"""Discrete distribution fitting using MLE optimization and Pandas UDFs."""
from __future__ import annotations
import warnings
from typing import Any, Callable, Dict, List, Optional, Tuple
import numpy as np
import pandas as pd
import scipy.optimize as opt
import scipy.stats as st
# PySpark is optional - only import if available
try:
from pyspark import Broadcast
from pyspark.sql.functions import pandas_udf
from pyspark.sql.types import ArrayType, FloatType, StringType, StructField, StructType
_PYSPARK_AVAILABLE = True
except ImportError:
Broadcast = None # type: ignore[assignment]
pandas_udf = None # type: ignore[assignment]
ArrayType = None # type: ignore[assignment]
FloatType = None # type: ignore[assignment]
StringType = None # type: ignore[assignment]
StructField = None # type: ignore[assignment]
StructType = None # type: ignore[assignment]
_PYSPARK_AVAILABLE = False
from spark_bestfit.distributions import DiscreteDistributionRegistry
from spark_bestfit.fitting import _filter_bootstrap_outliers
# Output schema for discrete fitting results (only if PySpark is available)
# Note: ad_statistic and ad_pvalue are included for schema compatibility with FitResults
# but are always None for discrete distributions (A-D test is for continuous distributions)
if _PYSPARK_AVAILABLE:
DISCRETE_FIT_RESULT_SCHEMA = StructType(
[
StructField("column_name", StringType(), True), # Column being fitted
StructField("distribution", StringType(), True),
StructField("parameters", ArrayType(FloatType()), True),
StructField("sse", FloatType(), True),
StructField("aic", FloatType(), True),
StructField("bic", FloatType(), True),
StructField("ks_statistic", FloatType(), True),
StructField("pvalue", FloatType(), True),
StructField("ad_statistic", FloatType(), True),
StructField("ad_pvalue", FloatType(), True),
# Flat data summary columns for provenance (v2.0: replaced MapType for ~20% perf)
StructField("data_min", FloatType(), True),
StructField("data_max", FloatType(), True),
StructField("data_mean", FloatType(), True),
StructField("data_stddev", FloatType(), True),
StructField("data_count", FloatType(), True),
# Heavy-tail detection stats (v2.3.0)
StructField("data_kurtosis", FloatType(), True),
StructField("data_skewness", FloatType(), True),
# Bounded distribution support
StructField("lower_bound", FloatType(), True),
StructField("upper_bound", FloatType(), True),
]
)
else:
DISCRETE_FIT_RESULT_SCHEMA = None # type: ignore[assignment]
[docs]
def fit_discrete_mle(
dist_name: str,
data: np.ndarray,
initial_params: List[float],
bounds: List[Tuple[float, float]],
) -> Tuple[np.ndarray, float]:
"""Fit a discrete distribution using maximum likelihood estimation.
Since scipy discrete distributions don't have a fit() method, we use
scipy.optimize.minimize to find parameters that maximize the likelihood.
Args:
dist_name: Name of the scipy.stats discrete distribution
data: Integer data to fit
initial_params: Initial parameter guesses
bounds: Parameter bounds as list of (min, max) tuples
Returns:
Tuple of (fitted_params, negative_log_likelihood)
Raises:
ValueError: If optimization fails to converge
"""
dist = getattr(st, dist_name)
def neg_log_likelihood(params: np.ndarray) -> float:
"""Compute negative log-likelihood for optimization."""
try:
# Ensure integer parameters where needed (e.g., n in binomial)
int_param_dists = {"binom", "betabinom", "hypergeom", "nhypergeom", "boltzmann", "zipfian"}
if dist_name in int_param_dists:
# First parameter is typically the integer one
params = list(params)
params[0] = int(round(params[0]))
params = tuple(params)
ll = np.sum(dist.logpmf(data, *params))
if not np.isfinite(ll):
return np.inf
return -ll
except (ValueError, RuntimeError, ZeroDivisionError):
return np.inf
# Run optimization
with warnings.catch_warnings():
warnings.simplefilter("ignore")
result = opt.minimize(
neg_log_likelihood,
initial_params,
bounds=bounds,
method="L-BFGS-B",
options={"maxiter": 200, "ftol": 1e-8},
)
if not result.success and result.fun == np.inf:
raise ValueError(f"Optimization failed for {dist_name}: {result.message}")
return result.x, result.fun
[docs]
def compute_discrete_histogram(
data: np.ndarray,
) -> Tuple[np.ndarray, np.ndarray]:
"""Compute histogram for discrete (integer) data.
Unlike continuous histograms, discrete histograms use integer-aligned bins
and compute empirical probability mass function (PMF).
Args:
data: Integer data array
Returns:
Tuple of (values, pmf) where:
- values: unique integer values in data
- pmf: empirical probability mass at each value
"""
# Get unique values and counts
values, counts = np.unique(data.astype(int), return_counts=True)
# Convert to PMF (probability mass)
pmf = counts / len(data)
return values, pmf
[docs]
def evaluate_pmf(
dist: Any,
params: Tuple[float, ...],
x: np.ndarray,
dist_name: str,
) -> np.ndarray:
"""Evaluate probability mass function at given integer points.
Args:
dist: scipy.stats discrete distribution object
params: Distribution parameters
x: Integer points at which to evaluate PMF
dist_name: Name of distribution (for special handling)
Returns:
PMF values at x
"""
# Handle distributions requiring integer parameters
int_param_dists = {"binom", "betabinom", "hypergeom", "nhypergeom", "boltzmann", "zipfian"}
if dist_name in int_param_dists:
params_list = list(params)
params_list[0] = int(round(params_list[0]))
params = tuple(params_list)
try:
pmf = dist.pmf(x, *params)
pmf = np.nan_to_num(pmf, nan=0.0, posinf=0.0, neginf=0.0)
return pmf
except (ValueError, RuntimeError):
return np.zeros_like(x, dtype=float)
[docs]
def compute_discrete_sse(
dist: Any,
params: Tuple[float, ...],
x_values: np.ndarray,
empirical_pmf: np.ndarray,
dist_name: str,
) -> float:
"""Compute sum of squared errors between empirical and fitted PMF.
Args:
dist: scipy.stats discrete distribution object
params: Fitted distribution parameters
x_values: Integer values where PMF is evaluated
empirical_pmf: Empirical probability mass at each x value
dist_name: Name of distribution
Returns:
Sum of squared errors
"""
fitted_pmf = evaluate_pmf(dist, params, x_values, dist_name)
sse = np.sum((empirical_pmf - fitted_pmf) ** 2)
if not np.isfinite(sse):
return np.inf
return float(sse)
[docs]
def compute_discrete_ks_statistic(
dist: Any,
params: Tuple[float, ...],
data: np.ndarray,
dist_name: str,
) -> Tuple[float, float]:
"""Compute Kolmogorov-Smirnov statistic for discrete distribution.
Computes the two-sided KS statistic D_n = max(D+, D-) which measures
the maximum distance between empirical and theoretical CDFs.
Note:
The standard KS test assumes continuous distributions.
For discrete distributions, the KS statistic is valid for comparing
fits, but p-values are conservative and should not be used for
formal hypothesis testing. Use AIC/BIC for model selection instead.
Args:
dist: scipy.stats discrete distribution object
params: Fitted distribution parameters
data: Original integer data
dist_name: Name of distribution
Returns:
Tuple of (ks_statistic, pvalue) where pvalue is approximate only
"""
try:
# Handle integer parameter distributions
int_param_dists = {"binom", "betabinom", "hypergeom", "nhypergeom", "boltzmann", "zipfian"}
if dist_name in int_param_dists:
params_list = list(params)
params_list[0] = int(round(params_list[0]))
params = tuple(params_list)
# Compute empirical CDF
sorted_data = np.sort(data)
n = len(data)
# Compute theoretical CDF at sorted data points
tcdf = dist.cdf(sorted_data, *params)
# Two-sided KS statistic: D_n = max(D+, D-)
# D+ = max_i(i/n - F(x_i)) - max deviation where empirical > theoretical
# D- = max_i(F(x_i) - (i-1)/n) - max deviation where theoretical > empirical
ecdf_upper = np.arange(1, n + 1) / n # F_n(x_i) = i/n (value after jump)
ecdf_lower = np.arange(0, n) / n # F_n(x_i-) = (i-1)/n (value before jump)
d_plus = np.max(ecdf_upper - tcdf)
d_minus = np.max(tcdf - ecdf_lower)
ks_stat = max(d_plus, d_minus)
# Approximate p-value using asymptotic distribution
# Note: This is approximate for discrete distributions
# sqrt(n) * D_n converges to Kolmogorov distribution
pvalue = float(st.kstwobign.sf(np.sqrt(n) * ks_stat))
if not np.isfinite(ks_stat):
return np.inf, 0.0
if not np.isfinite(pvalue):
pvalue = 0.0
return float(ks_stat), pvalue
except (ValueError, RuntimeError, FloatingPointError):
return np.inf, 0.0
[docs]
def compute_ks_ad_metrics_discrete(
dist_name: str,
params: List[float],
data_sample: np.ndarray,
lower_bound: Optional[float] = None,
upper_bound: Optional[float] = None,
) -> Tuple[Optional[float], Optional[float], Optional[float], Optional[float]]:
"""Compute KS metrics for a fitted discrete distribution.
This is the core computation function used for lazy metric evaluation
with discrete distributions.
Note: Anderson-Darling is not computed for discrete distributions
(AD test is for continuous distributions only).
Args:
dist_name: Name of scipy.stats discrete distribution
params: Fitted distribution parameters
data_sample: Integer data sample for metric computation
lower_bound: Optional lower bound (unused for discrete, for API compatibility)
upper_bound: Optional upper bound (unused for discrete, for API compatibility)
Returns:
Tuple of (ks_statistic, pvalue, ad_statistic, ad_pvalue)
ad_statistic and ad_pvalue are always None for discrete distributions.
Returns (None, None, None, None) if computation fails.
"""
try:
# Get distribution object
dist = getattr(st, dist_name)
# Ensure integer data
data_sample = data_sample.astype(int)
# Compute KS statistic
ks_stat, pvalue = compute_discrete_ks_statistic(
dist=dist,
params=tuple(params),
data=data_sample,
dist_name=dist_name,
)
# AD is not computed for discrete distributions
ad_stat = None
ad_pvalue = None
return (
float(ks_stat) if ks_stat is not None and np.isfinite(ks_stat) else None,
float(pvalue) if pvalue is not None and np.isfinite(pvalue) else None,
ad_stat,
ad_pvalue,
)
except (ValueError, RuntimeError, FloatingPointError, AttributeError):
return (None, None, None, None)
[docs]
def fit_single_discrete_distribution(
dist_name: str,
data_sample: np.ndarray,
x_values: np.ndarray,
empirical_pmf: np.ndarray,
registry: DiscreteDistributionRegistry,
column_name: Optional[str] = None,
data_stats: Optional[Dict[str, float]] = None,
lower_bound: Optional[float] = None,
upper_bound: Optional[float] = None,
lazy_metrics: bool = False,
) -> Dict[str, Any]:
"""Fit a single discrete distribution and compute goodness-of-fit metrics.
Args:
dist_name: Name of scipy.stats discrete distribution
data_sample: Sample of integer data for parameter fitting
x_values: Unique integer values in data
empirical_pmf: Empirical PMF at each x value
registry: DiscreteDistributionRegistry for parameter configs
column_name: Name of the column being fitted (for multi-column support)
data_stats: Pre-computed summary statistics (data_min, data_max, etc.)
lower_bound: Optional lower bound for truncated distribution
upper_bound: Optional upper bound for truncated distribution
lazy_metrics: If True, skip expensive KS computation. These metrics
will be None in the result and computed on-demand later. (v1.5.0)
Returns:
Dictionary with fit result fields including data_min, data_max, etc.
"""
try:
with warnings.catch_warnings(record=True) as caught_warnings:
warnings.simplefilter("always")
# Get distribution object
dist = getattr(st, dist_name)
# Get parameter configuration
config = registry.get_param_config(dist_name)
initial = config["initial"](data_sample)
bounds = config["bounds"](data_sample)
# Fit using MLE optimization
params, neg_ll = fit_discrete_mle(dist_name, data_sample, initial, bounds)
# Check for invalid parameters
if any(not np.isfinite(p) for p in params):
return _failed_discrete_fit_result(dist_name, column_name, data_stats)
# Compute SSE using PMF
sse = compute_discrete_sse(dist, tuple(params), x_values, empirical_pmf, dist_name)
if not np.isfinite(sse):
return _failed_discrete_fit_result(dist_name, column_name, data_stats)
# Compute information criteria (fast, always computed)
aic, bic = compute_discrete_information_criteria(dist, tuple(params), data_sample, dist_name)
# Compute expensive metrics only if not lazy
if lazy_metrics:
# Skip KS computation for performance - will be computed on-demand
ks_stat, pvalue = None, None
else:
# Compute KS statistic
ks_stat, pvalue = compute_discrete_ks_statistic(dist, tuple(params), data_sample, dist_name)
# Check for convergence warnings
for w in caught_warnings:
if "convergence" in str(w.message).lower() or "nan" in str(w.message).lower():
return _failed_discrete_fit_result(dist_name, column_name, data_stats, lower_bound, upper_bound)
return {
"column_name": column_name,
"distribution": dist_name,
"parameters": [float(p) for p in params],
"sse": float(sse),
"aic": float(aic),
"bic": float(bic),
"ks_statistic": float(ks_stat) if ks_stat is not None else None,
"pvalue": float(pvalue) if pvalue is not None else None,
"ad_statistic": None, # A-D not computed for discrete distributions
"ad_pvalue": None,
**(data_stats or {}), # Flat data stats: data_min, data_max, etc.
"lower_bound": float(lower_bound) if lower_bound is not None else None,
"upper_bound": float(upper_bound) if upper_bound is not None else None,
}
except Exception:
# Catch all exceptions to ensure fitting never crashes the Spark job
# This matches behavior of LocalBackend and RayBackend which skip failed fits
return _failed_discrete_fit_result(dist_name, column_name, data_stats, lower_bound, upper_bound)
def _failed_discrete_fit_result(
dist_name: str,
column_name: Optional[str] = None,
data_stats: Optional[Dict[str, float]] = None,
lower_bound: Optional[float] = None,
upper_bound: Optional[float] = None,
) -> Dict[str, Any]:
"""Return sentinel values for failed discrete fits.
Args:
dist_name: Name of the distribution that failed
column_name: Name of the column being fitted (for multi-column support)
data_stats: Pre-computed summary statistics (data_min, data_max, etc.)
lower_bound: Optional lower bound for truncated distribution
upper_bound: Optional upper bound for truncated distribution
Returns:
Dictionary with sentinel values indicating fit failure
"""
return {
"column_name": column_name,
"distribution": dist_name,
"parameters": [float(np.nan)],
"sse": float(np.inf),
"aic": float(np.inf),
"bic": float(np.inf),
"ks_statistic": float(np.inf),
"pvalue": 0.0,
"ad_statistic": None, # A-D not computed for discrete distributions
"ad_pvalue": None,
**(data_stats or {}), # Flat data stats: data_min, data_max, etc.
"lower_bound": float(lower_bound) if lower_bound is not None else None,
"upper_bound": float(upper_bound) if upper_bound is not None else None,
}
[docs]
def create_discrete_fitting_udf(
histogram_broadcast: Broadcast[Tuple[np.ndarray, np.ndarray]],
data_sample_broadcast: Broadcast[np.ndarray],
column_name: Optional[str] = None,
data_stats: Optional[Dict[str, float]] = None,
lower_bound: Optional[float] = None,
upper_bound: Optional[float] = None,
lazy_metrics: bool = False,
) -> Callable[[pd.Series], pd.DataFrame]:
"""Factory function to create Pandas UDF for discrete distribution fitting.
Args:
histogram_broadcast: Broadcast variable containing (x_values, empirical_pmf)
data_sample_broadcast: Broadcast variable containing integer data sample
column_name: Name of the column being fitted (for result tracking)
data_stats: Pre-computed summary statistics (data_min, data_max, etc.)
lower_bound: Optional lower bound for truncated distribution
upper_bound: Optional upper bound for truncated distribution
lazy_metrics: If True, skip expensive KS computation during fitting.
These metrics will be computed on-demand when accessed via
FitResults.best() or DistributionFitResult properties. (v1.5.0)
Returns:
Pandas UDF function for fitting discrete distributions
"""
# Create registry once - will be serialized to workers
registry = DiscreteDistributionRegistry()
@pandas_udf(DISCRETE_FIT_RESULT_SCHEMA)
def fit_discrete_distributions_batch(distribution_names: pd.Series) -> pd.DataFrame:
"""Vectorized UDF to fit multiple discrete distributions in a batch.
Args:
distribution_names: Series of scipy discrete distribution names
Returns:
DataFrame with fitting results including bounds
"""
# Get broadcasted data
x_values, empirical_pmf = histogram_broadcast.value
data_sample = data_sample_broadcast.value
# Fit each distribution in the batch
results = []
for dist_name in distribution_names:
try:
result = fit_single_discrete_distribution(
dist_name=dist_name,
data_sample=data_sample,
x_values=x_values,
empirical_pmf=empirical_pmf,
registry=registry,
column_name=column_name,
data_stats=data_stats,
lower_bound=lower_bound,
upper_bound=upper_bound,
lazy_metrics=lazy_metrics,
)
except Exception:
# Safety net: catch any unexpected exceptions to prevent job failure
result = _failed_discrete_fit_result(dist_name, column_name, data_stats, lower_bound, upper_bound)
results.append(result)
# Create DataFrame
df = pd.DataFrame(results)
df["distribution"] = df["distribution"].astype(str)
df["sse"] = df["sse"].astype(float)
return df
return fit_discrete_distributions_batch
[docs]
def create_discrete_sample_data(
data_full: np.ndarray,
sample_size: int = 10_000,
random_seed: int = 42,
) -> np.ndarray:
"""Create a sample of discrete data for distribution fitting.
Args:
data_full: Full integer dataset
sample_size: Target sample size
random_seed: Random seed for reproducibility
Returns:
Sampled integer data
"""
if len(data_full) <= sample_size:
return data_full.astype(int)
rng = np.random.RandomState(random_seed)
indices = rng.choice(len(data_full), size=sample_size, replace=False)
return data_full[indices].astype(int)
[docs]
def get_discrete_param_names(dist_name: str) -> List[str]:
"""Get parameter names for a discrete scipy distribution.
Args:
dist_name: Name of scipy.stats discrete distribution
Returns:
List of parameter names
Example:
>>> get_discrete_param_names("poisson")
['mu']
>>> get_discrete_param_names("binom")
['n', 'p']
>>> get_discrete_param_names("nbinom")
['n', 'p']
"""
registry = DiscreteDistributionRegistry()
config = registry.get_param_config(dist_name)
return config["param_names"]
[docs]
def bootstrap_discrete_confidence_intervals(
dist_name: str,
data: np.ndarray,
alpha: float = 0.05,
n_bootstrap: int = 1000,
random_seed: Optional[int] = None,
) -> Dict[str, Tuple[float, float]]:
"""Compute bootstrap confidence intervals for discrete distribution parameters.
Uses the percentile bootstrap method: resample data with replacement,
refit the distribution using MLE, and compute confidence intervals from
the empirical distribution of fitted parameters.
Args:
dist_name: Name of scipy.stats discrete distribution
data: Integer data array used for fitting
alpha: Significance level (default 0.05 for 95% CI)
n_bootstrap: Number of bootstrap samples (default 1000)
random_seed: Random seed for reproducibility
Returns:
Dictionary mapping parameter names to (lower, upper) bounds
Example:
>>> data = np.random.poisson(lam=7, size=1000)
>>> ci = bootstrap_discrete_confidence_intervals("poisson", data, alpha=0.05)
>>> print(ci)
{'mu': (6.75, 7.25)}
Note:
Bootstrap fitting may fail for some resamples. Failed fits are skipped.
"""
rng = np.random.default_rng(random_seed)
data = data.astype(int)
n = len(data)
# Get parameter configuration
registry = DiscreteDistributionRegistry()
config = registry.get_param_config(dist_name)
param_names = config["param_names"]
# Collect bootstrap parameter estimates
bootstrap_params: List[Tuple[float, ...]] = []
for _ in range(n_bootstrap):
# Resample with replacement
sample = rng.choice(data, size=n, replace=True)
try:
with warnings.catch_warnings():
warnings.simplefilter("ignore")
initial = config["initial"](sample)
bounds = config["bounds"](sample)
params, _ = fit_discrete_mle(dist_name, sample, initial, bounds)
# Skip if any parameter is non-finite
if all(np.isfinite(p) for p in params):
bootstrap_params.append(tuple(params))
except (ValueError, RuntimeError, FloatingPointError):
continue # Skip failed fits
if len(bootstrap_params) < 10:
raise ValueError(
f"Too few successful bootstrap fits ({len(bootstrap_params)}/{n_bootstrap}). "
"Data may be unsuitable for this distribution."
)
# Convert to array for percentile computation
bootstrap_array = np.array(bootstrap_params)
# Remove outlier bootstrap estimates using IQR filtering per parameter
# This prevents extreme outliers from blowing up the CI bounds
bootstrap_array = _filter_bootstrap_outliers(bootstrap_array)
if len(bootstrap_array) < 10:
raise ValueError(
"Too few bootstrap samples remain after outlier filtering. " "Data may be unsuitable for this distribution."
)
# Compute percentile confidence intervals
lower_pct = (alpha / 2) * 100
upper_pct = (1 - alpha / 2) * 100
result: Dict[str, Tuple[float, float]] = {}
for i, name in enumerate(param_names):
lower = float(np.percentile(bootstrap_array[:, i], lower_pct))
upper = float(np.percentile(bootstrap_array[:, i], upper_pct))
result[name] = (lower, upper)
return result