"""Continuous distribution fitting engine for Spark."""
import logging
from functools import reduce
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, Union
import numpy as np
# PySpark is optional - only import if available
try:
from pyspark.sql import DataFrame, SparkSession
_PYSPARK_AVAILABLE = True
except ImportError:
DataFrame = None # type: ignore[assignment,misc]
SparkSession = None # type: ignore[assignment,misc]
_PYSPARK_AVAILABLE = False
from spark_bestfit.base_fitter import BaseFitter
from spark_bestfit.config import FitterConfig
from spark_bestfit.distributions import DistributionRegistry
from spark_bestfit.fitting import (
FIT_RESULT_SCHEMA,
FITTING_SAMPLE_SIZE,
compute_data_stats,
detect_heavy_tail,
fit_single_distribution,
)
from spark_bestfit.histogram import HistogramComputer
from spark_bestfit.results import DistributionFitResult, FitResultsType, LazyMetricsContext, create_fit_results
if TYPE_CHECKING:
from scipy.stats import rv_continuous
from spark_bestfit.protocols import ExecutionBackend
logger = logging.getLogger(__name__)
# Re-export for convenience
DEFAULT_EXCLUDED_DISTRIBUTIONS: Tuple[str, ...] = tuple(DistributionRegistry.DEFAULT_EXCLUSIONS)
[docs]
class DistributionFitter(BaseFitter):
"""Modern Spark distribution fitting engine.
Efficiently fits ~90 scipy.stats distributions to data using Spark's
parallel processing capabilities. Uses broadcast variables and Pandas UDFs
to avoid data collection and minimize serialization overhead.
Example:
>>> from pyspark.sql import SparkSession
>>> from spark_bestfit import DistributionFitter
>>>
>>> # Create your own SparkSession
>>> spark = SparkSession.builder.appName("my-app").getOrCreate()
>>> df = spark.createDataFrame([(float(x),) for x in data], ['value'])
>>>
>>> # Simple usage
>>> fitter = DistributionFitter(spark)
>>> results = fitter.fit(df, column='value')
>>> best = results.best(n=1)[0]
>>> print(f"Best: {best.distribution} with SSE={best.sse}")
>>>
>>> # With custom parameters
>>> fitter = DistributionFitter(spark, random_seed=123)
>>> results = fitter.fit(df, 'value', bins=100, support_at_zero=True)
>>>
>>> # Plot the best fit
>>> fitter.plot(best, df, 'value', title='Best Fit')
"""
# Class attributes for BaseFitter
_registry_class = DistributionRegistry
_default_exclusions = DEFAULT_EXCLUDED_DISTRIBUTIONS
def __init__(
self,
spark: Optional[SparkSession] = None,
excluded_distributions: Optional[Tuple[str, ...]] = None,
random_seed: int = 42,
backend: Optional["ExecutionBackend"] = None,
):
"""Initialize DistributionFitter.
Args:
spark: SparkSession. If None, uses the active session.
Ignored if ``backend`` is provided.
excluded_distributions: Distributions to exclude from fitting.
Defaults to DEFAULT_EXCLUDED_DISTRIBUTIONS (slow distributions).
Pass an empty tuple ``()`` to include ALL scipy distributions.
random_seed: Random seed for reproducible sampling.
backend: Optional execution backend (v2.0). If None, creates a
SparkBackend from the spark session. Allows plugging in
alternative backends like LocalBackend for testing.
Raises:
RuntimeError: If no SparkSession provided and no active session exists
"""
super().__init__(
spark=spark,
excluded_distributions=excluded_distributions,
random_seed=random_seed,
backend=backend,
)
self._histogram_computer = HistogramComputer(backend=self._backend)
[docs]
def register_distribution(
self,
name: str,
distribution: "rv_continuous",
overwrite: bool = False,
) -> "DistributionFitter":
"""Register a custom distribution for fitting.
Custom distributions must implement the scipy rv_continuous interface,
specifically the fit(), pdf(), and cdf() methods. The distribution
will be included in fitting alongside scipy.stats distributions.
Args:
name: Unique name for the distribution (used in results)
distribution: scipy rv_continuous instance or subclass.
Must implement fit(), pdf(), cdf() methods.
overwrite: If True, overwrite existing distribution with same name.
Default False raises ValueError if name exists.
Returns:
Self (for method chaining)
Raises:
ValueError: If name already exists (and overwrite=False) or
conflicts with a scipy.stats distribution name
TypeError: If distribution doesn't implement required interface
Example:
>>> from scipy.stats import rv_continuous
>>>
>>> class PowerDistribution(rv_continuous):
... def _pdf(self, x, alpha):
... return alpha * x ** (alpha - 1)
... def _cdf(self, x, alpha):
... return x ** alpha
>>>
>>> fitter = DistributionFitter(spark)
>>> fitter.register_distribution("power", PowerDistribution(a=0, b=1))
>>> results = fitter.fit(df, "column")
>>> # Results will include "power" if it fits well
"""
self._registry.register_distribution(name, distribution, overwrite=overwrite)
return self
[docs]
def unregister_distribution(self, name: str) -> "DistributionFitter":
"""Remove a custom distribution from the registry.
Args:
name: Name of the custom distribution to remove
Returns:
Self (for method chaining)
Raises:
KeyError: If distribution not found in registry
"""
self._registry.unregister_distribution(name)
return self
[docs]
def get_custom_distributions(self) -> dict:
"""Get all registered custom distributions.
Returns:
Dict mapping distribution names to rv_continuous objects
"""
return self._registry.get_custom_distributions()
[docs]
def fit(
self,
df: DataFrame,
column: Optional[str] = None,
columns: Optional[List[str]] = None,
config: Optional[FitterConfig] = None,
*,
bins: Union[int, Tuple[float, ...]] = 50,
use_rice_rule: bool = True,
support_at_zero: bool = False,
max_distributions: Optional[int] = None,
enable_sampling: bool = True,
sample_fraction: Optional[float] = None,
max_sample_size: int = 1_000_000,
sample_threshold: int = 10_000_000,
num_partitions: Optional[int] = None,
progress_callback: Optional[Callable[[int, int, float], None]] = None,
bounded: bool = False,
lower_bound: Optional[Union[float, Dict[str, float]]] = None,
upper_bound: Optional[Union[float, Dict[str, float]]] = None,
lazy_metrics: bool = False,
prefilter: Union[bool, str] = False,
estimation_method: str = "mle",
) -> FitResultsType:
"""Fit distributions to data column(s).
Args:
df: Spark DataFrame containing data
column: Name of single column to fit distributions to
columns: List of column names for multi-column fitting
config: FitterConfig object (v2.2.0). Provides a cleaner way to
configure fitting with many parameters. If provided, individual
parameters below are ignored (except progress_callback which
can override the config's callback). Use FitterConfigBuilder
for fluent configuration.
bins: Number of histogram bins or tuple of bin edges
use_rice_rule: Use Rice rule to auto-determine bin count
support_at_zero: Only fit non-negative distributions
max_distributions: Limit number of distributions (for testing)
enable_sampling: Enable sampling for large datasets
sample_fraction: Fraction to sample (None = auto-determine)
max_sample_size: Maximum rows to sample when auto-determining
sample_threshold: Row count above which sampling is applied
num_partitions: Spark partitions (None = auto-determine)
progress_callback: Optional callback for progress updates.
Called with (completed_tasks, total_tasks, percent_complete).
Callback is invoked from background thread - ensure thread-safety.
bounded: If True, fit truncated distributions (v1.4.0).
When enabled, distributions are truncated to [lower_bound, upper_bound]
using scipy.stats.truncate(). Requires scipy >= 1.14.0.
lower_bound: Lower bound for truncated distribution fitting.
Can be a float (applied to all columns) or a dict mapping
column names to bounds (v1.5.0). If None and bounded=True,
auto-detects from each column's minimum.
upper_bound: Upper bound for truncated distribution fitting.
Can be a float (applied to all columns) or a dict mapping
column names to bounds (v1.5.0). If None and bounded=True,
auto-detects from each column's maximum.
lazy_metrics: If True, defer computation of expensive KS/AD metrics
until accessed (v1.5.0). Improves fitting performance when only
using AIC/BIC/SSE for model selection. Default False for
backward compatibility.
prefilter: Pre-filter distributions based on data characteristics (v1.6.0).
Skips distributions that are mathematically incompatible with the data,
reducing fitting time by 30-70% for non-normal data.
- False (default): No pre-filtering, fit all distributions
- True: Safe mode - filters by support bounds and skewness sign
- 'aggressive': Also filters by kurtosis (may skip valid distributions)
Pre-filtering uses scipy's distribution support bounds (dist.a, dist.b)
and sample moments. Filtered distributions are logged for transparency.
estimation_method: Parameter estimation method (v2.5.0):
- "mle": Maximum Likelihood Estimation (default). Fast and accurate
for most distributions. Uses scipy.stats.fit().
- "mse": Maximum Spacing Estimation. More robust for heavy-tailed
distributions (Pareto, Cauchy, etc.) where MLE may fail.
- "auto": Automatically select MSE for heavy-tailed data based on
kurtosis and extreme value analysis, MLE otherwise.
Returns:
FitResults object with fitted distributions
Raises:
ValueError: If column not found, DataFrame empty, or invalid params
TypeError: If column is not numeric
Example:
>>> # Using FitterConfig (recommended for complex configs, v2.2.0)
>>> from spark_bestfit import FitterConfigBuilder
>>> config = (FitterConfigBuilder()
... .with_bins(100)
... .with_bounds(lower=0, upper=100)
... .with_sampling(fraction=0.1)
... .build())
>>> results = fitter.fit(df, column='value', config=config)
>>>
>>> # Single column (backward compatible)
>>> results = fitter.fit(df, column='value')
>>> results = fitter.fit(df, 'value', bins=100, support_at_zero=True)
>>>
>>> # Multi-column
>>> results = fitter.fit(df, columns=['col1', 'col2', 'col3'])
>>> best_col1 = results.for_column('col1').best(n=1)[0]
>>> best_per_col = results.best_per_column(n=1)
>>>
>>> # Bounded fitting (v1.4.0)
>>> results = fitter.fit(df, 'value', bounded=True) # Auto-detect bounds
>>> results = fitter.fit(df, 'value', bounded=True, lower_bound=0, upper_bound=100)
>>>
>>> # Lazy metrics for faster fitting when only using AIC/BIC (v1.5.0)
>>> results = fitter.fit(df, 'value', lazy_metrics=True)
>>> best_aic = results.best(n=1, metric='aic')[0] # Fast, no KS/AD computed
"""
# Resolve config: explicit config takes precedence over individual parameters
if config is not None:
# Use config values, but allow progress_callback override
cfg = config
if progress_callback is not None:
cfg = config.with_progress_callback(progress_callback)
else:
# Create config from individual parameters (backward compatibility)
cfg = FitterConfig(
bins=bins,
use_rice_rule=use_rice_rule,
support_at_zero=support_at_zero,
max_distributions=max_distributions,
prefilter=prefilter,
enable_sampling=enable_sampling,
sample_fraction=sample_fraction,
max_sample_size=max_sample_size,
sample_threshold=sample_threshold,
bounded=bounded,
lower_bound=lower_bound,
upper_bound=upper_bound,
num_partitions=num_partitions,
lazy_metrics=lazy_metrics,
progress_callback=progress_callback,
estimation_method=estimation_method,
)
# Normalize column/columns to list
target_columns = self._normalize_columns(column, columns)
# Input validation for all columns
for col in target_columns:
self._validate_inputs(df, col, cfg.max_distributions, cfg.bins, cfg.sample_fraction)
# Validate bounds - handle both scalar and dict forms
self._validate_bounds(cfg.lower_bound, cfg.upper_bound, target_columns)
# Validate censoring column if specified (v2.9.0)
if cfg.censoring_column is not None:
self._validate_censoring_column(df, cfg.censoring_column)
# Get row count (single operation for all columns)
row_count = self._get_row_count(df)
if row_count == 0:
raise ValueError("DataFrame is empty")
logger.info(f"Row count: {row_count}")
# Build per-column bounds dict: {col: (lower, upper)}
column_bounds: Dict[str, Tuple[Optional[float], Optional[float]]] = {}
if cfg.bounded:
column_bounds = self._resolve_bounds(df, target_columns, cfg.lower_bound, cfg.upper_bound)
# Sample if needed (single operation for all columns)
# For adaptive sampling, use first column as representative for skew analysis
df_sample = self._apply_sampling(
df,
row_count,
cfg.enable_sampling,
cfg.sample_fraction,
cfg.max_sample_size,
cfg.sample_threshold,
column=target_columns[0] if target_columns else None,
adaptive_sampling=cfg.adaptive_sampling,
sampling_mode=cfg.sampling_mode,
skew_threshold_mild=cfg.skew_threshold_mild,
skew_threshold_high=cfg.skew_threshold_high,
)
# Get distributions to fit (same for all columns)
distributions = self._registry.get_distributions(
support_at_zero=cfg.support_at_zero,
additional_exclusions=list(self.excluded_distributions),
)
if cfg.max_distributions is not None and cfg.max_distributions > 0:
distributions = distributions[: cfg.max_distributions]
# Fit each column and collect results
all_results_dfs = []
lazy_contexts: Dict[str, LazyMetricsContext] = {}
cached_samples: Dict[str, np.ndarray] = {}
for col in target_columns:
# Get per-column bounds (empty dict if not bounded)
col_lower, col_upper = column_bounds.get(col, (None, None))
logger.info(f"Fitting column '{col}'...")
# Create fitting sample - this is what we'll cache (v2.10.0: Instant mode)
data_sample = self._create_fitting_sample(df_sample, col, row_count)
cached_samples[col] = data_sample
results_df = self._fit_single_column(
df_sample=df_sample,
column=col,
row_count=row_count,
bins=cfg.bins,
use_rice_rule=cfg.use_rice_rule,
distributions=distributions,
num_partitions=cfg.num_partitions,
lower_bound=col_lower,
upper_bound=col_upper,
lazy_metrics=cfg.lazy_metrics,
prefilter=cfg.prefilter,
progress_callback=cfg.progress_callback,
estimation_method=cfg.estimation_method,
censoring_column=cfg.censoring_column,
data_sample=data_sample,
)
all_results_dfs.append(results_df)
# Build lazy context for on-demand metric computation
if cfg.lazy_metrics:
lazy_contexts[col] = LazyMetricsContext(
source_df=df_sample,
column=col,
random_seed=self.random_seed,
row_count=row_count,
lower_bound=col_lower,
upper_bound=col_upper,
is_discrete=False,
cached_sample=data_sample,
)
# Union all results - handle both Spark and pandas DataFrames
if self.spark is not None:
# Spark: union DataFrames
combined_df = reduce(DataFrame.union, all_results_dfs)
combined_df = combined_df.cache()
total_results = combined_df.count()
else:
# Non-Spark backend: concatenate pandas DataFrames
import pandas as pd
combined_df = pd.concat(all_results_dfs, ignore_index=True)
total_results = len(combined_df)
logger.info(
f"Total results: {total_results} ({len(target_columns)} columns × ~{len(distributions)} distributions)"
)
# Pass lazy contexts and cached samples to FitResults for instant plotting
return create_fit_results(
combined_df,
lazy_contexts=lazy_contexts if cfg.lazy_metrics else None,
samples=cached_samples,
)
def _fit_single_column(
self,
df_sample: DataFrame,
column: str,
row_count: int,
bins: Union[int, Tuple[float, ...]],
use_rice_rule: bool,
distributions: List[str],
num_partitions: Optional[int],
lower_bound: Optional[float] = None,
upper_bound: Optional[float] = None,
lazy_metrics: bool = False,
prefilter: Union[bool, str] = False,
progress_callback: Optional[Callable[[int, int, float], None]] = None,
estimation_method: str = "mle",
censoring_column: Optional[str] = None,
data_sample: Optional[np.ndarray] = None,
) -> DataFrame:
"""Fit distributions to a single column (internal method).
Args:
df_sample: Sampled DataFrame
column: Column name
row_count: Original row count (for histogram computation)
bins: Number of histogram bins
use_rice_rule: Use Rice rule for bin count
distributions: List of distribution names to fit
num_partitions: Number of Spark partitions
lower_bound: Lower bound for truncated distribution fitting (v1.4.0)
upper_bound: Upper bound for truncated distribution fitting (v1.4.0)
lazy_metrics: If True, skip KS/AD computation for performance (v1.5.0)
prefilter: Pre-filter mode (False, True, or 'aggressive') (v1.6.0)
progress_callback: Optional callback for progress updates (v2.0.0)
estimation_method: Parameter estimation method (v2.5.0):
- "mle": Maximum Likelihood Estimation (default)
- "mse": Maximum Spacing Estimation (robust for heavy-tailed data)
- "auto": Automatically select MSE for heavy-tailed data
censoring_column: Column name containing censoring indicator (v2.9.0).
True/1 = observed event, False/0 = right-censored observation.
data_sample: Optional pre-computed sample data (v2.10.0)
Returns:
Spark DataFrame with fit results for this column
"""
# Compute histogram (returns bin edges for CDF-based fitting)
y_hist, bin_edges = self._histogram_computer.compute_histogram(
df_sample, column, bins=bins, use_rice_rule=use_rice_rule, approx_count=row_count
)
logger.info(f" Histogram for '{column}': {len(bin_edges) - 1} bins")
# Create fitting sample if not provided
if data_sample is None:
data_sample = self._create_fitting_sample(df_sample, column, row_count)
# Extract censoring indicator if specified (v2.9.0)
censoring_indicator: Optional[np.ndarray] = None
if censoring_column is not None:
censoring_indicator = self._create_fitting_sample(df_sample, censoring_column, row_count)
censoring_indicator = censoring_indicator.astype(bool)
logger.info(
f" Censored data: {int(np.sum(censoring_indicator))}/{len(censoring_indicator)} "
f"observed events ({100 * np.mean(censoring_indicator):.1f}%)"
)
# Handle empty sample (all NaN/inf data filtered out)
if len(data_sample) == 0:
logger.warning(f" No valid data for '{column}' after filtering NaN/inf values")
import pandas as pd
if self.spark is not None:
return self.spark.createDataFrame([], schema=FIT_RESULT_SCHEMA)
else:
return pd.DataFrame(
columns=[
"column_name",
"distribution",
"parameters",
"sse",
"aic",
"bic",
"ks_statistic",
"pvalue",
"ad_statistic",
"ad_pvalue",
"data_min",
"data_max",
"data_mean",
"data_stddev",
"data_count",
"data_kurtosis",
"data_skewness",
"lower_bound",
"upper_bound",
]
)
# Apply pre-filtering if enabled (v1.6.0)
original_distributions = distributions
original_count = len(distributions)
if prefilter:
distributions, filtered = self._prefilter_distributions(distributions, data_sample, prefilter)
if filtered:
filtered_names = [f[0] for f in filtered]
logger.info(
f" Pre-filter: skipped {len(filtered)}/{original_count} distributions "
f"({', '.join(filtered_names[:5])}{'...' if len(filtered_names) > 5 else ''})"
)
# Safeguard: if all distributions filtered, fall back to fitting all
if not distributions:
logger.warning(
f" Pre-filter removed all {original_count} distributions; "
f"falling back to fitting all distributions"
)
distributions = original_distributions
# Compute data stats for provenance (once per column)
data_stats = compute_data_stats(data_sample)
# Detect heavy-tail characteristics and warn (#64)
heavy_tail_info = detect_heavy_tail(data_sample)
if heavy_tail_info["is_heavy_tailed"]:
indicators = ", ".join(heavy_tail_info["indicators"])
import warnings
# Only warn if not already using MSE (which handles heavy tails well)
if estimation_method != "mse":
warnings.warn(
f"Column '{column}' exhibits heavy-tail characteristics ({indicators}). "
f"Consider: (1) heavy-tail distributions like pareto, cauchy, t; "
f"(2) using estimation_method='mse' for robust fitting; "
f"(3) data transformation (log, sqrt); "
f"(4) checking for outliers. "
f"Standard distributions may provide poor fits.",
UserWarning,
stacklevel=4,
)
logger.warning(f" Heavy-tail detected for '{column}': {indicators}")
# Resolve "auto" estimation method: use MSE for heavy-tailed data
resolved_estimation_method = estimation_method
if estimation_method == "auto":
resolved_estimation_method = "mse" if heavy_tail_info["is_heavy_tailed"] else "mle"
if resolved_estimation_method == "mse":
logger.info(" Auto-selected MSE estimation for heavy-tailed data")
# Interleave slow distributions for better partition balance
# Lazy import to avoid circular dependency with core.py
from spark_bestfit.core import _interleave_distributions
distributions = _interleave_distributions(distributions)
# Execute parallel fitting via backend (v2.0 abstraction)
# Backend handles: broadcast, partitioning, UDF application, collection
# Pass custom distributions if any are registered (v2.4.0)
custom_dists = self._registry.get_custom_distributions() if self._registry.has_custom_distributions() else None
results = self._backend.parallel_fit(
distributions=distributions,
histogram=(y_hist, bin_edges),
data_sample=data_sample,
fit_func=fit_single_distribution,
column_name=column,
data_stats=data_stats,
num_partitions=num_partitions,
lower_bound=lower_bound,
upper_bound=upper_bound,
lazy_metrics=lazy_metrics,
is_discrete=False,
progress_callback=progress_callback,
custom_distributions=custom_dists,
estimation_method=resolved_estimation_method,
censoring_indicator=censoring_indicator,
)
# Convert results to DataFrame
if self.spark is not None:
# Spark backend
if results:
results_df = self.spark.createDataFrame(results, schema=FIT_RESULT_SCHEMA)
else:
results_df = self.spark.createDataFrame([], schema=FIT_RESULT_SCHEMA)
else:
# Non-Spark backend: use pandas DataFrame
import pandas as pd
if results:
results_df = pd.DataFrame(results)
else:
# Create empty DataFrame with proper schema to preserve API contract
results_df = pd.DataFrame(
columns=[
"column_name",
"distribution",
"parameters",
"sse",
"aic",
"bic",
"ks_statistic",
"pvalue",
"ad_statistic",
"ad_pvalue",
"data_min",
"data_max",
"data_mean",
"data_stddev",
"data_count",
"lower_bound",
"upper_bound",
]
)
num_results = len(results)
logger.info(f" Fit {num_results}/{len(distributions)} distributions for '{column}'")
return results_df
[docs]
def plot(
self,
result: DistributionFitResult,
df: Optional[DataFrame] = None,
column: Optional[str] = None,
bins: Union[int, Tuple[float, ...]] = 50,
use_rice_rule: bool = True,
title: str = "",
xlabel: str = "Value",
ylabel: str = "Density",
figsize: Tuple[int, int] = (12, 8),
dpi: int = 100,
show_histogram: bool = True,
histogram_alpha: float = 0.5,
pdf_linewidth: int = 2,
title_fontsize: int = 14,
label_fontsize: int = 12,
legend_fontsize: int = 10,
grid_alpha: float = 0.3,
save_path: Optional[str] = None,
save_format: str = "png",
force_recompute: bool = False,
):
"""Plot fitted distribution against data histogram.
Args:
result: DistributionFitResult to plot
df: DataFrame with data. Optional when result contains a cached
sample (the default after fitting). When both a cached sample
and df are provided, the cached sample is used unless
``force_recompute=True``.
column: Column name. If None, uses column_name from result.
bins: Number of histogram bins
use_rice_rule: Use Rice rule for bins
title: Plot title
xlabel: X-axis label
ylabel: Y-axis label
figsize: Figure size (width, height)
dpi: Dots per inch for saved figures
show_histogram: Show data histogram
histogram_alpha: Histogram transparency (0-1)
pdf_linewidth: Line width for PDF curve
title_fontsize: Title font size
label_fontsize: Axis label font size
legend_fontsize: Legend font size
grid_alpha: Grid transparency (0-1)
save_path: Path to save figure (optional)
save_format: Save format (png, pdf, svg)
force_recompute: If True, ignore cached sample and recompute
histogram from ``df`` (requires ``df`` to be provided).
Default False.
Returns:
Tuple of (figure, axis) from matplotlib
Example:
>>> # Instant plot from cached sample (recommended)
>>> fitter.plot(best, title='Instant Plot')
>>> # Explicit DataFrame (recomputes histogram via Spark)
>>> fitter.plot(best, df, 'value', title='Best Fit', force_recompute=True)
"""
from spark_bestfit.plotting import plot_distribution
# Cache-first: prefer cached sample to avoid Spark DAG recomputation
if result.cached_sample is not None and not force_recompute:
if df is not None:
self._warn_df_with_cache("plot")
data = result.cached_sample
if use_rice_rule:
n_bins = int(np.ceil(len(data) ** (1 / 3)) * 2)
else:
n_bins = bins if isinstance(bins, int) else len(bins) - 1
y_hist, bin_edges = np.histogram(data, bins=n_bins, density=True)
x_centers = (bin_edges[:-1] + bin_edges[1:]) / 2.0
elif df is not None:
col = column or result.column_name
if col is None:
raise ValueError("column must be provided if result.column_name is None")
y_hist, bin_edges = self._histogram_computer.compute_histogram(
df, col, bins=bins, use_rice_rule=use_rice_rule
)
x_centers = (bin_edges[:-1] + bin_edges[1:]) / 2.0
else:
if force_recompute:
raise ValueError(
"force_recompute=True requires df to be provided, " "since the cached sample is bypassed."
)
raise ValueError("Either df must be provided or result must contain a cached sample")
return plot_distribution(
result=result,
y_hist=y_hist,
x_hist=x_centers,
title=title,
xlabel=xlabel,
ylabel=ylabel,
figsize=figsize,
dpi=dpi,
show_histogram=show_histogram,
histogram_alpha=histogram_alpha,
pdf_linewidth=pdf_linewidth,
title_fontsize=title_fontsize,
label_fontsize=label_fontsize,
legend_fontsize=legend_fontsize,
grid_alpha=grid_alpha,
save_path=save_path,
save_format=save_format,
)
[docs]
def plot_comparison(
self,
results: List[DistributionFitResult],
df: Optional[DataFrame] = None,
column: Optional[str] = None,
bins: Union[int, Tuple[float, ...]] = 50,
use_rice_rule: bool = True,
title: str = "Distribution Comparison",
xlabel: str = "Value",
ylabel: str = "Density",
figsize: Tuple[int, int] = (12, 8),
dpi: int = 100,
show_histogram: bool = True,
histogram_alpha: float = 0.5,
pdf_linewidth: int = 2,
title_fontsize: int = 14,
label_fontsize: int = 12,
legend_fontsize: int = 10,
grid_alpha: float = 0.3,
save_path: Optional[str] = None,
save_format: str = "png",
force_recompute: bool = False,
):
"""Plot multiple distributions for comparison.
Args:
results: List of DistributionFitResult objects
df: DataFrame with data. Optional when results contain a cached
sample. When both a cached sample and df are provided, the
cached sample is used unless ``force_recompute=True``.
column: Column name. If None, uses column_name from the first result.
bins: Number of histogram bins
use_rice_rule: Use Rice rule for bins
title: Plot title
xlabel: X-axis label
ylabel: Y-axis label
figsize: Figure size (width, height)
dpi: Dots per inch
show_histogram: Show data histogram
histogram_alpha: Histogram transparency
pdf_linewidth: PDF line width
title_fontsize: Title font size
label_fontsize: Label font size
legend_fontsize: Legend font size
grid_alpha: Grid transparency
save_path: Path to save figure
save_format: Save format
force_recompute: If True, ignore cached samples and recompute
histogram from ``df`` (requires ``df`` to be provided).
Default False.
Returns:
Tuple of (figure, axis)
Example:
>>> top_3 = results.best(n=3)
>>> # Instant comparison from cached sample (recommended)
>>> fitter.plot_comparison(top_3)
>>> # Explicit DataFrame (recomputes histogram via Spark)
>>> fitter.plot_comparison(top_3, df, 'value', force_recompute=True)
"""
from spark_bestfit.plotting import plot_comparison
# Try to get cached sample from results
first_cached = None
if not force_recompute:
for r in results:
if r.cached_sample is not None:
first_cached = r.cached_sample
break
if first_cached is not None:
if df is not None:
self._warn_df_with_cache("plot_comparison")
data = first_cached
if use_rice_rule:
n_bins = int(np.ceil(len(data) ** (1 / 3)) * 2)
else:
n_bins = bins if isinstance(bins, int) else len(bins) - 1
y_hist, bin_edges = np.histogram(data, bins=n_bins, density=True)
x_centers = (bin_edges[:-1] + bin_edges[1:]) / 2.0
elif df is not None:
col = column or (results[0].column_name if results else None)
if col is None:
raise ValueError("column must be provided when no cached sample is available")
y_hist, bin_edges = self._histogram_computer.compute_histogram(
df, col, bins=bins, use_rice_rule=use_rice_rule
)
x_centers = (bin_edges[:-1] + bin_edges[1:]) / 2.0
else:
if force_recompute:
raise ValueError(
"force_recompute=True requires df to be provided, " "since the cached sample is bypassed."
)
raise ValueError("Either df must be provided or at least one result must contain a cached sample")
return plot_comparison(
results=results,
y_hist=y_hist,
x_hist=x_centers,
title=title,
xlabel=xlabel,
ylabel=ylabel,
figsize=figsize,
dpi=dpi,
show_histogram=show_histogram,
histogram_alpha=histogram_alpha,
pdf_linewidth=pdf_linewidth,
title_fontsize=title_fontsize,
label_fontsize=label_fontsize,
legend_fontsize=legend_fontsize,
grid_alpha=grid_alpha,
save_path=save_path,
save_format=save_format,
)
@staticmethod
def _validate_inputs(
df: DataFrame,
column: str,
max_distributions: Optional[int],
bins: Union[int, Tuple[float, ...]],
sample_fraction: Optional[float],
) -> None:
"""Validate input parameters for distribution fitting.
Args:
df: Spark DataFrame containing data
column: Column name to validate
max_distributions: Maximum distributions to fit (0 is invalid)
bins: Number of histogram bins (must be positive if int)
sample_fraction: Sampling fraction (must be in (0, 1] if provided)
Raises:
ValueError: If max_distributions is 0, column not found, bins invalid,
or sample_fraction out of range
TypeError: If column is not numeric
"""
# Use base class validation methods
BaseFitter._validate_max_distributions(max_distributions)
BaseFitter._validate_column_exists(df, column)
BaseFitter._validate_column_numeric(df, column)
BaseFitter._validate_sample_fraction(sample_fraction)
# Continuous-specific: validate bins parameter
if isinstance(bins, int) and bins <= 0:
raise ValueError(f"bins must be positive, got {bins}")
@staticmethod
def _validate_censoring_column(df: DataFrame, censoring_column: str) -> None:
"""Validate censoring column exists and is boolean/binary (v2.9.0).
Args:
df: DataFrame containing data
censoring_column: Name of the censoring indicator column
Raises:
ValueError: If column doesn't exist or contains non-binary values
"""
# Use base class method to check column exists
BaseFitter._validate_column_exists(df, censoring_column)
# Check column data type - should be boolean or numeric with only 0/1 values
# For Spark DataFrames
if hasattr(df, "schema"):
from pyspark.sql.types import BooleanType, IntegerType, LongType
col_type = df.schema[censoring_column].dataType
if not isinstance(col_type, (BooleanType, IntegerType, LongType)):
raise ValueError(
f"Censoring column '{censoring_column}' must be boolean or integer type "
f"(True/1 = observed, False/0 = censored), got {col_type}"
)
# For pandas DataFrames
elif hasattr(df, "dtypes"):
import pandas as pd
col_dtype = df[censoring_column].dtype
if col_dtype not in [bool, "bool", pd.BooleanDtype()] and not pd.api.types.is_integer_dtype(col_dtype):
raise ValueError(
f"Censoring column '{censoring_column}' must be boolean or integer type "
f"(True/1 = observed, False/0 = censored), got {col_dtype}"
)
# _validate_bounds inherited from BaseFitter
# _resolve_bounds inherited from BaseFitter
# _apply_sampling inherited from BaseFitter
def _create_fitting_sample(self, df: DataFrame, column: str, row_count: int) -> np.ndarray:
"""Create numpy sample array for scipy distribution fitting.
Samples up to FITTING_SAMPLE_SIZE rows from the DataFrame for use in
scipy's distribution fitting functions.
Args:
df: Spark DataFrame or pandas DataFrame containing data
column: Column name to sample
row_count: Total row count (used to calculate sampling fraction)
Returns:
Numpy array of sampled values for distribution fitting
"""
sample_size = min(FITTING_SAMPLE_SIZE, row_count)
fraction = min(sample_size / row_count, 1.0)
# Use backend's sample_column which handles both Spark and pandas
return self._backend.sample_column(df, column, fraction=fraction, seed=self.random_seed)
# _calculate_partitions inherited from BaseFitter
@staticmethod
def _prefilter_distributions(
distributions: List[str],
data_sample: np.ndarray,
mode: Union[bool, str],
) -> Tuple[List[str], List[Tuple[str, str]]]:
"""Pre-filter distributions based on data characteristics.
Uses a layered approach based on SHAPE properties (not location/scale):
1. Skewness sign (~95% reliable): Skip positive-skew-only distributions
for clearly left-skewed data (skewness < -1.0)
2. Kurtosis (aggressive mode only, ~80% reliable): Skip low-kurtosis
distributions for very heavy-tailed data
Note: We do NOT filter by support bounds (dist.a/dist.b) because scipy's
fitting process uses loc/scale parameters that can shift any distribution
to cover any data range. Shape properties (skewness, kurtosis) are
intrinsic and cannot be changed by loc/scale.
Args:
distributions: List of distribution names to filter
data_sample: Numpy array of sample data
mode: True for safe mode, 'aggressive' for additional kurtosis filter
Returns:
Tuple of (compatible_distributions, filtered_with_reasons)
"""
# Early return if filtering is disabled
if not mode:
return distributions.copy(), []
from scipy.stats import kurtosis, skew
data_skew = float(skew(data_sample))
data_kurt = float(kurtosis(data_sample)) # Excess kurtosis
compatible = []
filtered = []
# Distributions that can only have positive skewness (mathematical constraint)
positive_skew_only = {
"expon",
"gamma",
"lognorm",
"chi2",
"weibull_min",
"pareto",
"rayleigh",
"invgamma",
"exponweib",
"genpareto",
"invweibull",
"fisk",
"burr",
"burr12",
"loggamma",
"invgauss",
"genextreme", # When shape > 0
"gompertz",
"halfnorm",
"halfcauchy",
"halflogistic",
"halfgennorm",
"rice",
"nakagami",
"wald",
"gengamma",
"powerlognorm",
}
for dist_name in distributions:
try:
# We intentionally do NOT check support bounds (dist.a/dist.b)
# because scipy.fit() uses loc/scale parameters that can shift
# any distribution to cover any data range.
# Layer 1: Skewness sign check (~95% reliable)
# Only filter if data is CLEARLY left-skewed (threshold = -1.0)
# These distributions are intrinsically right-skewed regardless of loc/scale
if data_skew < -1.0 and dist_name in positive_skew_only:
filtered.append((dist_name, "positive-skew only"))
continue
# Layer 2: Kurtosis check (aggressive mode only, ~80% reliable)
if mode == "aggressive" and data_kurt > 10:
# Very heavy-tailed data - skip uniform which has kurtosis = -1.2
# Uniform's kurtosis is intrinsic and cannot be changed by loc/scale
if dist_name == "uniform":
filtered.append((dist_name, "low kurtosis distribution"))
continue
compatible.append(dist_name)
except AttributeError:
# Unknown distribution - keep it (conservative)
compatible.append(dist_name)
return compatible, filtered
[docs]
def plot_qq(
self,
result: DistributionFitResult,
df: Optional[DataFrame] = None,
column: Optional[str] = None,
max_points: int = 1000,
title: str = "",
xlabel: str = "Theoretical Quantiles",
ylabel: str = "Sample Quantiles",
figsize: Tuple[int, int] = (10, 10),
dpi: int = 100,
marker: str = "o",
marker_size: int = 30,
marker_alpha: float = 0.6,
marker_color: str = "steelblue",
line_color: str = "red",
line_style: str = "--",
line_width: float = 1.5,
title_fontsize: int = 14,
label_fontsize: int = 12,
grid_alpha: float = 0.3,
save_path: Optional[str] = None,
save_format: str = "png",
force_recompute: bool = False,
):
"""Create a Q-Q plot to assess goodness-of-fit.
A Q-Q (quantile-quantile) plot compares sample quantiles against
theoretical quantiles from the fitted distribution. Points falling
close to the reference line indicate a good fit.
Args:
result: DistributionFitResult to plot
df: DataFrame with data. Optional when result contains a cached
sample (the default after fitting). When both a cached sample
and df are provided, the cached sample is used unless
``force_recompute=True``.
column: Column name. If None, uses column_name from result.
max_points: Maximum data points to sample for plotting
title: Plot title
xlabel: X-axis label
ylabel: Y-axis label
figsize: Figure size (width, height)
dpi: Dots per inch for saved figures
marker: Marker style for data points
marker_size: Size of markers
marker_alpha: Marker transparency (0-1)
marker_color: Color of markers
line_color: Color of reference line
line_style: Style of reference line
line_width: Width of reference line
title_fontsize: Title font size
label_fontsize: Axis label font size
grid_alpha: Grid transparency (0-1)
save_path: Path to save figure (optional)
save_format: Save format (png, pdf, svg)
force_recompute: If True, ignore cached sample and resample
from ``df`` (requires ``df`` to be provided). Default False.
Returns:
Tuple of (figure, axis) from matplotlib
Example:
>>> best = results.best(n=1)[0]
>>> # Instant Q-Q plot from cached sample (recommended)
>>> fitter.plot_qq(best, title='Instant Q-Q Plot')
>>> # Explicit DataFrame (resamples via Spark)
>>> fitter.plot_qq(best, df, 'value', title='Q-Q Plot', force_recompute=True)
"""
from spark_bestfit.plotting import plot_qq
from spark_bestfit.storage import _get_dataframe_row_count, _sample_dataframe_column
# Cache-first: prefer cached sample to avoid Spark DAG recomputation
if result.cached_sample is not None and not force_recompute:
if df is not None:
self._warn_df_with_cache("plot_qq")
data = result.cached_sample[:max_points]
elif df is not None:
col = column or result.column_name
if col is None:
raise ValueError("column must be provided if result.column_name is None")
row_count = _get_dataframe_row_count(df)
fraction = min(max_points * 3 / row_count, 1.0) if row_count > 0 else 1.0
sampled = _sample_dataframe_column(df, col, fraction, self.random_seed)
data = sampled[:max_points]
else:
if force_recompute:
raise ValueError(
"force_recompute=True requires df to be provided, " "since the cached sample is bypassed."
)
raise ValueError("Either df must be provided or result must contain a cached sample")
return plot_qq(
result=result,
data=data,
title=title,
xlabel=xlabel,
ylabel=ylabel,
figsize=figsize,
dpi=dpi,
marker=marker,
marker_size=marker_size,
marker_alpha=marker_alpha,
marker_color=marker_color,
line_color=line_color,
line_style=line_style,
line_width=line_width,
title_fontsize=title_fontsize,
label_fontsize=label_fontsize,
grid_alpha=grid_alpha,
save_path=save_path,
save_format=save_format,
)
[docs]
def plot_pp(
self,
result: DistributionFitResult,
df: Optional[DataFrame] = None,
column: Optional[str] = None,
max_points: int = 1000,
title: str = "",
xlabel: str = "Theoretical Probabilities",
ylabel: str = "Sample Probabilities",
figsize: Tuple[int, int] = (10, 10),
dpi: int = 100,
marker: str = "o",
marker_size: int = 30,
marker_alpha: float = 0.6,
marker_color: str = "steelblue",
line_color: str = "red",
line_style: str = "--",
line_width: float = 1.5,
title_fontsize: int = 14,
label_fontsize: int = 12,
grid_alpha: float = 0.3,
save_path: Optional[str] = None,
save_format: str = "png",
force_recompute: bool = False,
):
"""
Create a P-P plot to assess goodness-of-fit.
A P-P (probability-probability) plot compares the empirical CDF of the
sample data against the theoretical CDF of the fitted distribution.
Points falling close to the reference line indicate a good fit,
particularly in the center of the distribution.
Args:
result: DistributionFitResult to plot
df: DataFrame with data. Optional when result contains a cached
sample (the default after fitting). When both a cached sample
and df are provided, the cached sample is used unless
``force_recompute=True``.
column: Column name. If None, uses column_name from result.
max_points: Maximum data points to sample for plotting
title: Plot title
xlabel: X-axis label
ylabel: Y-axis label
figsize: Figure size (width, height)
dpi: Dots per inch for saved figures
marker: Marker style for data points
marker_size: Size of markers
marker_alpha: Marker transparency (0-1)
marker_color: Color of markers
line_color: Color of reference line
line_style: Style of reference line
line_width: Width of reference line
title_fontsize: Title font size
label_fontsize: Axis label font size
grid_alpha: Grid transparency (0-1)
save_path: Path to save figure (optional)
save_format: Save format (png, pdf, svg)
force_recompute: If True, ignore cached sample and resample
from ``df`` (requires ``df`` to be provided). Default False.
Returns:
Tuple of (figure, axis) from matplotlib
Example:
>>> best = results.best(n=1)[0]
>>> # Instant P-P plot from cached sample (recommended)
>>> fitter.plot_pp(best, title='Instant P-P Plot')
>>> # Explicit DataFrame (resamples via Spark)
>>> fitter.plot_pp(best, df, 'value', title='P-P Plot', force_recompute=True)
"""
from spark_bestfit.plotting import plot_pp
from spark_bestfit.storage import _get_dataframe_row_count, _sample_dataframe_column
# Cache-first: prefer cached sample to avoid Spark DAG recomputation
if result.cached_sample is not None and not force_recompute:
if df is not None:
self._warn_df_with_cache("plot_pp")
data = result.cached_sample[:max_points]
elif df is not None:
col = column or result.column_name
if col is None:
raise ValueError("column must be provided if result.column_name is None")
row_count = _get_dataframe_row_count(df)
fraction = min(max_points * 3 / row_count, 1.0) if row_count > 0 else 1.0
sampled = _sample_dataframe_column(df, col, fraction, self.random_seed)
data = sampled[:max_points]
else:
if force_recompute:
raise ValueError(
"force_recompute=True requires df to be provided, " "since the cached sample is bypassed."
)
raise ValueError("Either df must be provided or result must contain a cached sample")
return plot_pp(
result=result,
data=data,
title=title,
xlabel=xlabel,
ylabel=ylabel,
figsize=figsize,
dpi=dpi,
marker=marker,
marker_size=marker_size,
marker_alpha=marker_alpha,
marker_color=marker_color,
line_color=line_color,
line_style=line_style,
line_width=line_width,
title_fontsize=title_fontsize,
label_fontsize=label_fontsize,
grid_alpha=grid_alpha,
save_path=save_path,
save_format=save_format,
)