"""Apache Spark backend for distributed distribution fitting.
This module provides the SparkBackend class that implements the ExecutionBackend
protocol using Apache Spark's Pandas UDFs for parallel processing.
Example:
>>> from pyspark.sql import SparkSession
>>> from spark_bestfit.backends.spark import SparkBackend
>>> from spark_bestfit import DistributionFitter
>>>
>>> spark = SparkSession.builder.getOrCreate()
>>> backend = SparkBackend(spark)
>>> fitter = DistributionFitter(backend=backend)
"""
from typing import Any, Callable, Dict, List, Optional, Tuple
import numpy as np
import pyspark.sql.functions as F
from pyspark.sql import DataFrame, SparkSession
from spark_bestfit.utils import get_spark_session
[docs]
class SparkBackend:
"""Apache Spark backend using Pandas UDFs for parallel distribution fitting.
This is the default backend for spark-bestfit. It uses Spark's broadcast
variables for efficient data sharing and Pandas UDFs for vectorized
distribution fitting across the cluster.
Attributes:
spark: The SparkSession instance used for distributed operations
"""
def __init__(self, spark: Optional[SparkSession] = None):
"""Initialize SparkBackend.
Args:
spark: SparkSession instance. If None, attempts to get the active
session or create a new one.
Raises:
RuntimeError: If no SparkSession provided and no active session exists
"""
self.spark = get_spark_session(spark)
[docs]
def broadcast(self, data: Any) -> Any:
"""Broadcast data to all Spark executors.
Creates a read-only variable cached on each worker node. This is
essential for sharing histogram and sample data efficiently without
sending copies with each task.
Args:
data: Data to broadcast (numpy arrays, tuples, etc.)
Returns:
Spark Broadcast object wrapping the data
"""
return self.spark.sparkContext.broadcast(data)
[docs]
@staticmethod
def destroy_broadcast(handle: Any) -> None:
"""Release broadcast variable from executor memory.
Uses unpersist() rather than destroy() because Spark's lazy evaluation
may still reference the broadcast in pending operations.
Args:
handle: Broadcast variable returned by broadcast()
"""
if handle is not None:
handle.unpersist()
[docs]
def parallel_fit(
self,
distributions: List[str],
histogram: Tuple[np.ndarray, np.ndarray],
data_sample: np.ndarray,
fit_func: Callable[..., Dict[str, Any]],
column_name: str,
data_stats: Optional[Dict[str, float]] = None,
num_partitions: Optional[int] = None,
lower_bound: Optional[float] = None,
upper_bound: Optional[float] = None,
lazy_metrics: bool = False,
is_discrete: bool = False,
progress_callback: Optional[Callable[[int, int, float], None]] = None,
custom_distributions: Optional[Dict[str, Any]] = None,
estimation_method: str = "mle",
censoring_indicator: Optional[np.ndarray] = None,
) -> List[Dict[str, Any]]:
"""Execute distribution fitting in parallel using Pandas UDFs.
This method encapsulates all Spark-specific operations for fitting:
1. Broadcasts histogram and sample data to executors
2. Creates a DataFrame of distribution names
3. Applies the fitting UDF to compute results in parallel
4. Collects and returns results
Args:
distributions: List of scipy distribution names to fit
histogram: Tuple of (y_hist, bin_edges) for continuous or
(x_values, pmf) for discrete distributions
data_sample: Sample data array for MLE fitting
fit_func: Pure Python fitting function (not used directly here;
we use the Pandas UDF factories instead)
column_name: Name of the source column
data_stats: Optional dict with data_min, data_max, etc.
num_partitions: Number of partitions (None = auto)
lower_bound: Lower bound for truncated fitting
upper_bound: Upper bound for truncated fitting
lazy_metrics: If True, skip expensive KS/AD computation
is_discrete: If True, use discrete distribution fitting
progress_callback: Optional callback for progress updates.
Called with (completed_tasks, total_tasks, percent) at the
Spark task level via StatusTracker polling.
custom_distributions: Dict mapping custom distribution names to
rv_continuous objects. These are broadcasted to executors
for fitting custom distributions. (v2.4.0)
estimation_method: Parameter estimation method (v2.5.0):
- "mle": Maximum Likelihood Estimation (default)
- "mse": Maximum Spacing Estimation (robust for heavy-tailed data)
censoring_indicator: Boolean array where True=observed event,
False=censored. When provided, uses censored MLE. (v2.9.0)
Returns:
List of fit result dicts
"""
# Handle empty distribution list
if not distributions:
return []
# Start progress tracking if callback provided
tracker = None
if progress_callback is not None:
from spark_bestfit.progress import ProgressTracker
tracker = ProgressTracker(self.spark, progress_callback)
tracker.start()
# Broadcast data to executors
histogram_bc = self.broadcast(histogram)
data_sample_bc = self.broadcast(data_sample)
custom_dist_bc = self.broadcast(custom_distributions) if custom_distributions else None
censoring_bc = self.broadcast(censoring_indicator) if censoring_indicator is not None else None
try:
# Create DataFrame of distributions
dist_df = self.create_dataframe(
data=[(d,) for d in distributions],
columns=["distribution_name"],
)
# Repartition for optimal parallelism
n_partitions = num_partitions or self._calculate_partitions(distributions)
dist_df = dist_df.repartition(n_partitions)
# Create and apply appropriate fitting UDF
if is_discrete:
from spark_bestfit.discrete_fitting import create_discrete_fitting_udf
fitting_udf = create_discrete_fitting_udf(
histogram_bc,
data_sample_bc,
column_name=column_name,
data_stats=data_stats,
lower_bound=lower_bound,
upper_bound=upper_bound,
lazy_metrics=lazy_metrics,
)
else:
from spark_bestfit.fitting import create_fitting_udf
fitting_udf = create_fitting_udf(
histogram_bc,
data_sample_bc,
column_name=column_name,
data_stats=data_stats,
lower_bound=lower_bound,
upper_bound=upper_bound,
lazy_metrics=lazy_metrics,
custom_distributions_broadcast=custom_dist_bc,
estimation_method=estimation_method,
censoring_indicator_broadcast=censoring_bc,
)
# Apply UDF and expand struct
results_df = dist_df.select(fitting_udf(F.col("distribution_name")).alias("result")).select("result.*")
# Filter failed fits (SSE = infinity)
results_df = results_df.filter(F.col("sse") < float(np.inf))
# Collect results to driver
return [row.asDict() for row in results_df.collect()]
finally:
# Stop progress tracking
if tracker is not None:
tracker.stop()
# Always clean up broadcast variables
self.destroy_broadcast(histogram_bc)
self.destroy_broadcast(data_sample_bc)
if custom_dist_bc is not None:
self.destroy_broadcast(custom_dist_bc)
if censoring_bc is not None:
self.destroy_broadcast(censoring_bc)
[docs]
def get_parallelism(self) -> int:
"""Get the default parallelism from Spark configuration.
Returns the total number of cores available across the cluster,
which is used to determine optimal partition counts.
Returns:
Number of available parallel execution slots
"""
return self.spark.sparkContext.defaultParallelism
[docs]
@staticmethod
def collect_column(df: DataFrame, column: str) -> np.ndarray:
"""Collect a single column from Spark DataFrame as numpy array.
Warning: This collects data to the driver node. Use sparingly
for large datasets.
Args:
df: Spark DataFrame
column: Column name to collect
Returns:
Numpy array of column values
"""
return df.select(column).toPandas()[column].values
[docs]
@staticmethod
def get_column_stats(df: DataFrame, column: str) -> Dict[str, float]:
"""Compute min, max, and count for a column in a single pass.
Uses Spark aggregations to compute statistics efficiently without
collecting all data to the driver.
Args:
df: Spark DataFrame
column: Column name
Returns:
Dict with keys: 'min', 'max', 'count'. Values are NaN for
empty DataFrames or columns with all null values, ensuring
consistent return type with LocalBackend and RayBackend.
"""
stats = df.agg(
F.min(column).alias("min"),
F.max(column).alias("max"),
F.count(column).alias("count"),
).first()
return {
"min": float(stats["min"]) if stats["min"] is not None else float("nan"),
"max": float(stats["max"]) if stats["max"] is not None else float("nan"),
"count": int(stats["count"]),
}
[docs]
@staticmethod
def sample_column(
df: DataFrame,
column: str,
fraction: float,
seed: int,
) -> np.ndarray:
"""Sample a column and collect as numpy array.
Performs distributed sampling before collection, reducing the amount
of data transferred to the driver.
Args:
df: Spark DataFrame
column: Column name
fraction: Fraction to sample (0 < fraction <= 1)
seed: Random seed for reproducibility
Returns:
Numpy array of sampled values
"""
sample_df = df.select(column).sample(fraction=fraction, seed=seed)
return sample_df.toPandas()[column].values
[docs]
def create_dataframe(
self,
data: List[Tuple[Any, ...]],
columns: List[str],
) -> DataFrame:
"""Create a Spark DataFrame from local data.
Used internally to create the distribution name DataFrame for
parallel fitting.
Args:
data: List of row tuples
columns: Column names
Returns:
Spark DataFrame
"""
return self.spark.createDataFrame(data, columns)
def _calculate_partitions(self, distributions: List[str]) -> int:
"""Calculate optimal partition count for distribution fitting.
Uses distribution-aware weighting where slow distributions count
as 3x for partition calculation to reduce straggler effects.
Args:
distributions: List of distribution names to fit
Returns:
Optimal partition count
"""
from spark_bestfit.distributions import DistributionRegistry
slow_set = DistributionRegistry.SLOW_DISTRIBUTIONS
slow_count = sum(1 for d in distributions if d in slow_set)
# Slow distributions count 3x (1 base + 2 extra)
effective_count = len(distributions) + slow_count * 2
total_cores = self.get_parallelism()
return min(effective_count, total_cores * 2)
# =========================================================================
# Copula and Histogram Methods (v2.0)
# =========================================================================
[docs]
@staticmethod
def compute_correlation(
df: DataFrame,
columns: List[str],
method: str = "spearman",
) -> np.ndarray:
"""Compute correlation matrix using Spark ML.
Uses distributed computation via Spark ML's Correlation, enabling
correlation computation on DataFrames with billions of rows without
collecting data to the driver.
Args:
df: Spark DataFrame
columns: List of column names to compute correlation for
method: Correlation method ('spearman' or 'pearson')
Returns:
Correlation matrix as numpy array of shape (n_columns, n_columns)
"""
from pyspark.ml.feature import VectorAssembler
from pyspark.ml.stat import Correlation
# Assemble columns into a vector
assembler = VectorAssembler(
inputCols=columns,
outputCol="_corr_features",
handleInvalid="skip", # Skip rows with nulls
)
vector_df = assembler.transform(df).select("_corr_features")
# Compute correlation using Spark ML
corr_result = Correlation.corr(vector_df, "_corr_features", method=method)
# Extract correlation matrix from result
corr_matrix = corr_result.head()[0].toArray()
return corr_matrix
[docs]
@staticmethod
def compute_histogram(
df: DataFrame,
column: str,
bin_edges: np.ndarray,
) -> Tuple[np.ndarray, int]:
"""Compute histogram using distributed Bucketizer and groupBy.
This is the key optimization: uses Spark ML's Bucketizer to assign
each row to a bin, then uses groupBy to count rows per bin. All
computation happens in the cluster without collecting data.
Args:
df: Spark DataFrame
column: Column to histogram
bin_edges: Array of bin edge values (n_bins + 1 values)
Returns:
Tuple of (bin_counts, total_count) where bin_counts is an array
of counts for each bin
"""
from pyspark.ml.feature import Bucketizer
# Create temp column name to avoid conflicts
temp_col = f"__{column}_bin_temp__"
# Use Bucketizer to assign bin IDs
bucketizer = Bucketizer(
splits=bin_edges.tolist(),
inputCol=column,
outputCol=temp_col,
handleInvalid="keep", # Keep invalid values in a special bin
)
# Transform and aggregate
bucketed = bucketizer.transform(df)
histogram = bucketed.groupBy(temp_col).count().withColumnRenamed(temp_col, "bin_id")
# Collect ONLY the aggregated histogram (small data)
hist_data = histogram.orderBy("bin_id").collect()
# Extract counts (fill missing bins with zeros)
bin_counts = np.zeros(len(bin_edges) - 1)
total_count = 0
for row in hist_data:
bin_id = row["bin_id"]
count = row["count"]
# Skip None bin_id (can occur with handleInvalid="keep" for out-of-range values)
if bin_id is not None:
bin_id = int(bin_id)
if 0 <= bin_id < len(bin_counts):
bin_counts[bin_id] = count
total_count += count
return bin_counts, total_count
[docs]
def generate_samples(
self,
n: int,
generator_func: Callable[[int, int, Optional[int]], Dict[str, np.ndarray]],
column_names: List[str],
num_partitions: Optional[int] = None,
random_seed: Optional[int] = None,
) -> DataFrame:
"""Generate samples distributed across Spark partitions.
Uses mapInPandas to generate samples in each partition, enabling
generation of millions of samples distributed across the cluster.
Args:
n: Total number of samples to generate
generator_func: Function(n_samples, partition_id, seed) -> Dict[col, array]
that generates samples for one partition
column_names: Names of columns in output
num_partitions: Number of partitions (None = default parallelism)
random_seed: Base random seed (partition_id added for uniqueness)
Returns:
Spark DataFrame with generated samples
"""
import pandas as pd
from pyspark.sql.types import DoubleType, IntegerType, StructField, StructType
if num_partitions is None:
num_partitions = self.get_parallelism()
# Calculate samples per partition
base_samples = n // num_partitions
remainder = n % num_partitions
# Create partition info DataFrame
partition_data = []
for i in range(num_partitions):
samples_for_partition = base_samples + (1 if i < remainder else 0)
if samples_for_partition > 0:
partition_data.append((i, samples_for_partition))
partition_df = self.spark.createDataFrame(
partition_data,
StructType(
[
StructField("partition_id", IntegerType(), False),
StructField("n_samples", IntegerType(), False),
]
),
)
# Repartition to ensure parallelism
partition_df = partition_df.repartition(len(partition_data))
# Define output schema
output_fields = [StructField(col, DoubleType(), False) for col in column_names]
output_schema = StructType(output_fields)
# Create the mapInPandas function
def generate_partition_samples(iterator):
"""Generate samples for each partition."""
for pdf in iterator:
if len(pdf) == 0:
continue
for idx in range(len(pdf)):
n_samples = int(pdf.iloc[idx]["n_samples"])
partition_id = int(pdf.iloc[idx]["partition_id"])
# Compute seed for this partition
seed = None
if random_seed is not None:
seed = random_seed + partition_id
# Generate samples using the provided function
samples = generator_func(n_samples, partition_id, seed)
yield pd.DataFrame(samples)
# Apply the generator
result_df = partition_df.mapInPandas(
generate_partition_samples,
schema=output_schema,
)
return result_df