"""Discrete distribution fitting engine for Spark."""
import logging
from functools import reduce
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, Union
import numpy as np
# PySpark is optional - only import if available
try:
from pyspark.sql import DataFrame, SparkSession
_PYSPARK_AVAILABLE = True
except ImportError:
DataFrame = None # type: ignore[assignment,misc]
SparkSession = None # type: ignore[assignment,misc]
_PYSPARK_AVAILABLE = False
from spark_bestfit.base_fitter import BaseFitter
from spark_bestfit.config import FitterConfig
from spark_bestfit.discrete_fitting import (
DISCRETE_FIT_RESULT_SCHEMA,
compute_discrete_histogram,
create_discrete_sample_data,
fit_single_discrete_distribution,
)
from spark_bestfit.distributions import DiscreteDistributionRegistry
from spark_bestfit.fitting import FITTING_SAMPLE_SIZE, compute_data_stats
from spark_bestfit.results import DistributionFitResult, FitResultsType, LazyMetricsContext
if TYPE_CHECKING:
from spark_bestfit.protocols import ExecutionBackend
logger = logging.getLogger(__name__)
# Re-export for convenience
DEFAULT_EXCLUDED_DISCRETE_DISTRIBUTIONS: Tuple[str, ...] = tuple(DiscreteDistributionRegistry.DEFAULT_EXCLUSIONS)
[docs]
class DiscreteDistributionFitter(BaseFitter):
"""Spark distribution fitting engine for discrete (count) data.
Efficiently fits scipy.stats discrete distributions to integer data using
Spark's parallel processing capabilities. Uses MLE optimization since
scipy discrete distributions don't have a built-in fit() method.
Metric Selection:
For discrete distributions, **AIC is recommended** for model selection:
- ``aic``: Proper model selection criterion with complexity penalty
- ``bic``: Similar to AIC but stronger penalty for complex models
- ``ks_statistic``: Valid for ranking, but p-values are not reliable
- ``sse``: Simple comparison metric
The K-S test assumes continuous distributions. For discrete data,
the K-S statistic can rank fits, but p-values are conservative and
should not be used for hypothesis testing.
Example:
>>> from pyspark.sql import SparkSession
>>> from spark_bestfit import DiscreteDistributionFitter
>>>
>>> spark = SparkSession.builder.appName("my-app").getOrCreate()
>>> df = spark.createDataFrame([(x,) for x in count_data], ['counts'])
>>>
>>> fitter = DiscreteDistributionFitter(spark)
>>> results = fitter.fit(df, column='counts')
>>>
>>> # Use AIC for model selection (recommended)
>>> best = results.best(n=1, metric='aic')[0]
>>> print(f"Best: {best.distribution} (AIC={best.aic:.2f})")
"""
# Class attributes for BaseFitter
_registry_class = DiscreteDistributionRegistry
_default_exclusions = DEFAULT_EXCLUDED_DISCRETE_DISTRIBUTIONS
def __init__(
self,
spark: Optional[SparkSession] = None,
excluded_distributions: Optional[Tuple[str, ...]] = None,
random_seed: int = 42,
backend: Optional["ExecutionBackend"] = None,
):
"""Initialize DiscreteDistributionFitter.
Args:
spark: SparkSession. If None, uses the active session.
Ignored if ``backend`` is provided.
excluded_distributions: Distributions to exclude from fitting.
Defaults to DEFAULT_EXCLUDED_DISCRETE_DISTRIBUTIONS.
Pass an empty tuple ``()`` to include ALL scipy discrete distributions.
random_seed: Random seed for reproducible sampling.
backend: Optional execution backend (v2.0). If None, creates a
SparkBackend from the spark session. Allows plugging in
alternative backends like LocalBackend for testing.
Raises:
RuntimeError: If no SparkSession provided and no active session exists
"""
super().__init__(
spark=spark,
excluded_distributions=excluded_distributions,
random_seed=random_seed,
backend=backend,
)
[docs]
def fit(
self,
df: DataFrame,
column: Optional[str] = None,
columns: Optional[List[str]] = None,
config: Optional[FitterConfig] = None,
*,
max_distributions: Optional[int] = None,
enable_sampling: bool = True,
sample_fraction: Optional[float] = None,
max_sample_size: int = 1_000_000,
sample_threshold: int = 10_000_000,
num_partitions: Optional[int] = None,
progress_callback: Optional[Callable[[int, int, float], None]] = None,
bounded: bool = False,
lower_bound: Optional[Union[float, Dict[str, float]]] = None,
upper_bound: Optional[Union[float, Dict[str, float]]] = None,
lazy_metrics: bool = False,
prefilter: Union[bool, str] = False,
) -> FitResultsType:
"""Fit discrete distributions to integer data column(s).
Args:
df: Spark DataFrame containing integer count data
column: Name of single column to fit distributions to
columns: List of column names for multi-column fitting
config: FitterConfig object (v2.2.0). Provides a cleaner way to
configure fitting with many parameters. If provided, individual
parameters below are ignored (except progress_callback which
can override the config's callback). Note: bins, use_rice_rule,
support_at_zero, and prefilter in config are ignored for
discrete fitting.
max_distributions: Limit number of distributions (for testing)
enable_sampling: Enable sampling for large datasets
sample_fraction: Fraction to sample (None = auto-determine)
max_sample_size: Maximum rows to sample when auto-determining
sample_threshold: Row count above which sampling is applied
num_partitions: Spark partitions (None = auto-determine)
progress_callback: Optional callback for progress updates.
Called with (completed_tasks, total_tasks, percent_complete).
Callback is invoked from background thread - ensure thread-safety.
bounded: Enable bounded distribution fitting. When True, bounds
are auto-detected from data or use explicit lower_bound/upper_bound.
lower_bound: Lower bound for truncated distribution fitting.
Can be a float (applied to all columns) or a dict mapping
column names to bounds (v1.5.0). If None and bounded=True,
auto-detects from each column's minimum.
upper_bound: Upper bound for truncated distribution fitting.
Can be a float (applied to all columns) or a dict mapping
column names to bounds (v1.5.0). If None and bounded=True,
auto-detects from each column's maximum.
lazy_metrics: If True, defer computation of expensive KS metrics
until accessed (v1.5.0). Improves fitting performance when only
using AIC/BIC/SSE for model selection. Default False for
backward compatibility.
prefilter: Pre-filter distributions (v1.6.0). Currently only supported
for continuous distributions. For discrete, this parameter is
accepted but ignored (logs a warning if enabled).
Returns:
FitResults object with fitted distributions
Raises:
ValueError: If column not found, DataFrame empty, or invalid params
TypeError: If column is not numeric
Example:
>>> # Using FitterConfig (v2.2.0)
>>> from spark_bestfit import FitterConfigBuilder
>>> config = (FitterConfigBuilder()
... .with_bounds(lower=0, upper=100)
... .with_sampling(fraction=0.1)
... .build())
>>> results = fitter.fit(df, column='counts', config=config)
>>>
>>> # Single column (backward compatible)
>>> results = fitter.fit(df, column='counts')
>>> best = results.best(n=1, metric='aic')
>>>
>>> # Multi-column
>>> results = fitter.fit(df, columns=['counts1', 'counts2'])
>>> best_per_col = results.best_per_column(n=1, metric='aic')
>>>
>>> # Bounded fitting
>>> results = fitter.fit(df, column='counts', bounded=True, lower_bound=0, upper_bound=100)
>>>
>>> # Lazy metrics for faster fitting when only using AIC/BIC (v1.5.0)
>>> results = fitter.fit(df, 'counts', lazy_metrics=True)
>>> best_aic = results.best(n=1, metric='aic')[0] # Fast, no KS computed
"""
# Resolve config: explicit config takes precedence over individual parameters
if config is not None:
# Use config values, but allow progress_callback override
cfg = config
if progress_callback is not None:
cfg = config.with_progress_callback(progress_callback)
else:
# Create config from individual parameters (backward compatibility)
cfg = FitterConfig(
max_distributions=max_distributions,
prefilter=prefilter,
enable_sampling=enable_sampling,
sample_fraction=sample_fraction,
max_sample_size=max_sample_size,
sample_threshold=sample_threshold,
bounded=bounded,
lower_bound=lower_bound,
upper_bound=upper_bound,
num_partitions=num_partitions,
lazy_metrics=lazy_metrics,
progress_callback=progress_callback,
)
# Normalize column/columns to list
target_columns = self._normalize_columns(column, columns)
# Input validation for all columns
for col in target_columns:
self._validate_inputs(df, col, cfg.max_distributions, cfg.sample_fraction)
# Warn if prefilter is enabled (not yet supported for discrete)
if cfg.prefilter:
logger.warning("prefilter is not yet supported for discrete distributions; ignoring")
# Validate bounds - handle both scalar and dict forms
self._validate_bounds(cfg.lower_bound, cfg.upper_bound, target_columns)
# Get row count (single operation for all columns)
row_count = self._get_row_count(df)
if row_count == 0:
raise ValueError("DataFrame is empty")
logger.info(f"Row count: {row_count}")
# Build per-column bounds dict: {col: (lower, upper)}
column_bounds: Dict[str, Tuple[Optional[float], Optional[float]]] = {}
if cfg.bounded:
column_bounds = self._resolve_bounds(df, target_columns, cfg.lower_bound, cfg.upper_bound)
# Sample if needed (single operation for all columns)
# For adaptive sampling, use first column as representative for skew analysis
df_sample = self._apply_sampling(
df,
row_count,
cfg.enable_sampling,
cfg.sample_fraction,
cfg.max_sample_size,
cfg.sample_threshold,
column=target_columns[0] if target_columns else None,
adaptive_sampling=cfg.adaptive_sampling,
sampling_mode=cfg.sampling_mode,
skew_threshold_mild=cfg.skew_threshold_mild,
skew_threshold_high=cfg.skew_threshold_high,
)
# Get distributions to fit (same for all columns)
distributions = self._registry.get_distributions(
additional_exclusions=list(self.excluded_distributions),
)
if cfg.max_distributions is not None and cfg.max_distributions > 0:
distributions = distributions[: cfg.max_distributions]
# Fit each column and collect results
all_results_dfs = []
lazy_contexts: Dict[str, LazyMetricsContext] = {}
cached_samples: Dict[str, np.ndarray] = {}
for col in target_columns:
# Get per-column bounds (empty dict if not bounded)
col_lower, col_upper = column_bounds.get(col, (None, None))
logger.info(f"Fitting discrete column '{col}'...")
# Create fitting sample - this is what we'll cache (v2.10.0: Instant mode)
# For discrete data, convert to int and normalize via create_discrete_sample_data
data_sample = self._create_fitting_sample(df_sample, col, row_count)
data_sample = data_sample.astype(int)
data_sample = create_discrete_sample_data(data_sample, sample_size=FITTING_SAMPLE_SIZE)
cached_samples[col] = data_sample
results_df = self._fit_single_column(
df_sample=df_sample,
column=col,
row_count=row_count,
distributions=distributions,
num_partitions=cfg.num_partitions,
lower_bound=col_lower,
upper_bound=col_upper,
lazy_metrics=cfg.lazy_metrics,
progress_callback=cfg.progress_callback,
data_sample=data_sample,
)
all_results_dfs.append(results_df)
# Build lazy context for on-demand metric computation
if cfg.lazy_metrics:
lazy_contexts[col] = LazyMetricsContext(
source_df=df_sample,
column=col,
random_seed=self.random_seed,
row_count=row_count,
lower_bound=col_lower,
upper_bound=col_upper,
is_discrete=True, # Discrete distributions
cached_sample=data_sample,
)
# Union all results - handle both Spark and pandas DataFrames
if self.spark is not None:
# Spark: union DataFrames
combined_df = reduce(DataFrame.union, all_results_dfs)
combined_df = combined_df.cache()
total_results = combined_df.count()
else:
# Non-Spark backend: concatenate pandas DataFrames
import pandas as pd
combined_df = pd.concat(all_results_dfs, ignore_index=True)
total_results = len(combined_df)
logger.info(
f"Total results: {total_results} ({len(target_columns)} columns × ~{len(distributions)} distributions)"
)
# Pass lazy contexts and cached samples to FitResults for instant plotting
from spark_bestfit.results import create_fit_results
return create_fit_results(
combined_df,
lazy_contexts=lazy_contexts if cfg.lazy_metrics else None,
samples=cached_samples,
)
def _fit_single_column(
self,
df_sample: DataFrame,
column: str,
row_count: int,
distributions: List[str],
num_partitions: Optional[int],
lower_bound: Optional[float] = None,
upper_bound: Optional[float] = None,
lazy_metrics: bool = False,
progress_callback: Optional[Callable[[int, int, float], None]] = None,
data_sample: Optional[np.ndarray] = None,
) -> DataFrame:
"""Fit discrete distributions to a single column (internal method).
Args:
df_sample: Sampled DataFrame
column: Column name
row_count: Original row count
distributions: List of distribution names to fit
num_partitions: Number of Spark partitions
lower_bound: Optional lower bound for truncated distribution
upper_bound: Optional upper bound for truncated distribution
lazy_metrics: If True, skip KS computation for performance (v1.5.0)
progress_callback: Optional callback for progress updates (v2.0.0)
data_sample: Optional pre-computed sample data (v2.10.0)
Returns:
Spark DataFrame with fit results for this column
"""
# Create discrete histogram (returns x_values and pmf for plotting)
# For discrete data, we use PMF-based fitting
# Create fitting sample if not provided
if data_sample is None:
# Create integer data sample for fitting
sample_size = min(FITTING_SAMPLE_SIZE, row_count)
fraction = min(sample_size / row_count, 1.0)
# Use backend's sample_column which handles both Spark and pandas
raw_sample = self._backend.sample_column(df_sample, column, fraction=fraction, seed=self.random_seed)
data_sample = raw_sample.astype(int)
data_sample = create_discrete_sample_data(data_sample, sample_size=FITTING_SAMPLE_SIZE)
# Handle empty sample (all NaN/inf data filtered out)
if len(data_sample) == 0:
logger.warning(f" No valid data for '{column}' after filtering NaN/inf values")
import pandas as pd
if self.spark is not None:
return self.spark.createDataFrame([], schema=DISCRETE_FIT_RESULT_SCHEMA)
else:
return pd.DataFrame(
columns=[
"column_name",
"distribution",
"parameters",
"sse",
"aic",
"bic",
"ks_statistic",
"pvalue",
"ad_statistic",
"ad_pvalue",
"data_min",
"data_max",
"data_mean",
"data_stddev",
"data_count",
"data_kurtosis",
"data_skewness",
"lower_bound",
"upper_bound",
]
)
logger.info(f" Data sample for '{column}': {len(data_sample)} values")
# Compute discrete histogram (PMF)
x_values, empirical_pmf = compute_discrete_histogram(data_sample)
logger.info(f" PMF for '{column}': {len(x_values)} unique values (range: {x_values.min()}-{x_values.max()})")
# Compute data stats for provenance (once per column)
data_stats = compute_data_stats(data_sample.astype(float))
# Interleave slow distributions for better partition balance
# (Currently no slow discrete distributions, but maintains consistency)
# Lazy import to avoid circular dependency with core.py
from spark_bestfit.core import _interleave_distributions
distributions = _interleave_distributions(distributions)
# Execute parallel fitting via backend (v2.0 abstraction)
# Backend handles: broadcast, partitioning, UDF application, collection
results = self._backend.parallel_fit(
distributions=distributions,
histogram=(x_values, empirical_pmf),
data_sample=data_sample,
fit_func=fit_single_discrete_distribution,
column_name=column,
data_stats=data_stats,
num_partitions=num_partitions,
lower_bound=lower_bound,
upper_bound=upper_bound,
lazy_metrics=lazy_metrics,
is_discrete=True,
progress_callback=progress_callback,
)
# Convert results to DataFrame
if self.spark is not None:
# Spark backend
if results:
results_df = self.spark.createDataFrame(results, schema=DISCRETE_FIT_RESULT_SCHEMA)
else:
results_df = self.spark.createDataFrame([], schema=DISCRETE_FIT_RESULT_SCHEMA)
else:
# Non-Spark backend: use pandas DataFrame
import pandas as pd
if results:
results_df = pd.DataFrame(results)
else:
# Create empty DataFrame with proper schema to preserve API contract
results_df = pd.DataFrame(
columns=[
"column_name",
"distribution",
"parameters",
"sse",
"aic",
"bic",
"ks_statistic",
"pvalue",
"ad_statistic",
"ad_pvalue",
"data_min",
"data_max",
"data_mean",
"data_stddev",
"data_count",
"lower_bound",
"upper_bound",
]
)
num_results = len(results)
logger.info(f" Fit {num_results}/{len(distributions)} distributions for '{column}'")
return results_df
@staticmethod
def _validate_inputs(
df: DataFrame,
column: str,
max_distributions: Optional[int],
sample_fraction: Optional[float],
) -> None:
"""Validate input parameters for discrete distribution fitting.
Args:
df: Spark DataFrame containing data
column: Column name to validate
max_distributions: Maximum distributions to fit (0 is invalid)
sample_fraction: Sampling fraction (must be in (0, 1] if provided)
Raises:
ValueError: If max_distributions is 0, column not found,
or sample_fraction out of range
TypeError: If column is not numeric
"""
# Use base class validation methods (no bins for discrete)
BaseFitter._validate_max_distributions(max_distributions)
BaseFitter._validate_column_exists(df, column)
BaseFitter._validate_column_numeric(df, column)
BaseFitter._validate_sample_fraction(sample_fraction)
# _validate_bounds inherited from BaseFitter
# _resolve_bounds inherited from BaseFitter
# _apply_sampling inherited from BaseFitter
# _calculate_partitions inherited from BaseFitter
[docs]
def plot(
self,
result: DistributionFitResult,
df: Optional[DataFrame] = None,
column: Optional[str] = None,
title: str = "",
xlabel: str = "Value",
ylabel: str = "Probability",
figsize: Tuple[int, int] = (12, 8),
dpi: int = 100,
show_histogram: bool = True,
histogram_alpha: float = 0.7,
pmf_linewidth: int = 2,
title_fontsize: int = 14,
label_fontsize: int = 12,
legend_fontsize: int = 10,
grid_alpha: float = 0.3,
save_path: Optional[str] = None,
save_format: str = "png",
force_recompute: bool = False,
):
"""Plot fitted discrete distribution against data histogram.
Args:
result: DistributionFitResult to plot
df: DataFrame with data. If None, uses cached sample from result (v2.10.0).
When a cached sample exists and ``force_recompute`` is False, the
cached sample is used and *df* is ignored (a warning is emitted).
column: Column name. If None, uses column_name from result.
title: Plot title
xlabel: X-axis label
ylabel: Y-axis label
figsize: Figure size (width, height)
dpi: Dots per inch for saved figures
show_histogram: Show data histogram
histogram_alpha: Histogram transparency (0-1)
pmf_linewidth: Line width for PMF curve
title_fontsize: Title font size
label_fontsize: Axis label font size
legend_fontsize: Legend font size
grid_alpha: Grid transparency (0-1)
save_path: Path to save figure (optional)
save_format: Save format (png, pdf, svg)
force_recompute: If True, ignore cached sample and recompute from
*df*. Default False (v3.0.2).
Returns:
Tuple of (figure, axis) from matplotlib
Example:
>>> best = results.best(n=1)[0]
>>> # v3.0.2: instant plotting using cached sample (default)
>>> fitter.plot(best, title='Instant Plot')
>>> # Force recompute from DataFrame
>>> fitter.plot(best, df, 'value', title='Recomputed', force_recompute=True)
"""
from spark_bestfit.plotting import plot_discrete_distribution
# Cache takes priority over df (v3.0.2: cache-first mode)
if result.cached_sample is not None and not force_recompute:
if df is not None:
self._warn_df_with_cache("plot")
data = result.cached_sample
elif df is not None:
col = column or result.column_name
if col is None:
raise ValueError("column must be provided if result.column_name is None")
# Handle Spark DataFrame, Ray Dataset, and pandas DataFrame
if hasattr(df, "sparkSession"):
row_count = df.count()
elif hasattr(df, "select_columns") and hasattr(df, "count"):
row_count = df.count()
else:
row_count = len(df)
fraction = min(10000 / row_count, 1.0)
data = self._backend.sample_column(df, col, fraction=fraction, seed=self.random_seed).astype(int)
else:
if force_recompute:
raise ValueError(
"force_recompute=True requires df to be provided, " "since the cached sample is bypassed."
)
raise ValueError("Either df must be provided or result must contain a cached sample")
return plot_discrete_distribution(
result=result,
data=data,
title=title,
xlabel=xlabel,
ylabel=ylabel,
figsize=figsize,
dpi=dpi,
show_histogram=show_histogram,
histogram_alpha=histogram_alpha,
pmf_linewidth=pmf_linewidth,
title_fontsize=title_fontsize,
label_fontsize=label_fontsize,
legend_fontsize=legend_fontsize,
grid_alpha=grid_alpha,
save_path=save_path,
save_format=save_format,
)