""" Statistical Inference Module for Negative Binomial GLM This module implements closed-form standard error calculations and statistical inference for negative binomial GLM parameters, following the mathematical derivation in methods/closed_form_standard_errors.md. Key functions: - compute_fisher_weights: Calculate Fisher information weights - compute_standard_errors: Closed-form standard errors for binary predictor - compute_wald_statistics: Wald test statistics and p-values - validate_calibration: QQ plots for p-value calibration assessment """ import numpy as np import matplotlib.pyplot as plt from scipy import stats from scipy.stats import uniform from typing import Tuple, Dict, Optional, Union import warnings def compute_fisher_weights(mu_hat: float, beta_hat: float, alpha_hat: float, x_indicators: np.ndarray, lib_sizes: np.ndarray) -> np.ndarray: """ Compute Fisher information weights for negative binomial GLM. For each observation i, the Fisher weight is: W_i = m_i / (1 + φ * m_i) where: - m_i = ℓ_i * exp(μ̂ + x_i * β̂) is the fitted mean - φ = exp(α̂) is the dispersion parameter - ℓ_i is the library size (exposure) - x_i ∈ {0,1} is the treatment indicator Args: mu_hat: Fitted intercept parameter (log scale) beta_hat: Fitted slope parameter (log fold change) alpha_hat: Fitted dispersion parameter (log scale) x_indicators: Binary treatment indicators (0 = control, 1 = treatment) lib_sizes: Library sizes (exposures) for each observation Returns: Array of Fisher weights W_i for each observation References: methods/closed_form_standard_errors.md """ # Convert parameters to natural scale phi = np.exp(alpha_hat) # Dispersion parameter # Compute fitted means: m_i = ℓ_i * exp(μ̂ + x_i * β̂) linear_predictor = mu_hat + x_indicators * beta_hat fitted_means = lib_sizes * np.exp(linear_predictor) # Compute Fisher weights: W_i = m_i / (1 + φ * m_i) weights = fitted_means / (1.0 + phi * fitted_means) return weights def compute_standard_errors(mu_hat: float, beta_hat: float, alpha_hat: float, x_indicators: np.ndarray, lib_sizes: np.ndarray) -> Dict[str, float]: """ Compute closed-form standard errors for negative binomial GLM with binary predictor. For a binary predictor x ∈ {0,1}, the standard errors are: - SE(β̂₁) = √(1/S₀ + 1/S₁) [slope/treatment effect] - SE(β̂₀) = 1/√S₀ [intercept] where: - S₀ = Σ W_i for observations with x_i = 0 (control group) - S₁ = Σ W_i for observations with x_i = 1 (treatment group) Args: mu_hat: Fitted intercept parameter (log scale) beta_hat: Fitted slope parameter (log fold change) alpha_hat: Fitted dispersion parameter (log scale) x_indicators: Binary treatment indicators (0 = control, 1 = treatment) lib_sizes: Library sizes (exposures) for each observation Returns: Dictionary with standard errors: - 'se_beta': Standard error of treatment effect (slope) - 'se_mu': Standard error of intercept - 'S0': Sum of weights for control group - 'S1': Sum of weights for treatment group References: methods/closed_form_standard_errors.md, Section 5 """ # Input validation x_indicators = np.asarray(x_indicators) lib_sizes = np.asarray(lib_sizes) if len(x_indicators) != len(lib_sizes): raise ValueError("x_indicators and lib_sizes must have same length") if not np.all(np.isin(x_indicators, [0, 1])): raise ValueError("x_indicators must contain only 0s and 1s") if np.any(lib_sizes <= 0): raise ValueError("lib_sizes must be positive") # Compute Fisher weights weights = compute_fisher_weights(mu_hat, beta_hat, alpha_hat, x_indicators, lib_sizes) # Compute group-wise weight sums S0 = np.sum(weights[x_indicators == 0]) # Control group S1 = np.sum(weights[x_indicators == 1]) # Treatment group # Handle edge cases if S0 <= 0 or S1 <= 0: warnings.warn("One or both groups have zero weight sum. Standard errors may be unreliable.") se_beta = np.inf se_mu = np.inf else: # Closed-form standard errors se_beta = np.sqrt(1.0/S0 + 1.0/S1) # Treatment effect standard error se_mu = 1.0 / np.sqrt(S0) # Intercept standard error return { 'se_beta': se_beta, 'se_mu': se_mu, 'S0': S0, 'S1': S1 } def compute_wald_statistics(beta_hat: float, se_beta: float) -> Dict[str, float]: """ Compute Wald test statistics and p-values for treatment effect. The Wald statistic for testing H₀: β = 0 vs H₁: β ≠ 0 is: z = β̂ / SE(β̂) Under the null hypothesis, z ~ N(0,1) asymptotically. Two-sided p-value: p = 2 * (1 - Φ(|z|)) Args: beta_hat: Fitted treatment effect (log fold change) se_beta: Standard error of treatment effect Returns: Dictionary with test statistics: - 'z_stat': Wald z-statistic - 'p_value': Two-sided p-value - 'chi2_stat': Chi-squared statistic (z²) References: methods/closed_form_standard_errors.md, Section 6 """ # Handle edge cases if se_beta <= 0 or np.isinf(se_beta): return { 'z_stat': np.nan, 'p_value': np.nan, 'chi2_stat': np.nan } # Compute Wald statistic z_stat = beta_hat / se_beta # Two-sided p-value using normal distribution p_value = 2.0 * (1.0 - stats.norm.cdf(np.abs(z_stat))) # Chi-squared statistic (equivalent test) chi2_stat = z_stat ** 2 return { 'z_stat': z_stat, 'p_value': p_value, 'chi2_stat': chi2_stat } def compute_nb_glm_inference(mu_hat: float, beta_hat: float, alpha_hat: float, x_indicators: np.ndarray, lib_sizes: np.ndarray) -> Dict[str, float]: """ Complete statistical inference for negative binomial GLM with binary predictor. Combines parameter estimates with closed-form standard errors and test statistics to provide full statistical inference equivalent to classical GLM software. Args: mu_hat: Fitted intercept parameter (log scale) beta_hat: Fitted slope parameter (log fold change) alpha_hat: Fitted dispersion parameter (log scale) x_indicators: Binary treatment indicators (0 = control, 1 = treatment) lib_sizes: Library sizes (exposures) for each observation Returns: Dictionary with complete inference results: - Parameter estimates: mu_hat, beta_hat, alpha_hat - Standard errors: se_mu, se_beta - Test statistics: z_stat, chi2_stat - P-value: p_value (two-sided test of H₀: β = 0) - Fisher information: S0, S1 (group weight sums) """ # Compute standard errors se_results = compute_standard_errors(mu_hat, beta_hat, alpha_hat, x_indicators, lib_sizes) # Compute test statistics test_results = compute_wald_statistics(beta_hat, se_results['se_beta']) # Combine all results inference_results = { # Parameter estimates 'mu_hat': mu_hat, 'beta_hat': beta_hat, 'alpha_hat': alpha_hat, # Standard errors 'se_mu': se_results['se_mu'], 'se_beta': se_results['se_beta'], # Test statistics 'z_stat': test_results['z_stat'], 'chi2_stat': test_results['chi2_stat'], 'p_value': test_results['p_value'], # Fisher information 'S0': se_results['S0'], 'S1': se_results['S1'] } return inference_results def validate_calibration(p_values: np.ndarray, title: str = "P-value Calibration", output_path: Optional[str] = None, alpha: float = 0.05) -> Dict[str, float]: """ Validate statistical calibration using QQ plots and uniformity tests. Under correct calibration, p-values from null data should follow Uniform(0,1). This function creates QQ plots and performs statistical tests to assess calibration. Args: p_values: Array of p-values to test for uniformity title: Title for the QQ plot output_path: Optional path to save the plot alpha: Significance level for statistical tests Returns: Dictionary with calibration metrics: - 'ks_statistic': Kolmogorov-Smirnov test statistic - 'ks_pvalue': KS test p-value - 'ad_statistic': Anderson-Darling test statistic - 'ad_pvalue': AD test p-value (approximate) - 'is_calibrated_ks': Boolean, True if KS test is non-significant - 'is_calibrated_ad': Boolean, True if AD test is non-significant References: Statistical calibration assessment for hypothesis testing """ # Remove NaN values p_values = p_values[~np.isnan(p_values)] if len(p_values) == 0: raise ValueError("No valid p-values provided") # Kolmogorov-Smirnov test for uniformity ks_stat, ks_pval = stats.kstest(p_values, 'uniform') # Anderson-Darling test for uniformity using manual calculation # Since scipy doesn't support uniform dist directly, we use the formula # for uniform distribution on [0,1] n = len(p_values) p_sorted = np.sort(p_values) # Anderson-Darling statistic for uniform distribution i = np.arange(1, n + 1) ad_stat = -n - np.sum((2*i - 1) * (np.log(p_sorted) + np.log(1 - p_sorted[::-1]))) / n # Critical values for uniform distribution (approximate) # These are rough approximations based on simulation studies if n >= 25: ad_critical_05 = 2.492 # 5% critical value for large n ad_pval_approx = 0.05 if ad_stat > ad_critical_05 else 0.1 else: # For small samples, use more conservative threshold ad_critical_05 = 2.0 ad_pval_approx = 0.05 if ad_stat > ad_critical_05 else 0.1 # Create QQ plot fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5)) # QQ plot against uniform distribution expected_quantiles = np.linspace(0, 1, len(p_values)) observed_quantiles = np.sort(p_values) ax1.scatter(expected_quantiles, observed_quantiles, alpha=0.6, s=20) ax1.plot([0, 1], [0, 1], 'r--', label='Perfect calibration') ax1.set_xlabel('Expected quantiles (Uniform)') ax1.set_ylabel('Observed quantiles (P-values)') ax1.set_title(f'{title}\nQQ Plot vs Uniform(0,1)') ax1.legend() ax1.grid(True, alpha=0.3) # Histogram of p-values ax2.hist(p_values, bins=20, density=True, alpha=0.7, color='skyblue', edgecolor='black', label='Observed') ax2.axhline(y=1.0, color='red', linestyle='--', label='Expected (Uniform)') ax2.set_xlabel('P-value') ax2.set_ylabel('Density') ax2.set_title(f'{title}\nP-value Histogram') ax2.legend() ax2.grid(True, alpha=0.3) plt.tight_layout() # Add statistical test results as text textstr = f'KS test: D={ks_stat:.4f}, p={ks_pval:.4f}\nAD test: A²={ad_stat:.4f}' fig.text(0.02, 0.02, textstr, fontsize=10, bbox=dict(boxstyle="round,pad=0.3", facecolor="lightgray")) if output_path: plt.savefig(output_path, dpi=300, bbox_inches='tight') print(f"Calibration plot saved to: {output_path}") else: plt.show() # Return calibration metrics calibration_metrics = { 'ks_statistic': ks_stat, 'ks_pvalue': ks_pval, 'ad_statistic': ad_stat, 'ad_pvalue': ad_pval_approx, 'is_calibrated_ks': ks_pval > alpha, 'is_calibrated_ad': ad_pval_approx > alpha, 'n_tests': len(p_values) } return calibration_metrics def summarize_calibration_results(calibration_metrics: Dict[str, float]) -> str: """ Generate a human-readable summary of calibration results. Args: calibration_metrics: Output from validate_calibration() Returns: Formatted string summary """ ks_result = "✓ Well-calibrated" if calibration_metrics['is_calibrated_ks'] else "✗ Poorly calibrated" ad_result = "✓ Well-calibrated" if calibration_metrics['is_calibrated_ad'] else "✗ Poorly calibrated" summary = f""" Calibration Assessment Summary (n = {calibration_metrics['n_tests']:,}) ========================================= Kolmogorov-Smirnov Test: Statistic: {calibration_metrics['ks_statistic']:.4f} P-value: {calibration_metrics['ks_pvalue']:.4f} Result: {ks_result} Anderson-Darling Test: Statistic: {calibration_metrics['ad_statistic']:.4f} P-value: ~{calibration_metrics['ad_pvalue']:.3f} Result: {ad_result} Interpretation: - Well-calibrated methods should show p-values ~ Uniform(0,1) under null hypothesis - Significant test results (p < 0.05) indicate poor calibration - QQ plot should follow diagonal line for good calibration """ return summary def load_pretrained_model(checkpoint_path: Optional[str] = None, device: Optional[str] = None): """ Load the pre-trained NB-Transformer model. Args: checkpoint_path: Path to checkpoint file. If None, uses bundled v13 model. device: Device to load model on ('cpu', 'cuda', 'mps'). If None, auto-detects. Returns: Loaded DispersionTransformer model ready for inference Example: >>> from nb_transformer import load_pretrained_model >>> model = load_pretrained_model() >>> params = model.predict_parameters([2.1, 1.8, 2.3], [1.5, 1.2, 1.7]) """ import torch import os from .model import DispersionTransformer from .train import DispersionLightningModule # Auto-detect device if not specified if device is None: if torch.cuda.is_available(): device = 'cuda' elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): device = 'mps' else: device = 'cpu' # Use bundled checkpoint if none specified if checkpoint_path is None: package_dir = os.path.dirname(__file__) checkpoint_path = os.path.join(package_dir, '..', 'model_checkpoint', 'last-v13.ckpt') if not os.path.exists(checkpoint_path): raise FileNotFoundError( f"Bundled model checkpoint not found at {checkpoint_path}. " "Please provide checkpoint_path explicitly." ) # Load checkpoint try: lightning_module = DispersionLightningModule.load_from_checkpoint( checkpoint_path, map_location=device ) model = lightning_module.model model.to(device) model.eval() return model except Exception as e: raise RuntimeError(f"Failed to load model from {checkpoint_path}: {e}") def quick_inference_example(): """ Demonstrate quick inference with the pre-trained model. Returns: Dictionary with example parameters """ # Load model model = load_pretrained_model() # Example data: two conditions with different sample sizes condition_1 = [2.1, 1.8, 2.3, 2.0] # 4 samples from control condition_2 = [1.5, 1.2, 1.7, 1.4, 1.6] # 5 samples from treatment # Predict parameters params = model.predict_parameters(condition_1, condition_2) print("NB-Transformer Quick Inference Example") print("=====================================") print(f"Control samples: {condition_1}") print(f"Treatment samples: {condition_2}") print(f"μ̂ (base mean): {params['mu']:.3f}") print(f"β̂ (log fold change): {params['beta']:.3f}") print(f"α̂ (log dispersion): {params['alpha']:.3f}") print(f"Fold change: {np.exp(params['beta']):.2f}x") return params