Source code for spark_bestfit.continuous_fitter

"""Continuous distribution fitting engine for Spark."""

import logging
from functools import reduce
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, Union

import numpy as np

# PySpark is optional - only import if available
try:
    from pyspark.sql import DataFrame, SparkSession

    _PYSPARK_AVAILABLE = True
except ImportError:
    DataFrame = None  # type: ignore[assignment,misc]
    SparkSession = None  # type: ignore[assignment,misc]
    _PYSPARK_AVAILABLE = False

from spark_bestfit.base_fitter import BaseFitter
from spark_bestfit.config import FitterConfig
from spark_bestfit.distributions import DistributionRegistry
from spark_bestfit.fitting import (
    FIT_RESULT_SCHEMA,
    FITTING_SAMPLE_SIZE,
    compute_data_stats,
    detect_heavy_tail,
    fit_single_distribution,
)
from spark_bestfit.histogram import HistogramComputer
from spark_bestfit.results import DistributionFitResult, FitResultsType, LazyMetricsContext, create_fit_results

if TYPE_CHECKING:
    from scipy.stats import rv_continuous

    from spark_bestfit.protocols import ExecutionBackend

logger = logging.getLogger(__name__)

# Re-export for convenience
DEFAULT_EXCLUDED_DISTRIBUTIONS: Tuple[str, ...] = tuple(DistributionRegistry.DEFAULT_EXCLUSIONS)


