Source code for spark_bestfit.sampling

"""Distributed sampling for fitted distributions.

This module provides functions for generating samples from fitted distributions
using the backend abstraction for distributed or local execution.
"""

from typing import TYPE_CHECKING, Any, Dict, List, Optional

import numpy as np
import scipy.stats as st

if TYPE_CHECKING:
    from spark_bestfit.protocols import ExecutionBackend


[docs] def sample_distributed( distribution: str, parameters: List[float], n: int, backend: "ExecutionBackend", num_partitions: Optional[int] = None, random_seed: Optional[int] = None, column_name: str = "sample", ) -> Any: """Generate samples from a fitted distribution using backend abstraction. Uses the backend's parallelism to generate samples, enabling generation of millions of samples efficiently with SparkBackend or local execution with LocalBackend. Args: distribution: scipy.stats distribution name (e.g., "norm", "expon") parameters: Distribution parameters (shape, loc, scale) n: Total number of samples to generate backend: Execution backend (SparkBackend, LocalBackend, etc.) num_partitions: Number of partitions to use. Defaults to backend parallelism. random_seed: Random seed for reproducibility. Each partition uses seed + partition_id. column_name: Name for the output column (default: "sample") Returns: Backend-specific DataFrame with single column containing samples (Spark DataFrame for SparkBackend, pandas DataFrame for LocalBackend) Example: >>> from spark_bestfit.backends.spark import SparkBackend >>> backend = SparkBackend(spark) >>> df = sample_distributed("norm", [0.0, 1.0], n=1_000_000, backend=backend) >>> df.show(5) +-------------------+ | sample| +-------------------+ | 0.4691122931291924| |-0.2828633018445851| | 1.0093545783546243| +-------------------+ """ # Get distribution object from scipy dist = getattr(st, distribution) def generate_distribution_samples( n_samples: int, partition_id: int, seed: Optional[int], ) -> Dict[str, np.ndarray]: """Generate samples from the distribution. This function is passed to the backend's generate_samples method. """ if seed is not None: rng = np.random.default_rng(seed) samples = dist.rvs(*parameters, size=n_samples, random_state=rng) else: samples = dist.rvs(*parameters, size=n_samples) return {column_name: samples} return backend.generate_samples( n=n, generator_func=generate_distribution_samples, column_names=[column_name], num_partitions=num_partitions, random_seed=random_seed, )