"""Distribution registry and management for scipy.stats distributions."""
from typing import Any, Dict, List, Optional, Set, Tuple
import numpy as np
import scipy.stats as st
from scipy.stats import rv_continuous
[docs]
class DistributionRegistry:
"""Registry for managing scipy.stats continuous distributions.
Handles filtering of distributions based on exclusions and support constraints.
All scipy.stats continuous distributions are available by default, with
sensible exclusions for slow-computing distributions.
Example:
>>> registry = DistributionRegistry()
>>> distributions = registry.get_distributions()
>>> len(distributions)
~100
>>> # Only non-negative distributions
>>> pos_distributions = registry.get_distributions(support_at_zero=True)
>>> # Add custom exclusions
>>> distributions = registry.get_distributions(
... additional_exclusions=["ncf", "ncx2"]
... )
"""
# Default exclusions: distributions that are very slow or numerically unstable
DEFAULT_EXCLUSIONS = {
"levy_stable", # Extremely slow - MLE doesn't always converge
"kappa4", # Extremely slow
"ncx2", # Slow - non-central chi-squared
"ksone", # Slow - Kolmogorov-Smirnov one-sided
"ncf", # Slow - non-central F
"wald", # Sometimes numerically unstable
"mielke", # Slow
"exponpow", # Slow - exponential power
"studentized_range", # Very slow - scipy docs recommend approximation
"gausshyper", # Very slow - Gauss hypergeometric
"geninvgauss", # Can hang - generalized inverse Gaussian
"genhyperbolic", # Slow - generalized hyperbolic
"kstwo", # Slow - Kolmogorov-Smirnov two-sided
"kstwobign", # Slow - KS limit distribution
"recipinvgauss", # Can be slow
"vonmises", # Can be slow on fitting
"vonmises_line", # Can be slow on fitting
"tukeylambda", # Extremely slow (~7s) - ill-conditioned optimization
"nct", # Very slow (~1.4s) - non-central t distribution
"dpareto_lognorm", # Slow (~0.5s) - double Pareto-lognormal
}
# Distributions that are included but noticeably slower (~3-5x average).
# Used for partition weighting when fitting with all distributions.
SLOW_DISTRIBUTIONS: Set[str] = {
"powerlognorm", # ~160ms
"norminvgauss", # ~150ms
"t", # ~144ms
"pearson3", # ~141ms
"exponweib", # ~136ms
"johnsonsb", # ~133ms
"jf_skew_t", # ~125ms
"fisk", # ~120ms
"gengamma", # ~120ms
"johnsonsu", # ~106ms
"burr", # ~105ms
"burr12", # ~105ms
"truncweibull_min", # ~104ms
"invweibull", # ~92ms
"rice", # ~91ms
"genexpon", # ~89ms
}
# All scipy continuous distributions
ALL_DISTRIBUTIONS = [name for name in dir(st) if isinstance(getattr(st, name), st.rv_continuous)]
def __init__(self, custom_exclusions: Optional[Set[str]] = None):
"""Initialize the distribution registry.
Args:
custom_exclusions: Optional set of distribution names to exclude
(replaces default exclusions if provided)
"""
self._excluded = custom_exclusions if custom_exclusions is not None else self.DEFAULT_EXCLUSIONS.copy()
self._custom_distributions: Dict[str, rv_continuous] = {}
[docs]
def get_distributions(
self,
support_at_zero: bool = False,
additional_exclusions: Optional[List[str]] = None,
include_custom: bool = True,
) -> List[str]:
"""Get filtered list of distributions based on criteria.
Args:
support_at_zero: If True, only include distributions with support at zero
(non-negative distributions)
additional_exclusions: Additional distribution names to exclude
include_custom: If True, include registered custom distributions (default True)
Returns:
List of distribution names meeting the criteria
Example:
>>> registry = DistributionRegistry()
>>> # Get all non-excluded distributions
>>> dists = registry.get_distributions()
>>> # Get only non-negative distributions
>>> pos_dists = registry.get_distributions(support_at_zero=True)
>>> # Exclude more distributions
>>> filtered = registry.get_distributions(
... additional_exclusions=["norm", "expon"]
... )
>>> # Register and include custom distributions
>>> registry.register_distribution("my_custom", MyCustomDistribution())
>>> dists = registry.get_distributions() # Includes "my_custom"
"""
# Start with excluded set
excluded = self._excluded.copy()
# Add any additional exclusions
if additional_exclusions:
excluded.update(additional_exclusions)
# Filter out excluded distributions
distributions = [d for d in self.ALL_DISTRIBUTIONS if d not in excluded]
# Add custom distributions (exclude if in exclusion list)
if include_custom:
for name in self._custom_distributions:
if name not in excluded:
distributions.append(name)
# Filter by support if requested
if support_at_zero:
distributions = [d for d in distributions if self._has_support_at_zero(d)]
return distributions
def _has_support_at_zero(self, dist_name: str) -> bool:
"""Check if a distribution has support at zero (non-negative).
Args:
dist_name: Name of the scipy distribution or custom distribution
Returns:
True if distribution support starts at 0 or greater
"""
try:
# Check custom distributions first
if dist_name in self._custom_distributions:
dist = self._custom_distributions[dist_name]
else:
dist = getattr(st, dist_name)
return dist.a >= 0
except (AttributeError, TypeError):
# If we can't determine, exclude it to be safe
return False
[docs]
def add_exclusion(self, dist_name: str) -> None:
"""Add a distribution to the exclusion list.
Args:
dist_name: Name of the distribution to exclude
"""
self._excluded.add(dist_name)
[docs]
def remove_exclusion(self, dist_name: str) -> None:
"""Remove a distribution from the exclusion list.
Args:
dist_name: Name of the distribution to include
"""
self._excluded.discard(dist_name)
[docs]
def get_exclusions(self) -> Set[str]:
"""Get current set of excluded distributions.
Returns:
Set of excluded distribution names
"""
return self._excluded.copy()
[docs]
def reset_exclusions(self) -> None:
"""Reset exclusions to default set."""
self._excluded = self.DEFAULT_EXCLUSIONS.copy()
# =========================================================================
# Custom Distribution Support (v2.4.0)
# =========================================================================
[docs]
def register_distribution(
self,
name: str,
distribution: rv_continuous,
overwrite: bool = False,
) -> None:
"""Register a custom distribution for fitting.
Custom distributions must implement the scipy rv_continuous interface,
specifically the fit(), pdf(), and cdf() methods. The distribution
will be included in fitting alongside scipy.stats distributions.
Args:
name: Unique name for the distribution (used in results)
distribution: scipy rv_continuous instance or subclass.
Must implement fit(), pdf(), cdf() methods.
overwrite: If True, overwrite existing distribution with same name.
Default False raises ValueError if name exists.
Raises:
ValueError: If name already exists (and overwrite=False) or
conflicts with a scipy.stats distribution name
TypeError: If distribution doesn't implement required interface
Example:
>>> from scipy.stats import rv_continuous
>>>
>>> class PowerDistribution(rv_continuous):
... def _pdf(self, x, alpha):
... return alpha * x ** (alpha - 1)
... def _cdf(self, x, alpha):
... return x ** alpha
>>>
>>> registry = DistributionRegistry()
>>> registry.register_distribution("power", PowerDistribution(a=0, b=1))
>>> distributions = registry.get_distributions()
>>> "power" in distributions
True
"""
# Validate name doesn't conflict with scipy.stats
if name in self.ALL_DISTRIBUTIONS:
raise ValueError(
f"Cannot register '{name}': conflicts with scipy.stats distribution. " f"Use a different name."
)
# Check for existing registration
if not overwrite and name in self._custom_distributions:
raise ValueError(f"Distribution '{name}' already registered. " f"Use overwrite=True to replace it.")
# Validate distribution interface
required_methods = ["fit", "pdf", "cdf"]
missing = [m for m in required_methods if not hasattr(distribution, m)]
if missing:
raise TypeError(f"Distribution must implement {required_methods}. " f"Missing: {missing}")
# Validate it's an rv_continuous-like object
if not isinstance(distribution, rv_continuous):
raise TypeError(
f"Distribution must be a scipy rv_continuous subclass. " f"Got: {type(distribution).__name__}"
)
self._custom_distributions[name] = distribution
[docs]
def unregister_distribution(self, name: str) -> None:
"""Remove a custom distribution from the registry.
Args:
name: Name of the custom distribution to remove
Raises:
KeyError: If distribution not found in registry
"""
if name not in self._custom_distributions:
raise KeyError(f"Custom distribution '{name}' not found in registry")
del self._custom_distributions[name]
[docs]
def get_distribution_object(self, name: str) -> rv_continuous:
"""Get a distribution object by name.
Looks up both scipy.stats built-in distributions and registered
custom distributions.
Args:
name: Distribution name (scipy.stats name or custom registered name)
Returns:
scipy rv_continuous distribution object
Raises:
ValueError: If distribution not found
Example:
>>> registry = DistributionRegistry()
>>> norm_dist = registry.get_distribution_object("norm")
>>> # Also works for custom distributions
>>> registry.register_distribution("custom", MyDist())
>>> my_dist = registry.get_distribution_object("custom")
"""
# Check custom distributions first (allows overriding built-ins conceptually)
if name in self._custom_distributions:
return self._custom_distributions[name]
# Check scipy.stats
if hasattr(st, name) and isinstance(getattr(st, name), rv_continuous):
return getattr(st, name)
raise ValueError(
f"Distribution '{name}' not found. " f"Not a scipy.stats distribution or registered custom distribution."
)
[docs]
def get_custom_distributions(self) -> Dict[str, rv_continuous]:
"""Get a copy of all registered custom distributions.
Returns:
Dict mapping distribution names to rv_continuous objects
Note:
Returns a shallow copy - modifying the dict won't affect the registry,
but modifying distribution objects will.
"""
return self._custom_distributions.copy()
[docs]
def has_custom_distributions(self) -> bool:
"""Check if any custom distributions are registered.
Returns:
True if at least one custom distribution is registered
"""
return len(self._custom_distributions) > 0
[docs]
class DiscreteDistributionRegistry:
"""Registry for managing scipy.stats discrete distributions.
Unlike continuous distributions, discrete distributions in scipy do not have
a built-in fit() method. This registry provides parameter configuration
(initial values, bounds, estimation functions) needed for MLE fitting via
optimization.
Example:
>>> registry = DiscreteDistributionRegistry()
>>> distributions = registry.get_distributions()
>>> len(distributions)
~15
>>> # Get parameter config for fitting
>>> config = registry.get_param_config("poisson")
>>> initial = config["initial"](data)
>>> bounds = config["bounds"](data)
"""
# Default exclusions: distributions that are slow, require special handling,
# or have complex parameter constraints
DEFAULT_EXCLUSIONS = {
"nchypergeom_fisher", # Very slow - non-central hypergeometric
"nchypergeom_wallenius", # Very slow - non-central hypergeometric
"randint", # Discrete uniform - trivial, not useful for fitting
"bernoulli", # Special case of binomial with n=1
"poisson_binom", # Poisson binomial - complex, requires list of probabilities
}
# All scipy discrete distributions
ALL_DISTRIBUTIONS = [name for name in dir(st) if isinstance(getattr(st, name), st.rv_discrete)]
def __init__(self, custom_exclusions: Optional[Set[str]] = None):
"""Initialize the discrete distribution registry.
Args:
custom_exclusions: Optional set of distribution names to exclude
(replaces default exclusions if provided)
"""
self._excluded = custom_exclusions if custom_exclusions is not None else self.DEFAULT_EXCLUSIONS.copy()
self._param_configs = self._build_param_configs()
@staticmethod
def _build_param_configs() -> Dict[str, Dict[str, Any]]:
"""Build parameter configurations for each discrete distribution.
Each config contains:
- initial: Callable[[np.ndarray], List[float]] - initial parameter estimates
- bounds: Callable[[np.ndarray], List[Tuple[float, float]]] - parameter bounds
- param_names: List[str] - names of parameters
Returns:
Dictionary mapping distribution names to their parameter configs
"""
configs: Dict[str, Dict[str, Any]] = {}
# Poisson: λ (mu) - rate parameter
# MLE: λ = mean(data)
configs["poisson"] = {
"param_names": ["mu"],
"initial": lambda data: [max(np.mean(data), 0.1)],
"bounds": lambda data: [(1e-6, max(np.mean(data) * 10, 100))],
}
# Geometric: p - probability of success
# MLE: p = 1 / mean(data) for scipy's parameterization (starting from 1)
configs["geom"] = {
"param_names": ["p"],
"initial": lambda data: [min(1 / max(np.mean(data), 1), 0.99)],
"bounds": lambda data: [(1e-6, 1 - 1e-6)],
}
# Binomial: n (trials), p (probability)
# Estimation: n ≈ max(data), p ≈ mean(data) / n
configs["binom"] = {
"param_names": ["n", "p"],
"initial": lambda data: [
max(int(np.max(data)) + 5, int(np.mean(data) + 3 * np.std(data))),
np.clip(np.mean(data) / max(np.max(data), 1), 0.01, 0.99),
],
"bounds": lambda data: [
(max(int(np.max(data)), 1), max(int(np.max(data)) * 3, 100)),
(1e-6, 1 - 1e-6),
],
}
# Negative Binomial: n (successes), p (probability)
# Method of moments: p = mean / (mean + var/mean), n = mean * p / (1-p)
def nbinom_initial(data: np.ndarray) -> List[float]:
mean_val = max(np.mean(data), 0.1)
var_val = max(np.var(data), mean_val + 0.1) # Ensure overdispersion
p = np.clip(mean_val / var_val, 0.01, 0.99)
n = max(mean_val * p / (1 - p), 0.1)
return [n, p]
configs["nbinom"] = {
"param_names": ["n", "p"],
"initial": nbinom_initial,
"bounds": lambda data: [(1e-2, 1000), (1e-6, 1 - 1e-6)],
}
# Zipf: a - shape parameter (a > 1)
# Initial estimate based on log-log regression slope
configs["zipf"] = {
"param_names": ["a"],
"initial": lambda data: [2.0], # Common default
"bounds": lambda data: [(1.0 + 1e-6, 10.0)],
}
# Zipfian: a, n - generalized Zipf
configs["zipfian"] = {
"param_names": ["a", "n"],
"initial": lambda data: [1.5, int(np.max(data)) + 1],
"bounds": lambda data: [(0.0, 10.0), (int(np.max(data)), int(np.max(data)) * 2 + 10)],
}
# Hypergeometric: M (population), n (success states), N (draws)
# Complex constraints: n <= M, N <= M, max(data) <= min(n, N)
def hypergeom_initial(data: np.ndarray) -> List[float]:
max_val = int(np.max(data))
mean_val = np.mean(data)
# Rough estimates
N = max(max_val + 5, int(mean_val * 2)) # draws
n = max(max_val + 10, N) # success states
M = max(n + N, int(n * 2)) # population
return [M, n, N]
def hypergeom_bounds(data: np.ndarray) -> List[Tuple[float, float]]:
max_val = int(np.max(data))
return [
(max_val + 10, 10000), # M (population)
(max_val + 1, 5000), # n (success states)
(max_val + 1, 5000), # N (draws)
]
configs["hypergeom"] = {
"param_names": ["M", "n", "N"],
"initial": hypergeom_initial,
"bounds": hypergeom_bounds,
}
# Beta-Binomial: n, a, b
configs["betabinom"] = {
"param_names": ["n", "a", "b"],
"initial": lambda data: [int(np.max(data)) + 5, 1.0, 1.0],
"bounds": lambda data: [
(int(np.max(data)), int(np.max(data)) * 3 + 10),
(1e-2, 100),
(1e-2, 100),
],
}
# Beta-Negative Binomial: n, a, b
configs["betanbinom"] = {
"param_names": ["n", "a", "b"],
"initial": lambda data: [max(np.mean(data), 1), 1.0, 1.0],
"bounds": lambda data: [(1e-2, 1000), (1e-2, 100), (1e-2, 100)],
}
# Boltzmann: lambda, N
configs["boltzmann"] = {
"param_names": ["lambda", "N"],
"initial": lambda data: [1.0, int(np.max(data)) + 1],
"bounds": lambda data: [(1e-6, 100), (int(np.max(data)) + 1, int(np.max(data)) * 2 + 10)],
}
# Discrete Laplacian: a
configs["dlaplace"] = {
"param_names": ["a"],
"initial": lambda data: [0.5],
"bounds": lambda data: [(1e-6, 10.0)],
}
# Logarithmic (Log-Series): p
configs["logser"] = {
"param_names": ["p"],
"initial": lambda data: [0.5],
"bounds": lambda data: [(1e-6, 1 - 1e-6)],
}
# Planck: lambda
configs["planck"] = {
"param_names": ["lambda"],
"initial": lambda data: [1.0],
"bounds": lambda data: [(1e-6, 100)],
}
# Skellam: mu1, mu2 (difference of two Poissons)
def skellam_initial(data: np.ndarray) -> List[float]:
mean_val = np.mean(data)
var_val = np.var(data)
# mu1 + mu2 = var, mu1 - mu2 = mean
mu1 = max((var_val + mean_val) / 2, 0.1)
mu2 = max((var_val - mean_val) / 2, 0.1)
return [mu1, mu2]
configs["skellam"] = {
"param_names": ["mu1", "mu2"],
"initial": skellam_initial,
"bounds": lambda data: [(1e-6, 1000), (1e-6, 1000)],
}
# Yule-Simon: alpha
configs["yulesimon"] = {
"param_names": ["alpha"],
"initial": lambda data: [2.0],
"bounds": lambda data: [(1e-6, 20.0)],
}
# Non-central hypergeometric (nhypergeom): M, n, N, odds (4 params - complex)
configs["nhypergeom"] = {
"param_names": ["M", "n", "r"],
"initial": lambda data: [int(np.max(data)) * 2 + 20, int(np.max(data)) + 10, int(np.max(data)) + 5],
"bounds": lambda data: [
(int(np.max(data)) + 10, 10000),
(int(np.max(data)) + 1, 5000),
(1, 5000),
],
}
return configs
[docs]
def get_distributions(
self,
additional_exclusions: Optional[List[str]] = None,
) -> List[str]:
"""Get filtered list of discrete distributions.
Only returns distributions that have parameter configurations defined.
Args:
additional_exclusions: Additional distribution names to exclude
Returns:
List of distribution names that can be fitted
"""
excluded = self._excluded.copy()
if additional_exclusions:
excluded.update(additional_exclusions)
# Only include distributions we have configs for
return [d for d in self.ALL_DISTRIBUTIONS if d not in excluded and d in self._param_configs]
[docs]
def get_param_config(self, dist_name: str) -> Dict[str, Any]:
"""Get parameter configuration for a distribution.
Args:
dist_name: Name of the scipy discrete distribution
Returns:
Dictionary with 'param_names', 'initial', and 'bounds' keys
Raises:
ValueError: If distribution is not supported
"""
if dist_name not in self._param_configs:
raise ValueError(
f"Distribution '{dist_name}' is not supported. " f"Supported: {list(self._param_configs.keys())}"
)
return self._param_configs[dist_name]
[docs]
def add_exclusion(self, dist_name: str) -> None:
"""Add a distribution to the exclusion list."""
self._excluded.add(dist_name)
[docs]
def remove_exclusion(self, dist_name: str) -> None:
"""Remove a distribution from the exclusion list."""
self._excluded.discard(dist_name)
[docs]
def get_exclusions(self) -> Set[str]:
"""Get current set of excluded distributions."""
return self._excluded.copy()
[docs]
def reset_exclusions(self) -> None:
"""Reset exclusions to default set."""
self._excluded = self.DEFAULT_EXCLUSIONS.copy()