Source code for spark_bestfit.backends.local

"""Local backend for testing and development without Spark.

This module provides the LocalBackend class that implements the ExecutionBackend
protocol using Python's concurrent.futures for parallel processing.

This backend is useful for:
- Unit testing without Spark dependency
- Development and debugging on small datasets
- Environments where Spark is not available

Example:
    >>> import pandas as pd
    >>> from spark_bestfit.backends.local import LocalBackend
    >>> from spark_bestfit import DistributionFitter
    >>>
    >>> backend = LocalBackend(max_workers=4)
    >>> fitter = DistributionFitter(backend=backend)
    >>> # Note: DataFrames are pandas DataFrames with LocalBackend
    >>> df = pd.DataFrame({'value': [1.0, 2.0, 3.0, ...]})
    >>> results = fitter.fit(df, column='value')
"""

import multiprocessing
import threading
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import Any, Callable, Dict, List, Optional, Tuple

import numpy as np
import pandas as pd


[docs] class LocalBackend: """Local backend using ThreadPoolExecutor for parallel distribution fitting. This backend runs distribution fitting locally using Python threads. It's primarily useful for testing and development without requiring a Spark cluster. Attributes: max_workers: Number of worker threads for parallel execution """ def __init__(self, max_workers: Optional[int] = None): """Initialize LocalBackend. Args: max_workers: Maximum number of worker threads. If None, uses the number of CPU cores. """ self.max_workers = max_workers or multiprocessing.cpu_count()
[docs] @staticmethod def broadcast(data: Any) -> Any: """No-op broadcast for local execution. In local mode, data is already accessible to all threads, so we simply return the data as-is. Args: data: Data to "broadcast" Returns: The same data (no transformation needed) """ return data
[docs] def destroy_broadcast(self, handle: Any) -> None: """No-op cleanup for local execution. Args: handle: Data reference (ignored) """ pass # Nothing to clean up in local mode
[docs] def parallel_fit( self, distributions: List[str], histogram: Tuple[np.ndarray, np.ndarray], data_sample: np.ndarray, fit_func: Callable[..., Dict[str, Any]], column_name: str, data_stats: Optional[Dict[str, float]] = None, num_partitions: Optional[int] = None, lower_bound: Optional[float] = None, upper_bound: Optional[float] = None, lazy_metrics: bool = False, is_discrete: bool = False, progress_callback: Optional[Callable[[int, int, float], None]] = None, custom_distributions: Optional[Dict[str, Any]] = None, estimation_method: str = "mle", censoring_indicator: Optional[np.ndarray] = None, ) -> List[Dict[str, Any]]: """Execute distribution fitting in parallel using threads. Uses ThreadPoolExecutor to fit distributions concurrently. Each distribution is fitted independently using the provided fit_func. Args: distributions: List of scipy distribution names to fit histogram: Tuple of (y_hist, bin_edges) for continuous or (x_values, pmf) for discrete distributions data_sample: Sample data array for MLE fitting fit_func: Pure Python fitting function to apply. For continuous distributions, this is fit_single_distribution. For discrete, use fit_single_discrete_distribution. column_name: Name of the source column data_stats: Optional dict with data_min, data_max, etc. num_partitions: Ignored (uses max_workers instead) lower_bound: Lower bound for truncated fitting upper_bound: Upper bound for truncated fitting lazy_metrics: If True, skip expensive KS/AD computation is_discrete: If True, use discrete distribution fitting progress_callback: Optional callback for progress updates. Called with (completed, total, percent) after each distribution. custom_distributions: Dict mapping custom distribution names to rv_continuous objects. (v2.4.0) estimation_method: Parameter estimation method (v2.5.0): - "mle": Maximum Likelihood Estimation (default) - "mse": Maximum Spacing Estimation (robust for heavy-tailed data) censoring_indicator: Boolean array where True=observed event, False=censored. When provided, uses censored MLE. (v2.9.0) Returns: List of fit result dicts (only successful fits, SSE < inf) """ if not distributions: return [] # Check for empty or invalid data sample if len(data_sample) == 0: return [] # Unpack histogram based on distribution type if is_discrete: x_values, empirical_pmf = histogram y_hist = None bin_edges = None else: y_hist, bin_edges = histogram x_values = None empirical_pmf = None def fit_one_distribution(dist_name: str) -> Dict[str, Any]: """Fit a single distribution (runs in thread pool).""" if is_discrete: # Import inside function to avoid circular imports from spark_bestfit.discrete_fitting import fit_single_discrete_distribution from spark_bestfit.distributions import DiscreteDistributionRegistry registry = DiscreteDistributionRegistry() return fit_single_discrete_distribution( dist_name=dist_name, data_sample=data_sample.astype(int), x_values=x_values, empirical_pmf=empirical_pmf, registry=registry, column_name=column_name, data_stats=data_stats, lower_bound=lower_bound, upper_bound=upper_bound, lazy_metrics=lazy_metrics, ) else: from spark_bestfit.fitting import fit_single_distribution return fit_single_distribution( dist_name=dist_name, data_sample=data_sample, bin_edges=bin_edges, y_hist=y_hist, column_name=column_name, data_stats=data_stats, lower_bound=lower_bound, upper_bound=upper_bound, lazy_metrics=lazy_metrics, custom_distributions=custom_distributions, estimation_method=estimation_method, censoring_indicator=censoring_indicator, ) # Execute in parallel using ThreadPoolExecutor results = [] total = len(distributions) completed = 0 lock = threading.Lock() # Thread safety for shared state with ThreadPoolExecutor(max_workers=self.max_workers) as executor: futures = {executor.submit(fit_one_distribution, d): d for d in distributions} # Use as_completed for progress tracking for future in as_completed(futures): try: result = future.result() # Filter failed fits (SSE = infinity) if result["sse"] < float(np.inf): results.append(result) except Exception: # Skip distributions that fail completely pass # Update progress with thread safety with lock: completed += 1 current_completed = completed if progress_callback is not None: percent = (current_completed / total) * 100.0 try: progress_callback(current_completed, total, percent) except Exception: pass # Don't let callback errors break fitting return results
[docs] def get_parallelism(self) -> int: """Get the number of worker threads. Returns: Number of parallel execution slots (max_workers) """ return self.max_workers
[docs] @staticmethod def collect_column(df: pd.DataFrame, column: str) -> np.ndarray: """Extract a column from pandas DataFrame as numpy array. Args: df: Pandas DataFrame column: Column name to extract Returns: Numpy array of column values """ return df[column].values
[docs] @staticmethod def get_column_stats(df: pd.DataFrame, column: str) -> Dict[str, float]: """Compute min, max, and count for a column. Args: df: Pandas DataFrame column: Column name Returns: Dict with keys: 'min', 'max', 'count' """ return { "min": float(df[column].min()), "max": float(df[column].max()), "count": len(df[column]), }
[docs] @staticmethod def sample_column( df: pd.DataFrame, column: str, fraction: float, seed: int, ) -> np.ndarray: """Sample a column and return as numpy array. Filters out NaN and infinite values before sampling to ensure clean data for distribution fitting. Args: df: Pandas DataFrame column: Column name fraction: Fraction to sample (0 < fraction <= 1) seed: Random seed for reproducibility Returns: Numpy array of sampled values (NaN/inf filtered) """ # Filter out NaN and inf values before sampling clean_df = df[[column]].replace([np.inf, -np.inf], np.nan).dropna() if len(clean_df) == 0: return np.array([]) sample_df = clean_df.sample(frac=fraction, random_state=seed) return sample_df[column].values
[docs] @staticmethod def create_dataframe( data: List[Tuple[Any, ...]], columns: List[str], ) -> pd.DataFrame: """Create a pandas DataFrame from local data. Args: data: List of row tuples columns: Column names Returns: Pandas DataFrame """ return pd.DataFrame(data, columns=columns)
# ========================================================================= # Copula and Histogram Methods (v2.0) # =========================================================================
[docs] @staticmethod def compute_correlation( df: pd.DataFrame, columns: List[str], method: str = "spearman", ) -> np.ndarray: """Compute correlation matrix using pandas. Args: df: Pandas DataFrame columns: List of column names to compute correlation for method: Correlation method ('spearman' or 'pearson') Returns: Correlation matrix as numpy array of shape (n_columns, n_columns) """ return df[columns].corr(method=method).values
[docs] @staticmethod def compute_histogram( df: pd.DataFrame, column: str, bin_edges: np.ndarray, ) -> Tuple[np.ndarray, int]: """Compute histogram bin counts using numpy. Args: df: Pandas DataFrame column: Column to histogram bin_edges: Array of bin edge values (n_bins + 1 values) Returns: Tuple of (bin_counts, total_count) where bin_counts is an array of counts for each bin """ data = df[column].dropna().values bin_counts, _ = np.histogram(data, bins=bin_edges) total_count = int(bin_counts.sum()) return bin_counts.astype(float), total_count
[docs] def generate_samples( self, n: int, generator_func: Callable[[int, int, Optional[int]], Dict[str, np.ndarray]], column_names: List[str], num_partitions: Optional[int] = None, random_seed: Optional[int] = None, ) -> pd.DataFrame: """Generate samples locally. Unlike SparkBackend, this generates all samples in a single call since there's no distributed cluster to leverage. Args: n: Total number of samples to generate generator_func: Function(n_samples, partition_id, seed) -> Dict[col, array] that generates samples for one partition column_names: Names of columns in output (for interface compatibility) num_partitions: Ignored (no partitioning in local mode) random_seed: Random seed for reproducibility Returns: Pandas DataFrame with generated samples """ # Generate all samples in one call (partition_id=0) samples = generator_func(n, 0, random_seed) return pd.DataFrame(samples)