"""Data storage classes for distribution fitting results.
This module contains the core data classes for storing individual distribution
fit results and related context objects. These are the fundamental building
blocks used throughout spark_bestfit.
Classes:
DistributionFitResult: Stores a single distribution's fitted parameters and metrics.
LazyMetricsContext: Context for deferred KS/AD metric computation.
Type Aliases:
MetricName: Valid metric names for sorting/filtering.
ContinuousHistogram: Tuple type for continuous distribution histograms.
DiscreteHistogram: Tuple type for discrete distribution histograms.
HistogramBins: Array of bin edges (len = n_bins + 1).
HistogramCounts: Array of counts/density per bin.
HistogramResult: Tuple type for HistogramComputer results (counts, bins).
Constants:
FITTING_SAMPLE_SIZE: Default sample size for fitting (10000).
DEFAULT_PVALUE_THRESHOLD: Default p-value threshold (0.05).
DEFAULT_KS_THRESHOLD: Default KS statistic threshold (0.10).
DEFAULT_AD_THRESHOLD: Default AD statistic threshold (2.0).
DEFAULT_BINS: Default number of histogram bins (50).
DEFAULT_BOOTSTRAP_SAMPLES: Default bootstrap iterations (1000).
DEFAULT_MAX_SAMPLES: Default maximum samples to collect (10000).
DEFAULT_DPI: Default plot DPI (100).
DEFAULT_SAMPLE_SIZE: Default sample size for sampling methods (1000).
"""
from dataclasses import dataclass
from pathlib import Path
from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Tuple, TypeAlias, Union
import numpy as np
import scipy.stats as st
from spark_bestfit.truncated import TruncatedFrozenDist
# PySpark is optional - only import if available
try:
from pyspark.sql import DataFrame
_PYSPARK_AVAILABLE = True
except ImportError:
DataFrame = None # type: ignore[assignment,misc]
_PYSPARK_AVAILABLE = False
if TYPE_CHECKING:
pass
# =============================================================================
# Constants
# =============================================================================
# Default sample size for fitting operations
FITTING_SAMPLE_SIZE = 10000
# Default threshold values for quality assessment
DEFAULT_PVALUE_THRESHOLD = 0.05
DEFAULT_KS_THRESHOLD = 0.10
DEFAULT_AD_THRESHOLD = 2.0
# Default histogram and plotting parameters
DEFAULT_BINS = 50
DEFAULT_DPI = 100
# Default bootstrap and sampling parameters
DEFAULT_BOOTSTRAP_SAMPLES = 1000
DEFAULT_MAX_SAMPLES = 10000
DEFAULT_SAMPLE_SIZE = 1000
# =============================================================================
# Type Aliases
# =============================================================================
# Valid metric names for sorting/filtering (for IDE autocomplete and type checking)
MetricName = Literal["sse", "aic", "bic", "ks_statistic", "ad_statistic"]
# Histogram type aliases - distinguish continuous from discrete semantically
# ContinuousHistogram: (density_values, bin_edges) where len(edges) = len(density) + 1
ContinuousHistogram = Tuple[np.ndarray, np.ndarray]
# DiscreteHistogram: (x_values, empirical_pmf) where both arrays have same length
DiscreteHistogram = Tuple[np.ndarray, np.ndarray]
# HistogramBins: Array of bin edges (len = n_bins + 1)
HistogramBins: TypeAlias = np.ndarray
# HistogramCounts: Array of counts/density per bin
HistogramCounts: TypeAlias = np.ndarray
# HistogramResult: Complete histogram from HistogramComputer.compute_histogram()
# Order is (counts, bins) where counts has n_bins elements and bins has n_bins + 1 edges
HistogramResult: TypeAlias = Tuple[np.ndarray, np.ndarray]
# =============================================================================
# DataFrame Utilities (Multi-Backend Support)
# =============================================================================
def _is_spark_dataframe(df) -> bool:
"""Check if df is a Spark DataFrame.
Uses duck typing to detect Spark DataFrames without requiring pyspark import.
Args:
df: DataFrame to check
Returns:
True if df is a Spark DataFrame, False otherwise
"""
return hasattr(df, "toPandas") and hasattr(df, "select")
def _is_ray_dataset(df) -> bool:
"""Check if df is a Ray Dataset (not a pandas DataFrame).
Uses duck typing to detect Ray Datasets. Note that pandas DataFrames
don't have select_columns(), so this won't match them.
Args:
df: DataFrame to check
Returns:
True if df is a Ray Dataset, False otherwise
"""
return hasattr(df, "select_columns") and hasattr(df, "to_pandas")
def _get_dataframe_row_count(df) -> int:
"""Get total row count from any supported DataFrame type.
Supports Spark DataFrames, Ray Datasets, and pandas DataFrames.
Args:
df: DataFrame (Spark, Ray Dataset, or pandas)
Returns:
Total number of rows in the DataFrame
"""
if _is_spark_dataframe(df):
return df.count()
elif _is_ray_dataset(df):
return df.count()
else: # pandas DataFrame
return len(df)
def _collect_dataframe_column(df, column: str) -> np.ndarray:
"""Extract a column from any supported DataFrame type as numpy array.
Supports Spark DataFrames, Ray Datasets, and pandas DataFrames.
Args:
df: DataFrame (Spark, Ray Dataset, or pandas)
column: Column name to extract
Returns:
Numpy array of column values
"""
if _is_spark_dataframe(df):
return np.array(df.select(column).toPandas()[column].values)
elif _is_ray_dataset(df):
return df.select_columns([column]).to_pandas()[column].values
else: # pandas DataFrame
return df[column].values
def _sample_dataframe_column(df, column: str, fraction: float, seed: Optional[int]) -> np.ndarray:
"""Sample a column from any supported DataFrame type and return as numpy array.
Supports Spark DataFrames, Ray Datasets, and pandas DataFrames.
Args:
df: DataFrame (Spark, Ray Dataset, or pandas)
column: Column name to sample
fraction: Fraction of rows to sample (0 < fraction <= 1)
seed: Random seed for reproducibility (can be None)
Returns:
Numpy array of sampled column values
"""
if _is_spark_dataframe(df):
if seed is not None:
sampled = df.sample(withReplacement=False, fraction=fraction, seed=seed)
else:
sampled = df.sample(withReplacement=False, fraction=fraction)
return np.array(sampled.select(column).toPandas()[column].values)
elif _is_ray_dataset(df):
sampled = df.random_sample(fraction, seed=seed)
return sampled.select_columns([column]).to_pandas()[column].values
else: # pandas DataFrame
sample_df = df[[column]].sample(frac=fraction, random_state=seed)
return sample_df[column].values
# =============================================================================
# Data Classes
# =============================================================================
@dataclass(slots=True)
class LazyMetricsContext:
"""Context for deferred KS/AD metric computation.
When lazy_metrics=True during fitting, this context stores everything
needed to compute KS/AD metrics on-demand later. The key insight is that
with the same (DataFrame, column, seed), we can recreate the exact sample.
Attributes:
source_df: Reference to the source DataFrame for sampling
column: Column name to sample from
random_seed: Seed used for reproducible sampling
row_count: Total row count for calculating sample fraction
lower_bound: Optional lower bound for truncated distributions
upper_bound: Optional upper bound for truncated distributions
is_discrete: Whether this is discrete distribution fitting
Note:
The source_df reference must remain valid (not unpersisted) for lazy
metric computation to work. Call materialize() before unpersisting
if you need the metrics.
"""
source_df: DataFrame
column: str
random_seed: int
row_count: int
lower_bound: Optional[float] = None
upper_bound: Optional[float] = None
is_discrete: bool = False
# Cached sample data for instant plotting (v2.10.0)
cached_sample: Optional[np.ndarray] = None
[docs]
@dataclass(slots=True)
class DistributionFitResult:
"""Result from fitting a single distribution.
Attributes:
distribution: Name of the scipy.stats distribution
parameters: Fitted parameters (shape params + loc + scale)
sse: Sum of Squared Errors
column_name: Name of the column that was fitted (for multi-column support)
aic: Akaike Information Criterion (lower is better)
bic: Bayesian Information Criterion (lower is better)
ks_statistic: Kolmogorov-Smirnov statistic (lower is better)
pvalue: P-value from KS test (higher indicates better fit)
ad_statistic: Anderson-Darling statistic (lower is better)
ad_pvalue: P-value from A-D test (only for norm, expon, logistic, gumbel_r, gumbel_l)
data_min: Minimum value in the data used for fitting
data_max: Maximum value in the data used for fitting
data_mean: Mean of the data used for fitting
data_stddev: Standard deviation of the data used for fitting
data_count: Number of samples in the data used for fitting
data_kurtosis: Excess kurtosis of the data used for fitting (v2.3.0)
data_skewness: Skewness of the data used for fitting (v2.3.0)
cached_sample: Cached sample data for instant plotting (v2.10.0)
lower_bound: Lower bound for truncated distribution fitting (v1.4.0).
When set, the distribution is truncated at this lower limit.
upper_bound: Upper bound for truncated distribution fitting (v1.4.0).
When set, the distribution is truncated at this upper limit.
Note:
The p-value from the KS test is approximate when parameters are
estimated from the same data being tested. It tends to be conservative
(larger than it should be). Use it for rough guidance, not strict
hypothesis testing. The ks_statistic is valid for ranking fits.
The ad_pvalue is only available for 5 distributions (norm, expon,
logistic, gumbel_r, gumbel_l) where scipy has critical value tables.
For other distributions, ad_pvalue will be None but ad_statistic
is still valid for ranking fits.
When bounds are set (lower_bound and/or upper_bound), methods like
sample(), pdf(), cdf(), and ppf() automatically use scipy.stats.truncate()
to return values respecting the bounded domain.
"""
distribution: str
parameters: List[float]
sse: float
column_name: Optional[str] = None
aic: Optional[float] = None
bic: Optional[float] = None
ks_statistic: Optional[float] = None
pvalue: Optional[float] = None
ad_statistic: Optional[float] = None
ad_pvalue: Optional[float] = None
# Flat data stats (v2.0: replaced data_summary MapType for ~20% perf)
data_min: Optional[float] = None
data_max: Optional[float] = None
data_mean: Optional[float] = None
data_stddev: Optional[float] = None
data_count: Optional[float] = None
data_kurtosis: Optional[float] = None
data_skewness: Optional[float] = None
# Cached sample data for instant plotting (v2.10.0)
cached_sample: Optional[np.ndarray] = None
# Bounds for truncated distribution fitting (v1.4.0)
lower_bound: Optional[float] = None
upper_bound: Optional[float] = None
[docs]
def to_dict(self) -> dict:
"""Convert result to dictionary.
Returns:
Dictionary representation of the result
"""
return {
"column_name": self.column_name,
"distribution": self.distribution,
"parameters": self.parameters,
"sse": self.sse,
"aic": self.aic,
"bic": self.bic,
"ks_statistic": self.ks_statistic,
"pvalue": self.pvalue,
"ad_statistic": self.ad_statistic,
"ad_pvalue": self.ad_pvalue,
"data_min": self.data_min,
"data_max": self.data_max,
"data_mean": self.data_mean,
"data_stddev": self.data_stddev,
"data_count": self.data_count,
"data_kurtosis": self.data_kurtosis,
"data_skewness": self.data_skewness,
"lower_bound": self.lower_bound,
"upper_bound": self.upper_bound,
}
[docs]
def get_scipy_dist(self, frozen: bool = True):
"""Get scipy distribution object.
Args:
frozen: If True (default), return a frozen distribution with parameters applied.
If False, return the unfrozen distribution class.
Returns:
scipy.stats distribution object. If bounds are set and frozen=True,
returns a TruncatedFrozenDist wrapper that handles truncation.
Note:
When bounds are set (lower_bound and/or upper_bound), the returned
distribution is truncated. This ensures that sampling and PDF/CDF
evaluation respect the bounds.
"""
dist_class = getattr(st, self.distribution)
if not frozen:
return dist_class
# Create frozen distribution with parameters
frozen_dist = dist_class(*self.parameters)
# Apply truncation if bounds are set
if self.lower_bound is not None or self.upper_bound is not None:
lb = self.lower_bound if self.lower_bound is not None else -np.inf
ub = self.upper_bound if self.upper_bound is not None else np.inf
return TruncatedFrozenDist(frozen_dist, lb, ub)
return frozen_dist
[docs]
def sample(self, size: int = DEFAULT_SAMPLE_SIZE, random_state: Optional[int] = None) -> np.ndarray:
"""Generate random samples from the fitted distribution.
Args:
size: Number of samples to generate
random_state: Random seed for reproducibility
Returns:
Array of random samples. If bounds are set, samples are
guaranteed to be within [lower_bound, upper_bound].
Example:
>>> result = fitter.fit(df, 'value').best(n=1)[0]
>>> samples = result.sample(size=10000, random_state=42)
"""
# get_scipy_dist() returns a frozen distribution, optionally truncated
frozen_dist = self.get_scipy_dist()
return frozen_dist.rvs(size=size, random_state=random_state)
[docs]
def pdf(self, x: np.ndarray) -> np.ndarray:
"""Evaluate probability density function at given points.
Args:
x: Points at which to evaluate PDF
Returns:
PDF values at x. If bounds are set, the PDF is normalized
to integrate to 1 over the bounded domain.
Example:
>>> result = fitter.fit(df, 'value').best(n=1)[0]
>>> x = np.linspace(0, 10, 100)
>>> y = result.pdf(x)
"""
# get_scipy_dist() returns a frozen distribution, optionally truncated
frozen_dist = self.get_scipy_dist()
return frozen_dist.pdf(x)
[docs]
def cdf(self, x: np.ndarray) -> np.ndarray:
"""Evaluate cumulative distribution function at given points.
Args:
x: Points at which to evaluate CDF
Returns:
CDF values at x. If bounds are set, the CDF is adjusted
for the truncated domain (0 at lower_bound, 1 at upper_bound).
"""
# get_scipy_dist() returns a frozen distribution, optionally truncated
frozen_dist = self.get_scipy_dist()
return frozen_dist.cdf(x)
[docs]
def ppf(self, q: np.ndarray) -> np.ndarray:
"""Evaluate percent point function (inverse CDF) at given quantiles.
Args:
q: Quantiles at which to evaluate PPF (0 to 1)
Returns:
PPF values at q. If bounds are set, values are guaranteed
to be within [lower_bound, upper_bound].
"""
# get_scipy_dist() returns a frozen distribution, optionally truncated
frozen_dist = self.get_scipy_dist()
return frozen_dist.ppf(q)
[docs]
def save(
self,
path: Union[str, Path],
format: Optional[Literal["json", "pickle"]] = None,
indent: Optional[int] = 2,
) -> None:
"""Save fitted distribution to file.
Serializes the distribution parameters and metrics to JSON or pickle format.
JSON is recommended for human-readable, version-safe output. Pickle is
available for faster serialization when human-readability is not required.
Args:
path: File path. Format is detected from extension if not specified.
format: Output format - 'json' (human-readable) or 'pickle'.
If None, detected from file extension (.json, .pkl, .pickle).
indent: JSON indentation level (default 2). Use None for compact output.
Ignored for pickle format.
Raises:
SerializationError: If format cannot be determined or write fails.
Example:
>>> best = results.best(n=1)[0]
>>> best.save("model.json")
>>> best.save("model.pkl", format="pickle")
>>> best.save("compact.json", indent=None)
"""
from spark_bestfit.serialization import detect_format, save_json, save_pickle, serialize_to_dict
path = Path(path)
file_format = format or detect_format(path)
if file_format == "json":
data = serialize_to_dict(self)
save_json(data, path, indent)
else: # pickle
save_pickle(self, path)
[docs]
@classmethod
def load(cls, path: Union[str, Path]) -> "DistributionFitResult":
"""Load fitted distribution from file.
Reconstructs a DistributionFitResult from a previously saved file.
The loaded result can be used for sampling, PDF/CDF evaluation, etc.
Args:
path: File path. Format is detected from extension (.json, .pkl, .pickle).
Returns:
Reconstructed DistributionFitResult
Raises:
SerializationError: If file format is invalid or distribution is unknown.
FileNotFoundError: If file does not exist.
Example:
>>> loaded = DistributionFitResult.load("model.json")
>>> samples = loaded.sample(n=1000)
>>> pdf_values = loaded.pdf(np.linspace(0, 100, 100))
Warning:
Only load pickle files from trusted sources.
"""
from spark_bestfit.serialization import deserialize_from_dict, detect_format, load_json, load_pickle
path = Path(path)
if not path.exists():
raise FileNotFoundError(f"File not found: {path}")
file_format = detect_format(path)
if file_format == "json":
data = load_json(path)
return deserialize_from_dict(data)
else: # pickle
return load_pickle(path)
[docs]
def get_param_names(self) -> List[str]:
"""Get parameter names for this distribution.
Returns:
List of parameter names in order matching self.parameters
Example:
>>> result = fitter.fit(df, 'value').best(n=1)[0]
>>> print(result.distribution)
'gamma'
>>> print(result.get_param_names())
['a', 'loc', 'scale']
>>> print(dict(zip(result.get_param_names(), result.parameters)))
{'a': 2.5, 'loc': 0.0, 'scale': 3.2}
"""
from spark_bestfit.distributions import DiscreteDistributionRegistry
from spark_bestfit.fitting import get_continuous_param_names
# Check if this is a discrete distribution
registry = DiscreteDistributionRegistry()
if self.distribution in registry.get_distributions():
config = registry.get_param_config(self.distribution)
return config["param_names"]
else:
# Continuous distribution
return get_continuous_param_names(self.distribution)
[docs]
def confidence_intervals(
self,
df,
column: str,
alpha: float = DEFAULT_PVALUE_THRESHOLD,
n_bootstrap: int = DEFAULT_BOOTSTRAP_SAMPLES,
max_samples: int = DEFAULT_MAX_SAMPLES,
random_seed: Optional[int] = None,
) -> Dict[str, Tuple[float, float]]:
"""Compute bootstrap confidence intervals for fitted parameters.
Uses the percentile bootstrap method: resample data with replacement,
refit the distribution, and compute confidence intervals from the
empirical distribution of fitted parameters.
Args:
df: DataFrame containing the data (Spark DataFrame, pandas DataFrame,
or Ray Dataset)
column: Column name containing the data
alpha: Significance level (default 0.05 for 95% CI)
n_bootstrap: Number of bootstrap samples (default 1000)
max_samples: Maximum rows to collect from DataFrame (default 10000)
random_seed: Random seed for reproducibility
Returns:
Dictionary mapping parameter names to (lower, upper) bounds
Example:
>>> result = fitter.fit(df, 'value').best(n=1)[0]
>>> ci = result.confidence_intervals(df, 'value', alpha=0.05, random_seed=42)
>>> print(result.distribution)
'gamma'
>>> for param, (lower, upper) in ci.items():
... print(f" {param}: [{lower:.4f}, {upper:.4f}]")
a: [2.35, 2.65]
loc: [-0.12, 0.08]
scale: [3.05, 3.35]
Note:
Bootstrap computation can be slow for large n_bootstrap values.
The default 1000 iterations provides reasonable precision.
"""
from spark_bestfit.discrete_fitting import bootstrap_discrete_confidence_intervals
from spark_bestfit.distributions import DiscreteDistributionRegistry
from spark_bestfit.fitting import bootstrap_confidence_intervals
# Sample data from DataFrame (supports Spark, pandas, and Ray backends)
total_rows = _get_dataframe_row_count(df)
if total_rows <= max_samples:
# Collect all rows
data = _collect_dataframe_column(df, column)
else:
# Sample rows
fraction = max_samples / total_rows
data = _sample_dataframe_column(df, column, fraction, random_seed)
# Check if this is a discrete distribution
registry = DiscreteDistributionRegistry()
if self.distribution in registry.get_distributions():
return bootstrap_discrete_confidence_intervals(
dist_name=self.distribution,
data=data.astype(int),
alpha=alpha,
n_bootstrap=n_bootstrap,
random_seed=random_seed,
)
else:
return bootstrap_confidence_intervals(
dist_name=self.distribution,
data=data,
alpha=alpha,
n_bootstrap=n_bootstrap,
random_seed=random_seed,
)
[docs]
def diagnostics(
self,
data: np.ndarray,
y_hist: Optional[np.ndarray] = None,
x_hist: Optional[np.ndarray] = None,
bins: int = DEFAULT_BINS,
title: str = "",
figsize: Tuple[int, int] = (14, 12),
dpi: int = DEFAULT_DPI,
title_fontsize: int = 16,
subplot_title_fontsize: int = 12,
label_fontsize: int = 10,
grid_alpha: float = 0.3,
save_path: Optional[str] = None,
save_format: str = "png",
):
"""Create a 2x2 diagnostic plot panel for assessing distribution fit quality.
Generates four diagnostic plots:
- Q-Q Plot (top-left): Compares sample quantiles vs theoretical quantiles
- P-P Plot (top-right): Compares empirical vs theoretical probabilities
- Residual Histogram (bottom-left): Distribution of fit residuals
- CDF Comparison (bottom-right): Empirical vs theoretical CDF overlay
Args:
data: Sample data array (1D numpy array)
y_hist: Optional pre-computed histogram density values. If None,
computed from data using specified bins.
x_hist: Optional pre-computed histogram bin centers. If None,
computed from data using specified bins.
bins: Number of histogram bins (used if y_hist/x_hist not provided)
title: Overall figure title
figsize: Figure size (width, height)
dpi: Dots per inch for saved figures
title_fontsize: Main title font size
subplot_title_fontsize: Subplot title font size
label_fontsize: Axis label font size
grid_alpha: Grid transparency (0-1)
save_path: Optional path to save figure
save_format: Save format (png, pdf, svg)
Returns:
Tuple of (figure, array of axes)
Example:
>>> result = fitter.fit(df, 'value').best(n=1)[0]
>>> fig, axes = result.diagnostics(data, title='Fit Diagnostics')
>>> plt.show()
"""
from spark_bestfit.plotting import plot_diagnostics
return plot_diagnostics(
result=self,
data=data,
y_hist=y_hist,
x_hist=x_hist,
bins=bins,
title=title,
figsize=figsize,
dpi=dpi,
title_fontsize=title_fontsize,
subplot_title_fontsize=subplot_title_fontsize,
label_fontsize=label_fontsize,
grid_alpha=grid_alpha,
save_path=save_path,
save_format=save_format,
)
def __repr__(self) -> str:
"""String representation of the result."""
param_str = ", ".join([f"{p:.4f}" for p in self.parameters])
aic_str = f"{self.aic:.2f}" if self.aic is not None else "None"
bic_str = f"{self.bic:.2f}" if self.bic is not None else "None"
ks_str = f"{self.ks_statistic:.6f}" if self.ks_statistic is not None else "None"
pval_str = f"{self.pvalue:.4f}" if self.pvalue is not None else "None"
ad_str = f"{self.ad_statistic:.6f}" if self.ad_statistic is not None else "None"
ad_pval_str = f"{self.ad_pvalue:.4f}" if self.ad_pvalue is not None else "None"
col_str = f"column_name='{self.column_name}', " if self.column_name else ""
# Build bounds string if set
bounds_parts = []
if self.lower_bound is not None:
bounds_parts.append(f"lower_bound={self.lower_bound:.4f}")
if self.upper_bound is not None:
bounds_parts.append(f"upper_bound={self.upper_bound:.4f}")
bounds_str = ", ".join(bounds_parts)
bounds_suffix = f", {bounds_str}" if bounds_str else ""
return (
f"DistributionFitResult({col_str}distribution='{self.distribution}', "
f"sse={self.sse:.6f}, aic={aic_str}, bic={bic_str}, "
f"ks_statistic={ks_str}, pvalue={pval_str}, "
f"ad_statistic={ad_str}, ad_pvalue={ad_pval_str}, "
f"parameters=[{param_str}]{bounds_suffix})"
)