[docs] class DistributionFitter(BaseFitter): """Modern Spark distribution fitting engine. Efficiently fits ~90 scipy.stats distributions to data using Spark's parallel processing capabilities. Uses broadcast variables and Pandas UDFs to avoid data collection and minimize serialization overhead. Example: >>> from pyspark.sql import SparkSession >>> from spark_bestfit import DistributionFitter >>> >>> # Create your own SparkSession >>> spark = SparkSession.builder.appName("my-app").getOrCreate() >>> df = spark.createDataFrame([(float(x),) for x in data], ['value']) >>> >>> # Simple usage >>> fitter = DistributionFitter(spark) >>> results = fitter.fit(df, column='value') >>> best = results.best(n=1)[0] >>> print(f"Best: {best.distribution} with SSE={best.sse}") >>> >>> # With custom parameters >>> fitter = DistributionFitter(spark, random_seed=123) >>> results = fitter.fit(df, 'value', bins=100, support_at_zero=True) >>> >>> # Plot the best fit >>> fitter.plot(best, df, 'value', title='Best Fit') """ # Class attributes for BaseFitter _registry_class = DistributionRegistry _default_exclusions = DEFAULT_EXCLUDED_DISTRIBUTIONS def __init__( self, spark: Optional[SparkSession] = None, excluded_distributions: Optional[Tuple[str, ...]] = None, random_seed: int = 42, backend: Optional["ExecutionBackend"] = None, ): """Initialize DistributionFitter. Args: spark: SparkSession. If None, uses the active session. Ignored if ``backend`` is provided. excluded_distributions: Distributions to exclude from fitting. Defaults to DEFAULT_EXCLUDED_DISTRIBUTIONS (slow distributions). Pass an empty tuple ``()`` to include ALL scipy distributions. random_seed: Random seed for reproducible sampling. backend: Optional execution backend (v2.0). If None, creates a SparkBackend from the spark session. Allows plugging in alternative backends like LocalBackend for testing. Raises: RuntimeError: If no SparkSession provided and no active session exists """ super().__init__( spark=spark, excluded_distributions=excluded_distributions, random_seed=random_seed, backend=backend, ) self._histogram_computer = HistogramComputer(backend=self._backend)
[docs] def register_distribution( self, name: str, distribution: "rv_continuous", overwrite: bool = False, ) -> "DistributionFitter": """Register a custom distribution for fitting. Custom distributions must implement the scipy rv_continuous interface, specifically the fit(), pdf(), and cdf() methods. The distribution will be included in fitting alongside scipy.stats distributions. Args: name: Unique name for the distribution (used in results) distribution: scipy rv_continuous instance or subclass. Must implement fit(), pdf(), cdf() methods. overwrite: If True, overwrite existing distribution with same name. Default False raises ValueError if name exists. Returns: Self (for method chaining) Raises: ValueError: If name already exists (and overwrite=False) or conflicts with a scipy.stats distribution name TypeError: If distribution doesn't implement required interface Example: >>> from scipy.stats import rv_continuous >>> >>> class PowerDistribution(rv_continuous): ... def _pdf(self, x, alpha): ... return alpha * x ** (alpha - 1) ... def _cdf(self, x, alpha): ... return x ** alpha >>> >>> fitter = DistributionFitter(spark) >>> fitter.register_distribution("power", PowerDistribution(a=0, b=1)) >>> results = fitter.fit(df, "column") >>> # Results will include "power" if it fits well """ self._registry.register_distribution(name, distribution, overwrite=overwrite) return self
[docs] def unregister_distribution(self, name: str) -> "DistributionFitter": """Remove a custom distribution from the registry. Args: name: Name of the custom distribution to remove Returns: Self (for method chaining) Raises: KeyError: If distribution not found in registry """ self._registry.unregister_distribution(name) return self
[docs] def get_custom_distributions(self) -> dict: """Get all registered custom distributions. Returns: Dict mapping distribution names to rv_continuous objects """ return self._registry.get_custom_distributions()
[docs] def fit( self, df: DataFrame, column: Optional[str] = None, columns: Optional[List[str]] = None, config: Optional[FitterConfig] = None, *, bins: Union[int, Tuple[float, ...]] = 50, use_rice_rule: bool = True, support_at_zero: bool = False, max_distributions: Optional[int] = None, enable_sampling: bool = True, sample_fraction: Optional[float] = None, max_sample_size: int = 1_000_000, sample_threshold: int = 10_000_000, num_partitions: Optional[int] = None, progress_callback: Optional[Callable[[int, int, float], None]] = None, bounded: bool = False, lower_bound: Optional[Union[float, Dict[str, float]]] = None, upper_bound: Optional[Union[float, Dict[str, float]]] = None, lazy_metrics: bool = False, prefilter: Union[bool, str] = False, estimation_method: str = "mle", ) -> FitResultsType: """Fit distributions to data column(s). Args: df: Spark DataFrame containing data column: Name of single column to fit distributions to columns: List of column names for multi-column fitting config: FitterConfig object (v2.2.0). Provides a cleaner way to configure fitting with many parameters. If provided, individual parameters below are ignored (except progress_callback which can override the config's callback). Use FitterConfigBuilder for fluent configuration. bins: Number of histogram bins or tuple of bin edges use_rice_rule: Use Rice rule to auto-determine bin count support_at_zero: Only fit non-negative distributions max_distributions: Limit number of distributions (for testing) enable_sampling: Enable sampling for large datasets sample_fraction: Fraction to sample (None = auto-determine) max_sample_size: Maximum rows to sample when auto-determining sample_threshold: Row count above which sampling is applied num_partitions: Spark partitions (None = auto-determine) progress_callback: Optional callback for progress updates. Called with (completed_tasks, total_tasks, percent_complete). Callback is invoked from background thread - ensure thread-safety. bounded: If True, fit truncated distributions (v1.4.0). When enabled, distributions are truncated to [lower_bound, upper_bound] using scipy.stats.truncate(). Requires scipy >= 1.14.0. lower_bound: Lower bound for truncated distribution fitting. Can be a float (applied to all columns) or a dict mapping column names to bounds (v1.5.0). If None and bounded=True, auto-detects from each column's minimum. upper_bound: Upper bound for truncated distribution fitting. Can be a float (applied to all columns) or a dict mapping column names to bounds (v1.5.0). If None and bounded=True, auto-detects from each column's maximum. lazy_metrics: If True, defer computation of expensive KS/AD metrics until accessed (v1.5.0). Improves fitting performance when only using AIC/BIC/SSE for model selection. Default False for backward compatibility. prefilter: Pre-filter distributions based on data characteristics (v1.6.0). Skips distributions that are mathematically incompatible with the data, reducing fitting time by 30-70% for non-normal data. - False (default): No pre-filtering, fit all distributions - True: Safe mode - filters by support bounds and skewness sign - 'aggressive': Also filters by kurtosis (may skip valid distributions) Pre-filtering uses scipy's distribution support bounds (dist.a, dist.b) and sample moments. Filtered distributions are logged for transparency. estimation_method: Parameter estimation method (v2.5.0): - "mle": Maximum Likelihood Estimation (default). Fast and accurate for most distributions. Uses scipy.stats.fit(). - "mse": Maximum Spacing Estimation. More robust for heavy-tailed distributions (Pareto, Cauchy, etc.) where MLE may fail. - "auto": Automatically select MSE for heavy-tailed data based on kurtosis and extreme value analysis, MLE otherwise. Returns: FitResults object with fitted distributions Raises: ValueError: If column not found, DataFrame empty, or invalid params TypeError: If column is not numeric Example: >>> # Using FitterConfig (recommended for complex configs, v2.2.0) >>> from spark_bestfit import FitterConfigBuilder >>> config = (FitterConfigBuilder() ... .with_bins(100) ... .with_bounds(lower=0, upper=100) ... .with_sampling(fraction=0.1) ... .build()) >>> results = fitter.fit(df, column='value', config=config) >>> >>> # Single column (backward compatible) >>> results = fitter.fit(df, column='value') >>> results = fitter.fit(df, 'value', bins=100, support_at_zero=True) >>> >>> # Multi-column >>> results = fitter.fit(df, columns=['col1', 'col2', 'col3']) >>> best_col1 = results.for_column('col1').best(n=1)[0] >>> best_per_col = results.best_per_column(n=1) >>> >>> # Bounded fitting (v1.4.0) >>> results = fitter.fit(df, 'value', bounded=True) # Auto-detect bounds >>> results = fitter.fit(df, 'value', bounded=True, lower_bound=0, upper_bound=100) >>> >>> # Lazy metrics for faster fitting when only using AIC/BIC (v1.5.0) >>> results = fitter.fit(df, 'value', lazy_metrics=True) >>> best_aic = results.best(n=1, metric='aic')[0] # Fast, no KS/AD computed """ # Resolve config: explicit config takes precedence over individual parameters if config is not None: # Use config values, but allow progress_callback override cfg = config if progress_callback is not None: cfg = config.with_progress_callback(progress_callback) else: # Create config from individual parameters (backward compatibility) cfg = FitterConfig( bins=bins, use_rice_rule=use_rice_rule, support_at_zero=support_at_zero, max_distributions=max_distributions, prefilter=prefilter, enable_sampling=enable_sampling, sample_fraction=sample_fraction, max_sample_size=max_sample_size, sample_threshold=sample_threshold, bounded=bounded, lower_bound=lower_bound, upper_bound=upper_bound, num_partitions=num_partitions, lazy_metrics=lazy_metrics, progress_callback=progress_callback, estimation_method=estimation_method, ) # Normalize column/columns to list target_columns = self._normalize_columns(column, columns) # Input validation for all columns for col in target_columns: self._validate_inputs(df, col, cfg.max_distributions, cfg.bins, cfg.sample_fraction) # Validate bounds - handle both scalar and dict forms self._validate_bounds(cfg.lower_bound, cfg.upper_bound, target_columns) # Validate censoring column if specified (v2.9.0) if cfg.censoring_column is not None: self._validate_censoring_column(df, cfg.censoring_column) # Get row count (single operation for all columns) row_count = self._get_row_count(df) if row_count == 0: raise ValueError("DataFrame is empty") logger.info(f"Row count: {row_count}") # Build per-column bounds dict: {col: (lower, upper)} column_bounds: Dict[str, Tuple[Optional[float], Optional[float]]] = {} if cfg.bounded: column_bounds = self._resolve_bounds(df, target_columns, cfg.lower_bound, cfg.upper_bound) # Sample if needed (single operation for all columns) # For adaptive sampling, use first column as representative for skew analysis df_sample = self._apply_sampling( df, row_count, cfg.enable_sampling, cfg.sample_fraction, cfg.max_sample_size, cfg.sample_threshold, column=target_columns[0] if target_columns else None, adaptive_sampling=cfg.adaptive_sampling, sampling_mode=cfg.sampling_mode, skew_threshold_mild=cfg.skew_threshold_mild, skew_threshold_high=cfg.skew_threshold_high, ) # Get distributions to fit (same for all columns) distributions = self._registry.get_distributions( support_at_zero=cfg.support_at_zero, additional_exclusions=list(self.excluded_distributions), ) if cfg.max_distributions is not None and cfg.max_distributions > 0: distributions = distributions[: cfg.max_distributions] # Fit each column and collect results all_results_dfs = [] lazy_contexts: Dict[str, LazyMetricsContext] = {} cached_samples: Dict[str, np.ndarray] = {} for col in target_columns: # Get per-column bounds (empty dict if not bounded) col_lower, col_upper = column_bounds.get(col, (None, None)) logger.info(f"Fitting column '{col}'...") # Create fitting sample - this is what we'll cache (v2.10.0: Instant mode) data_sample = self._create_fitting_sample(df_sample, col, row_count) cached_samples[col] = data_sample results_df = self._fit_single_column( df_sample=df_sample, column=col, row_count=row_count, bins=cfg.bins, use_rice_rule=cfg.use_rice_rule, distributions=distributions, num_partitions=cfg.num_partitions, lower_bound=col_lower, upper_bound=col_upper, lazy_metrics=cfg.lazy_metrics, prefilter=cfg.prefilter, progress_callback=cfg.progress_callback, estimation_method=cfg.estimation_method, censoring_column=cfg.censoring_column, data_sample=data_sample, ) all_results_dfs.append(results_df) # Build lazy context for on-demand metric computation if cfg.lazy_metrics: lazy_contexts[col] = LazyMetricsContext( source_df=df_sample, column=col, random_seed=self.random_seed, row_count=row_count, lower_bound=col_lower, upper_bound=col_upper, is_discrete=False, cached_sample=data_sample, ) # Union all results - handle both Spark and pandas DataFrames if self.spark is not None: # Spark: union DataFrames combined_df = reduce(DataFrame.union, all_results_dfs) combined_df = combined_df.cache() total_results = combined_df.count() else: # Non-Spark backend: concatenate pandas DataFrames import pandas as pd combined_df = pd.concat(all_results_dfs, ignore_index=True) total_results = len(combined_df) logger.info( f"Total results: {total_results} ({len(target_columns)} columns × ~{len(distributions)} distributions)" ) # Pass lazy contexts and cached samples to FitResults for instant plotting return create_fit_results( combined_df, lazy_contexts=lazy_contexts if cfg.lazy_metrics else None, samples=cached_samples, )
def _fit_single_column( self, df_sample: DataFrame, column: str, row_count: int, bins: Union[int, Tuple[float, ...]], use_rice_rule: bool, distributions: List[str], num_partitions: Optional[int], lower_bound: Optional[float] = None, upper_bound: Optional[float] = None, lazy_metrics: bool = False, prefilter: Union[bool, str] = False, progress_callback: Optional[Callable[[int, int, float], None]] = None, estimation_method: str = "mle", censoring_column: Optional[str] = None, data_sample: Optional[np.ndarray] = None, ) -> DataFrame: """Fit distributions to a single column (internal method). Args: df_sample: Sampled DataFrame column: Column name row_count: Original row count (for histogram computation) bins: Number of histogram bins use_rice_rule: Use Rice rule for bin count distributions: List of distribution names to fit num_partitions: Number of Spark partitions lower_bound: Lower bound for truncated distribution fitting (v1.4.0) upper_bound: Upper bound for truncated distribution fitting (v1.4.0) lazy_metrics: If True, skip KS/AD computation for performance (v1.5.0) prefilter: Pre-filter mode (False, True, or 'aggressive') (v1.6.0) progress_callback: Optional callback for progress updates (v2.0.0) estimation_method: Parameter estimation method (v2.5.0): - "mle": Maximum Likelihood Estimation (default) - "mse": Maximum Spacing Estimation (robust for heavy-tailed data) - "auto": Automatically select MSE for heavy-tailed data censoring_column: Column name containing censoring indicator (v2.9.0). True/1 = observed event, False/0 = right-censored observation. data_sample: Optional pre-computed sample data (v2.10.0) Returns: Spark DataFrame with fit results for this column """ # Compute histogram (returns bin edges for CDF-based fitting) y_hist, bin_edges = self._histogram_computer.compute_histogram( df_sample, column, bins=bins, use_rice_rule=use_rice_rule, approx_count=row_count ) logger.info(f" Histogram for '{column}': {len(bin_edges) - 1} bins") # Create fitting sample if not provided if data_sample is None: data_sample = self._create_fitting_sample(df_sample, column, row_count) # Extract censoring indicator if specified (v2.9.0) censoring_indicator: Optional[np.ndarray] = None if censoring_column is not None: censoring_indicator = self._create_fitting_sample(df_sample, censoring_column, row_count) censoring_indicator = censoring_indicator.astype(bool) logger.info( f" Censored data: {int(np.sum(censoring_indicator))}/{len(censoring_indicator)} " f"observed events ({100 * np.mean(censoring_indicator):.1f}%)" ) # Handle empty sample (all NaN/inf data filtered out) if len(data_sample) == 0: logger.warning(f" No valid data for '{column}' after filtering NaN/inf values") import pandas as pd if self.spark is not None: return self.spark.createDataFrame([], schema=FIT_RESULT_SCHEMA) else: return pd.DataFrame( columns=[ "column_name", "distribution", "parameters", "sse", "aic", "bic", "ks_statistic", "pvalue", "ad_statistic", "ad_pvalue", "data_min", "data_max", "data_mean", "data_stddev", "data_count", "data_kurtosis", "data_skewness", "lower_bound", "upper_bound", ] ) # Apply pre-filtering if enabled (v1.6.0) original_distributions = distributions original_count = len(distributions) if prefilter: distributions, filtered = self._prefilter_distributions(distributions, data_sample, prefilter) if filtered: filtered_names = [f[0] for f in filtered] logger.info( f" Pre-filter: skipped {len(filtered)}/{original_count} distributions " f"({', '.join(filtered_names[:5])}{'...' if len(filtered_names) > 5 else ''})" ) # Safeguard: if all distributions filtered, fall back to fitting all if not distributions: logger.warning( f" Pre-filter removed all {original_count} distributions; " f"falling back to fitting all distributions" ) distributions = original_distributions # Compute data stats for provenance (once per column) data_stats = compute_data_stats(data_sample) # Detect heavy-tail characteristics and warn (#64) heavy_tail_info = detect_heavy_tail(data_sample) if heavy_tail_info["is_heavy_tailed"]: indicators = ", ".join(heavy_tail_info["indicators"]) import warnings # Only warn if not already using MSE (which handles heavy tails well) if estimation_method != "mse": warnings.warn( f"Column '{column}' exhibits heavy-tail characteristics ({indicators}). " f"Consider: (1) heavy-tail distributions like pareto, cauchy, t; " f"(2) using estimation_method='mse' for robust fitting; " f"(3) data transformation (log, sqrt); " f"(4) checking for outliers. " f"Standard distributions may provide poor fits.", UserWarning, stacklevel=4, ) logger.warning(f" Heavy-tail detected for '{column}': {indicators}") # Resolve "auto" estimation method: use MSE for heavy-tailed data resolved_estimation_method = estimation_method if estimation_method == "auto": resolved_estimation_method = "mse" if heavy_tail_info["is_heavy_tailed"] else "mle" if resolved_estimation_method == "mse": logger.info(" Auto-selected MSE estimation for heavy-tailed data") # Interleave slow distributions for better partition balance # Lazy import to avoid circular dependency with core.py from spark_bestfit.core import _interleave_distributions distributions = _interleave_distributions(distributions) # Execute parallel fitting via backend (v2.0 abstraction) # Backend handles: broadcast, partitioning, UDF application, collection # Pass custom distributions if any are registered (v2.4.0) custom_dists = self._registry.get_custom_distributions() if self._registry.has_custom_distributions() else None results = self._backend.parallel_fit( distributions=distributions, histogram=(y_hist, bin_edges), data_sample=data_sample, fit_func=fit_single_distribution, column_name=column, data_stats=data_stats, num_partitions=num_partitions, lower_bound=lower_bound, upper_bound=upper_bound, lazy_metrics=lazy_metrics, is_discrete=False, progress_callback=progress_callback, custom_distributions=custom_dists, estimation_method=resolved_estimation_method, censoring_indicator=censoring_indicator, ) # Convert results to DataFrame if self.spark is not None: # Spark backend if results: results_df = self.spark.createDataFrame(results, schema=FIT_RESULT_SCHEMA) else: results_df = self.spark.createDataFrame([], schema=FIT_RESULT_SCHEMA) else: # Non-Spark backend: use pandas DataFrame import pandas as pd if results: results_df = pd.DataFrame(results) else: # Create empty DataFrame with proper schema to preserve API contract results_df = pd.DataFrame( columns=[ "column_name", "distribution", "parameters", "sse", "aic", "bic", "ks_statistic", "pvalue", "ad_statistic", "ad_pvalue", "data_min", "data_max", "data_mean", "data_stddev", "data_count", "lower_bound", "upper_bound", ] ) num_results = len(results) logger.info(f" Fit {num_results}/{len(distributions)} distributions for '{column}'") return results_df
[docs] def plot( self, result: DistributionFitResult, df: Optional[DataFrame] = None, column: Optional[str] = None, bins: Union[int, Tuple[float, ...]] = 50, use_rice_rule: bool = True, title: str = "", xlabel: str = "Value", ylabel: str = "Density", figsize: Tuple[int, int] = (12, 8), dpi: int = 100, show_histogram: bool = True, histogram_alpha: float = 0.5, pdf_linewidth: int = 2, title_fontsize: int = 14, label_fontsize: int = 12, legend_fontsize: int = 10, grid_alpha: float = 0.3, save_path: Optional[str] = None, save_format: str = "png", force_recompute: bool = False, ): """Plot fitted distribution against data histogram. Args: result: DistributionFitResult to plot df: DataFrame with data. Optional when result contains a cached sample (the default after fitting). When both a cached sample and df are provided, the cached sample is used unless ``force_recompute=True``. column: Column name. If None, uses column_name from result. bins: Number of histogram bins use_rice_rule: Use Rice rule for bins title: Plot title xlabel: X-axis label ylabel: Y-axis label figsize: Figure size (width, height) dpi: Dots per inch for saved figures show_histogram: Show data histogram histogram_alpha: Histogram transparency (0-1) pdf_linewidth: Line width for PDF curve title_fontsize: Title font size label_fontsize: Axis label font size legend_fontsize: Legend font size grid_alpha: Grid transparency (0-1) save_path: Path to save figure (optional) save_format: Save format (png, pdf, svg) force_recompute: If True, ignore cached sample and recompute histogram from ``df`` (requires ``df`` to be provided). Default False. Returns: Tuple of (figure, axis) from matplotlib Example: >>> # Instant plot from cached sample (recommended) >>> fitter.plot(best, title='Instant Plot') >>> # Explicit DataFrame (recomputes histogram via Spark) >>> fitter.plot(best, df, 'value', title='Best Fit', force_recompute=True) """ from spark_bestfit.plotting import plot_distribution # Cache-first: prefer cached sample to avoid Spark DAG recomputation if result.cached_sample is not None and not force_recompute: if df is not None: self._warn_df_with_cache("plot") data = result.cached_sample if use_rice_rule: n_bins = int(np.ceil(len(data) ** (1 / 3)) * 2) else: n_bins = bins if isinstance(bins, int) else len(bins) - 1 y_hist, bin_edges = np.histogram(data, bins=n_bins, density=True) x_centers = (bin_edges[:-1] + bin_edges[1:]) / 2.0 elif df is not None: col = column or result.column_name if col is None: raise ValueError("column must be provided if result.column_name is None") y_hist, bin_edges = self._histogram_computer.compute_histogram( df, col, bins=bins, use_rice_rule=use_rice_rule ) x_centers = (bin_edges[:-1] + bin_edges[1:]) / 2.0 else: if force_recompute: raise ValueError( "force_recompute=True requires df to be provided, " "since the cached sample is bypassed." ) raise ValueError("Either df must be provided or result must contain a cached sample") return plot_distribution( result=result, y_hist=y_hist, x_hist=x_centers, title=title, xlabel=xlabel, ylabel=ylabel, figsize=figsize, dpi=dpi, show_histogram=show_histogram, histogram_alpha=histogram_alpha, pdf_linewidth=pdf_linewidth, title_fontsize=title_fontsize, label_fontsize=label_fontsize, legend_fontsize=legend_fontsize, grid_alpha=grid_alpha, save_path=save_path, save_format=save_format, )
[docs] def plot_comparison( self, results: List[DistributionFitResult], df: Optional[DataFrame] = None, column: Optional[str] = None, bins: Union[int, Tuple[float, ...]] = 50, use_rice_rule: bool = True, title: str = "Distribution Comparison", xlabel: str = "Value", ylabel: str = "Density", figsize: Tuple[int, int] = (12, 8), dpi: int = 100, show_histogram: bool = True, histogram_alpha: float = 0.5, pdf_linewidth: int = 2, title_fontsize: int = 14, label_fontsize: int = 12, legend_fontsize: int = 10, grid_alpha: float = 0.3, save_path: Optional[str] = None, save_format: str = "png", force_recompute: bool = False, ): """Plot multiple distributions for comparison. Args: results: List of DistributionFitResult objects df: DataFrame with data. Optional when results contain a cached sample. When both a cached sample and df are provided, the cached sample is used unless ``force_recompute=True``. column: Column name. If None, uses column_name from the first result. bins: Number of histogram bins use_rice_rule: Use Rice rule for bins title: Plot title xlabel: X-axis label ylabel: Y-axis label figsize: Figure size (width, height) dpi: Dots per inch show_histogram: Show data histogram histogram_alpha: Histogram transparency pdf_linewidth: PDF line width title_fontsize: Title font size label_fontsize: Label font size legend_fontsize: Legend font size grid_alpha: Grid transparency save_path: Path to save figure save_format: Save format force_recompute: If True, ignore cached samples and recompute histogram from ``df`` (requires ``df`` to be provided). Default False. Returns: Tuple of (figure, axis) Example: >>> top_3 = results.best(n=3) >>> # Instant comparison from cached sample (recommended) >>> fitter.plot_comparison(top_3) >>> # Explicit DataFrame (recomputes histogram via Spark) >>> fitter.plot_comparison(top_3, df, 'value', force_recompute=True) """ from spark_bestfit.plotting import plot_comparison # Try to get cached sample from results first_cached = None if not force_recompute: for r in results: if r.cached_sample is not None: first_cached = r.cached_sample break if first_cached is not None: if df is not None: self._warn_df_with_cache("plot_comparison") data = first_cached if use_rice_rule: n_bins = int(np.ceil(len(data) ** (1 / 3)) * 2) else: n_bins = bins if isinstance(bins, int) else len(bins) - 1 y_hist, bin_edges = np.histogram(data, bins=n_bins, density=True) x_centers = (bin_edges[:-1] + bin_edges[1:]) / 2.0 elif df is not None: col = column or (results[0].column_name if results else None) if col is None: raise ValueError("column must be provided when no cached sample is available") y_hist, bin_edges = self._histogram_computer.compute_histogram( df, col, bins=bins, use_rice_rule=use_rice_rule ) x_centers = (bin_edges[:-1] + bin_edges[1:]) / 2.0 else: if force_recompute: raise ValueError( "force_recompute=True requires df to be provided, " "since the cached sample is bypassed." ) raise ValueError("Either df must be provided or at least one result must contain a cached sample") return plot_comparison( results=results, y_hist=y_hist, x_hist=x_centers, title=title, xlabel=xlabel, ylabel=ylabel, figsize=figsize, dpi=dpi, show_histogram=show_histogram, histogram_alpha=histogram_alpha, pdf_linewidth=pdf_linewidth, title_fontsize=title_fontsize, label_fontsize=label_fontsize, legend_fontsize=legend_fontsize, grid_alpha=grid_alpha, save_path=save_path, save_format=save_format, )
@staticmethod def _validate_inputs( df: DataFrame, column: str, max_distributions: Optional[int], bins: Union[int, Tuple[float, ...]], sample_fraction: Optional[float], ) -> None: """Validate input parameters for distribution fitting. Args: df: Spark DataFrame containing data column: Column name to validate max_distributions: Maximum distributions to fit (0 is invalid) bins: Number of histogram bins (must be positive if int) sample_fraction: Sampling fraction (must be in (0, 1] if provided) Raises: ValueError: If max_distributions is 0, column not found, bins invalid, or sample_fraction out of range TypeError: If column is not numeric """ # Use base class validation methods BaseFitter._validate_max_distributions(max_distributions) BaseFitter._validate_column_exists(df, column) BaseFitter._validate_column_numeric(df, column) BaseFitter._validate_sample_fraction(sample_fraction) # Continuous-specific: validate bins parameter if isinstance(bins, int) and bins <= 0: raise ValueError(f"bins must be positive, got {bins}") @staticmethod def _validate_censoring_column(df: DataFrame, censoring_column: str) -> None: """Validate censoring column exists and is boolean/binary (v2.9.0). Args: df: DataFrame containing data censoring_column: Name of the censoring indicator column Raises: ValueError: If column doesn't exist or contains non-binary values """ # Use base class method to check column exists BaseFitter._validate_column_exists(df, censoring_column) # Check column data type - should be boolean or numeric with only 0/1 values # For Spark DataFrames if hasattr(df, "schema"): from pyspark.sql.types import BooleanType, IntegerType, LongType col_type = df.schema[censoring_column].dataType if not isinstance(col_type, (BooleanType, IntegerType, LongType)): raise ValueError( f"Censoring column '{censoring_column}' must be boolean or integer type " f"(True/1 = observed, False/0 = censored), got {col_type}" ) # For pandas DataFrames elif hasattr(df, "dtypes"): import pandas as pd col_dtype = df[censoring_column].dtype if col_dtype not in [bool, "bool", pd.BooleanDtype()] and not pd.api.types.is_integer_dtype(col_dtype): raise ValueError( f"Censoring column '{censoring_column}' must be boolean or integer type " f"(True/1 = observed, False/0 = censored), got {col_dtype}" ) # _validate_bounds inherited from BaseFitter # _resolve_bounds inherited from BaseFitter # _apply_sampling inherited from BaseFitter def _create_fitting_sample(self, df: DataFrame, column: str, row_count: int) -> np.ndarray: """Create numpy sample array for scipy distribution fitting. Samples up to FITTING_SAMPLE_SIZE rows from the DataFrame for use in scipy's distribution fitting functions. Args: df: Spark DataFrame or pandas DataFrame containing data column: Column name to sample row_count: Total row count (used to calculate sampling fraction) Returns: Numpy array of sampled values for distribution fitting """ sample_size = min(FITTING_SAMPLE_SIZE, row_count) fraction = min(sample_size / row_count, 1.0) # Use backend's sample_column which handles both Spark and pandas return self._backend.sample_column(df, column, fraction=fraction, seed=self.random_seed) # _calculate_partitions inherited from BaseFitter @staticmethod def _prefilter_distributions( distributions: List[str], data_sample: np.ndarray, mode: Union[bool, str], ) -> Tuple[List[str], List[Tuple[str, str]]]: """Pre-filter distributions based on data characteristics. Uses a layered approach based on SHAPE properties (not location/scale): 1. Skewness sign (~95% reliable): Skip positive-skew-only distributions for clearly left-skewed data (skewness < -1.0) 2. Kurtosis (aggressive mode only, ~80% reliable): Skip low-kurtosis distributions for very heavy-tailed data Note: We do NOT filter by support bounds (dist.a/dist.b) because scipy's fitting process uses loc/scale parameters that can shift any distribution to cover any data range. Shape properties (skewness, kurtosis) are intrinsic and cannot be changed by loc/scale. Args: distributions: List of distribution names to filter data_sample: Numpy array of sample data mode: True for safe mode, 'aggressive' for additional kurtosis filter Returns: Tuple of (compatible_distributions, filtered_with_reasons) """ # Early return if filtering is disabled if not mode: return distributions.copy(), [] from scipy.stats import kurtosis, skew data_skew = float(skew(data_sample)) data_kurt = float(kurtosis(data_sample)) # Excess kurtosis compatible = [] filtered = [] # Distributions that can only have positive skewness (mathematical constraint) positive_skew_only = { "expon", "gamma", "lognorm", "chi2", "weibull_min", "pareto", "rayleigh", "invgamma", "exponweib", "genpareto", "invweibull", "fisk", "burr", "burr12", "loggamma", "invgauss", "genextreme", # When shape > 0 "gompertz", "halfnorm", "halfcauchy", "halflogistic", "halfgennorm", "rice", "nakagami", "wald", "gengamma", "powerlognorm", } for dist_name in distributions: try: # We intentionally do NOT check support bounds (dist.a/dist.b) # because scipy.fit() uses loc/scale parameters that can shift # any distribution to cover any data range. # Layer 1: Skewness sign check (~95% reliable) # Only filter if data is CLEARLY left-skewed (threshold = -1.0) # These distributions are intrinsically right-skewed regardless of loc/scale if data_skew < -1.0 and dist_name in positive_skew_only: filtered.append((dist_name, "positive-skew only")) continue # Layer 2: Kurtosis check (aggressive mode only, ~80% reliable) if mode == "aggressive" and data_kurt > 10: # Very heavy-tailed data - skip uniform which has kurtosis = -1.2 # Uniform's kurtosis is intrinsic and cannot be changed by loc/scale if dist_name == "uniform": filtered.append((dist_name, "low kurtosis distribution")) continue compatible.append(dist_name) except AttributeError: # Unknown distribution - keep it (conservative) compatible.append(dist_name) return compatible, filtered
[docs] def plot_qq( self, result: DistributionFitResult, df: Optional[DataFrame] = None, column: Optional[str] = None, max_points: int = 1000, title: str = "", xlabel: str = "Theoretical Quantiles", ylabel: str = "Sample Quantiles", figsize: Tuple[int, int] = (10, 10), dpi: int = 100, marker: str = "o", marker_size: int = 30, marker_alpha: float = 0.6, marker_color: str = "steelblue", line_color: str = "red", line_style: str = "--", line_width: float = 1.5, title_fontsize: int = 14, label_fontsize: int = 12, grid_alpha: float = 0.3, save_path: Optional[str] = None, save_format: str = "png", force_recompute: bool = False, ): """Create a Q-Q plot to assess goodness-of-fit. A Q-Q (quantile-quantile) plot compares sample quantiles against theoretical quantiles from the fitted distribution. Points falling close to the reference line indicate a good fit. Args: result: DistributionFitResult to plot df: DataFrame with data. Optional when result contains a cached sample (the default after fitting). When both a cached sample and df are provided, the cached sample is used unless ``force_recompute=True``. column: Column name. If None, uses column_name from result. max_points: Maximum data points to sample for plotting title: Plot title xlabel: X-axis label ylabel: Y-axis label figsize: Figure size (width, height) dpi: Dots per inch for saved figures marker: Marker style for data points marker_size: Size of markers marker_alpha: Marker transparency (0-1) marker_color: Color of markers line_color: Color of reference line line_style: Style of reference line line_width: Width of reference line title_fontsize: Title font size label_fontsize: Axis label font size grid_alpha: Grid transparency (0-1) save_path: Path to save figure (optional) save_format: Save format (png, pdf, svg) force_recompute: If True, ignore cached sample and resample from ``df`` (requires ``df`` to be provided). Default False. Returns: Tuple of (figure, axis) from matplotlib Example: >>> best = results.best(n=1)[0] >>> # Instant Q-Q plot from cached sample (recommended) >>> fitter.plot_qq(best, title='Instant Q-Q Plot') >>> # Explicit DataFrame (resamples via Spark) >>> fitter.plot_qq(best, df, 'value', title='Q-Q Plot', force_recompute=True) """ from spark_bestfit.plotting import plot_qq from spark_bestfit.storage import _get_dataframe_row_count, _sample_dataframe_column # Cache-first: prefer cached sample to avoid Spark DAG recomputation if result.cached_sample is not None and not force_recompute: if df is not None: self._warn_df_with_cache("plot_qq") data = result.cached_sample[:max_points] elif df is not None: col = column or result.column_name if col is None: raise ValueError("column must be provided if result.column_name is None") row_count = _get_dataframe_row_count(df) fraction = min(max_points * 3 / row_count, 1.0) if row_count > 0 else 1.0 sampled = _sample_dataframe_column(df, col, fraction, self.random_seed) data = sampled[:max_points] else: if force_recompute: raise ValueError( "force_recompute=True requires df to be provided, " "since the cached sample is bypassed." ) raise ValueError("Either df must be provided or result must contain a cached sample") return plot_qq( result=result, data=data, title=title, xlabel=xlabel, ylabel=ylabel, figsize=figsize, dpi=dpi, marker=marker, marker_size=marker_size, marker_alpha=marker_alpha, marker_color=marker_color, line_color=line_color, line_style=line_style, line_width=line_width, title_fontsize=title_fontsize, label_fontsize=label_fontsize, grid_alpha=grid_alpha, save_path=save_path, save_format=save_format, )
[docs] def plot_pp( self, result: DistributionFitResult, df: Optional[DataFrame] = None, column: Optional[str] = None, max_points: int = 1000, title: str = "", xlabel: str = "Theoretical Probabilities", ylabel: str = "Sample Probabilities", figsize: Tuple[int, int] = (10, 10), dpi: int = 100, marker: str = "o", marker_size: int = 30, marker_alpha: float = 0.6, marker_color: str = "steelblue", line_color: str = "red", line_style: str = "--", line_width: float = 1.5, title_fontsize: int = 14, label_fontsize: int = 12, grid_alpha: float = 0.3, save_path: Optional[str] = None, save_format: str = "png", force_recompute: bool = False, ): """ Create a P-P plot to assess goodness-of-fit. A P-P (probability-probability) plot compares the empirical CDF of the sample data against the theoretical CDF of the fitted distribution. Points falling close to the reference line indicate a good fit, particularly in the center of the distribution. Args: result: DistributionFitResult to plot df: DataFrame with data. Optional when result contains a cached sample (the default after fitting). When both a cached sample and df are provided, the cached sample is used unless ``force_recompute=True``. column: Column name. If None, uses column_name from result. max_points: Maximum data points to sample for plotting title: Plot title xlabel: X-axis label ylabel: Y-axis label figsize: Figure size (width, height) dpi: Dots per inch for saved figures marker: Marker style for data points marker_size: Size of markers marker_alpha: Marker transparency (0-1) marker_color: Color of markers line_color: Color of reference line line_style: Style of reference line line_width: Width of reference line title_fontsize: Title font size label_fontsize: Axis label font size grid_alpha: Grid transparency (0-1) save_path: Path to save figure (optional) save_format: Save format (png, pdf, svg) force_recompute: If True, ignore cached sample and resample from ``df`` (requires ``df`` to be provided). Default False. Returns: Tuple of (figure, axis) from matplotlib Example: >>> best = results.best(n=1)[0] >>> # Instant P-P plot from cached sample (recommended) >>> fitter.plot_pp(best, title='Instant P-P Plot') >>> # Explicit DataFrame (resamples via Spark) >>> fitter.plot_pp(best, df, 'value', title='P-P Plot', force_recompute=True) """ from spark_bestfit.plotting import plot_pp from spark_bestfit.storage import _get_dataframe_row_count, _sample_dataframe_column # Cache-first: prefer cached sample to avoid Spark DAG recomputation if result.cached_sample is not None and not force_recompute: if df is not None: self._warn_df_with_cache("plot_pp") data = result.cached_sample[:max_points] elif df is not None: col = column or result.column_name if col is None: raise ValueError("column must be provided if result.column_name is None") row_count = _get_dataframe_row_count(df) fraction = min(max_points * 3 / row_count, 1.0) if row_count > 0 else 1.0 sampled = _sample_dataframe_column(df, col, fraction, self.random_seed) data = sampled[:max_points] else: if force_recompute: raise ValueError( "force_recompute=True requires df to be provided, " "since the cached sample is bypassed." ) raise ValueError("Either df must be provided or result must contain a cached sample") return plot_pp( result=result, data=data, title=title, xlabel=xlabel, ylabel=ylabel, figsize=figsize, dpi=dpi, marker=marker, marker_size=marker_size, marker_alpha=marker_alpha, marker_color=marker_color, line_color=line_color, line_style=line_style, line_width=line_width, title_fontsize=title_fontsize, label_fontsize=label_fontsize, grid_alpha=grid_alpha, save_path=save_path, save_format=save_format, )