valsv's picture
Upload folder using huggingface_hub
ccd282b verified
"""
Statistical Inference Module for Negative Binomial GLM
This module implements closed-form standard error calculations and statistical
inference for negative binomial GLM parameters, following the mathematical
derivation in methods/closed_form_standard_errors.md.
Key functions:
- compute_fisher_weights: Calculate Fisher information weights
- compute_standard_errors: Closed-form standard errors for binary predictor
- compute_wald_statistics: Wald test statistics and p-values
- validate_calibration: QQ plots for p-value calibration assessment
"""
import numpy as np
import matplotlib.pyplot as plt
from scipy import stats
from scipy.stats import uniform
from typing import Tuple, Dict, Optional, Union
import warnings
def compute_fisher_weights(mu_hat: float,
beta_hat: float,
alpha_hat: float,
x_indicators: np.ndarray,
lib_sizes: np.ndarray) -> np.ndarray:
"""
Compute Fisher information weights for negative binomial GLM.
For each observation i, the Fisher weight is:
W_i = m_i / (1 + φ * m_i)
where:
- m_i = ℓ_i * exp(μ̂ + x_i * β̂) is the fitted mean
- φ = exp(α̂) is the dispersion parameter
- ℓ_i is the library size (exposure)
- x_i ∈ {0,1} is the treatment indicator
Args:
mu_hat: Fitted intercept parameter (log scale)
beta_hat: Fitted slope parameter (log fold change)
alpha_hat: Fitted dispersion parameter (log scale)
x_indicators: Binary treatment indicators (0 = control, 1 = treatment)
lib_sizes: Library sizes (exposures) for each observation
Returns:
Array of Fisher weights W_i for each observation
References:
methods/closed_form_standard_errors.md
"""
# Convert parameters to natural scale
phi = np.exp(alpha_hat) # Dispersion parameter
# Compute fitted means: m_i = ℓ_i * exp(μ̂ + x_i * β̂)
linear_predictor = mu_hat + x_indicators * beta_hat
fitted_means = lib_sizes * np.exp(linear_predictor)
# Compute Fisher weights: W_i = m_i / (1 + φ * m_i)
weights = fitted_means / (1.0 + phi * fitted_means)
return weights
def compute_standard_errors(mu_hat: float,
beta_hat: float,
alpha_hat: float,
x_indicators: np.ndarray,
lib_sizes: np.ndarray) -> Dict[str, float]:
"""
Compute closed-form standard errors for negative binomial GLM with binary predictor.
For a binary predictor x ∈ {0,1}, the standard errors are:
- SE(β̂₁) = √(1/S₀ + 1/S₁) [slope/treatment effect]
- SE(β̂₀) = 1/√S₀ [intercept]
where:
- S₀ = Σ W_i for observations with x_i = 0 (control group)
- S₁ = Σ W_i for observations with x_i = 1 (treatment group)
Args:
mu_hat: Fitted intercept parameter (log scale)
beta_hat: Fitted slope parameter (log fold change)
alpha_hat: Fitted dispersion parameter (log scale)
x_indicators: Binary treatment indicators (0 = control, 1 = treatment)
lib_sizes: Library sizes (exposures) for each observation
Returns:
Dictionary with standard errors:
- 'se_beta': Standard error of treatment effect (slope)
- 'se_mu': Standard error of intercept
- 'S0': Sum of weights for control group
- 'S1': Sum of weights for treatment group
References:
methods/closed_form_standard_errors.md, Section 5
"""
# Input validation
x_indicators = np.asarray(x_indicators)
lib_sizes = np.asarray(lib_sizes)
if len(x_indicators) != len(lib_sizes):
raise ValueError("x_indicators and lib_sizes must have same length")
if not np.all(np.isin(x_indicators, [0, 1])):
raise ValueError("x_indicators must contain only 0s and 1s")
if np.any(lib_sizes <= 0):
raise ValueError("lib_sizes must be positive")
# Compute Fisher weights
weights = compute_fisher_weights(mu_hat, beta_hat, alpha_hat, x_indicators, lib_sizes)
# Compute group-wise weight sums
S0 = np.sum(weights[x_indicators == 0]) # Control group
S1 = np.sum(weights[x_indicators == 1]) # Treatment group
# Handle edge cases
if S0 <= 0 or S1 <= 0:
warnings.warn("One or both groups have zero weight sum. Standard errors may be unreliable.")
se_beta = np.inf
se_mu = np.inf
else:
# Closed-form standard errors
se_beta = np.sqrt(1.0/S0 + 1.0/S1) # Treatment effect standard error
se_mu = 1.0 / np.sqrt(S0) # Intercept standard error
return {
'se_beta': se_beta,
'se_mu': se_mu,
'S0': S0,
'S1': S1
}
def compute_wald_statistics(beta_hat: float, se_beta: float) -> Dict[str, float]:
"""
Compute Wald test statistics and p-values for treatment effect.
The Wald statistic for testing H₀: β = 0 vs H₁: β ≠ 0 is:
z = β̂ / SE(β̂)
Under the null hypothesis, z ~ N(0,1) asymptotically.
Two-sided p-value: p = 2 * (1 - Φ(|z|))
Args:
beta_hat: Fitted treatment effect (log fold change)
se_beta: Standard error of treatment effect
Returns:
Dictionary with test statistics:
- 'z_stat': Wald z-statistic
- 'p_value': Two-sided p-value
- 'chi2_stat': Chi-squared statistic (z²)
References:
methods/closed_form_standard_errors.md, Section 6
"""
# Handle edge cases
if se_beta <= 0 or np.isinf(se_beta):
return {
'z_stat': np.nan,
'p_value': np.nan,
'chi2_stat': np.nan
}
# Compute Wald statistic
z_stat = beta_hat / se_beta
# Two-sided p-value using normal distribution
p_value = 2.0 * (1.0 - stats.norm.cdf(np.abs(z_stat)))
# Chi-squared statistic (equivalent test)
chi2_stat = z_stat ** 2
return {
'z_stat': z_stat,
'p_value': p_value,
'chi2_stat': chi2_stat
}
def compute_nb_glm_inference(mu_hat: float,
beta_hat: float,
alpha_hat: float,
x_indicators: np.ndarray,
lib_sizes: np.ndarray) -> Dict[str, float]:
"""
Complete statistical inference for negative binomial GLM with binary predictor.
Combines parameter estimates with closed-form standard errors and test statistics
to provide full statistical inference equivalent to classical GLM software.
Args:
mu_hat: Fitted intercept parameter (log scale)
beta_hat: Fitted slope parameter (log fold change)
alpha_hat: Fitted dispersion parameter (log scale)
x_indicators: Binary treatment indicators (0 = control, 1 = treatment)
lib_sizes: Library sizes (exposures) for each observation
Returns:
Dictionary with complete inference results:
- Parameter estimates: mu_hat, beta_hat, alpha_hat
- Standard errors: se_mu, se_beta
- Test statistics: z_stat, chi2_stat
- P-value: p_value (two-sided test of H₀: β = 0)
- Fisher information: S0, S1 (group weight sums)
"""
# Compute standard errors
se_results = compute_standard_errors(mu_hat, beta_hat, alpha_hat, x_indicators, lib_sizes)
# Compute test statistics
test_results = compute_wald_statistics(beta_hat, se_results['se_beta'])
# Combine all results
inference_results = {
# Parameter estimates
'mu_hat': mu_hat,
'beta_hat': beta_hat,
'alpha_hat': alpha_hat,
# Standard errors
'se_mu': se_results['se_mu'],
'se_beta': se_results['se_beta'],
# Test statistics
'z_stat': test_results['z_stat'],
'chi2_stat': test_results['chi2_stat'],
'p_value': test_results['p_value'],
# Fisher information
'S0': se_results['S0'],
'S1': se_results['S1']
}
return inference_results
def validate_calibration(p_values: np.ndarray,
title: str = "P-value Calibration",
output_path: Optional[str] = None,
alpha: float = 0.05) -> Dict[str, float]:
"""
Validate statistical calibration using QQ plots and uniformity tests.
Under correct calibration, p-values from null data should follow Uniform(0,1).
This function creates QQ plots and performs statistical tests to assess calibration.
Args:
p_values: Array of p-values to test for uniformity
title: Title for the QQ plot
output_path: Optional path to save the plot
alpha: Significance level for statistical tests
Returns:
Dictionary with calibration metrics:
- 'ks_statistic': Kolmogorov-Smirnov test statistic
- 'ks_pvalue': KS test p-value
- 'ad_statistic': Anderson-Darling test statistic
- 'ad_pvalue': AD test p-value (approximate)
- 'is_calibrated_ks': Boolean, True if KS test is non-significant
- 'is_calibrated_ad': Boolean, True if AD test is non-significant
References:
Statistical calibration assessment for hypothesis testing
"""
# Remove NaN values
p_values = p_values[~np.isnan(p_values)]
if len(p_values) == 0:
raise ValueError("No valid p-values provided")
# Kolmogorov-Smirnov test for uniformity
ks_stat, ks_pval = stats.kstest(p_values, 'uniform')
# Anderson-Darling test for uniformity using manual calculation
# Since scipy doesn't support uniform dist directly, we use the formula
# for uniform distribution on [0,1]
n = len(p_values)
p_sorted = np.sort(p_values)
# Anderson-Darling statistic for uniform distribution
i = np.arange(1, n + 1)
ad_stat = -n - np.sum((2*i - 1) * (np.log(p_sorted) + np.log(1 - p_sorted[::-1]))) / n
# Critical values for uniform distribution (approximate)
# These are rough approximations based on simulation studies
if n >= 25:
ad_critical_05 = 2.492 # 5% critical value for large n
ad_pval_approx = 0.05 if ad_stat > ad_critical_05 else 0.1
else:
# For small samples, use more conservative threshold
ad_critical_05 = 2.0
ad_pval_approx = 0.05 if ad_stat > ad_critical_05 else 0.1
# Create QQ plot
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
# QQ plot against uniform distribution
expected_quantiles = np.linspace(0, 1, len(p_values))
observed_quantiles = np.sort(p_values)
ax1.scatter(expected_quantiles, observed_quantiles, alpha=0.6, s=20)
ax1.plot([0, 1], [0, 1], 'r--', label='Perfect calibration')
ax1.set_xlabel('Expected quantiles (Uniform)')
ax1.set_ylabel('Observed quantiles (P-values)')
ax1.set_title(f'{title}\nQQ Plot vs Uniform(0,1)')
ax1.legend()
ax1.grid(True, alpha=0.3)
# Histogram of p-values
ax2.hist(p_values, bins=20, density=True, alpha=0.7, color='skyblue',
edgecolor='black', label='Observed')
ax2.axhline(y=1.0, color='red', linestyle='--', label='Expected (Uniform)')
ax2.set_xlabel('P-value')
ax2.set_ylabel('Density')
ax2.set_title(f'{title}\nP-value Histogram')
ax2.legend()
ax2.grid(True, alpha=0.3)
plt.tight_layout()
# Add statistical test results as text
textstr = f'KS test: D={ks_stat:.4f}, p={ks_pval:.4f}\nAD test: A²={ad_stat:.4f}'
fig.text(0.02, 0.02, textstr, fontsize=10,
bbox=dict(boxstyle="round,pad=0.3", facecolor="lightgray"))
if output_path:
plt.savefig(output_path, dpi=300, bbox_inches='tight')
print(f"Calibration plot saved to: {output_path}")
else:
plt.show()
# Return calibration metrics
calibration_metrics = {
'ks_statistic': ks_stat,
'ks_pvalue': ks_pval,
'ad_statistic': ad_stat,
'ad_pvalue': ad_pval_approx,
'is_calibrated_ks': ks_pval > alpha,
'is_calibrated_ad': ad_pval_approx > alpha,
'n_tests': len(p_values)
}
return calibration_metrics
def summarize_calibration_results(calibration_metrics: Dict[str, float]) -> str:
"""
Generate a human-readable summary of calibration results.
Args:
calibration_metrics: Output from validate_calibration()
Returns:
Formatted string summary
"""
ks_result = "✓ Well-calibrated" if calibration_metrics['is_calibrated_ks'] else "✗ Poorly calibrated"
ad_result = "✓ Well-calibrated" if calibration_metrics['is_calibrated_ad'] else "✗ Poorly calibrated"
summary = f"""
Calibration Assessment Summary (n = {calibration_metrics['n_tests']:,})
=========================================
Kolmogorov-Smirnov Test:
Statistic: {calibration_metrics['ks_statistic']:.4f}
P-value: {calibration_metrics['ks_pvalue']:.4f}
Result: {ks_result}
Anderson-Darling Test:
Statistic: {calibration_metrics['ad_statistic']:.4f}
P-value: ~{calibration_metrics['ad_pvalue']:.3f}
Result: {ad_result}
Interpretation:
- Well-calibrated methods should show p-values ~ Uniform(0,1) under null hypothesis
- Significant test results (p < 0.05) indicate poor calibration
- QQ plot should follow diagonal line for good calibration
"""
return summary
def load_pretrained_model(checkpoint_path: Optional[str] = None, device: Optional[str] = None):
"""
Load the pre-trained NB-Transformer model.
Args:
checkpoint_path: Path to checkpoint file. If None, uses bundled v13 model.
device: Device to load model on ('cpu', 'cuda', 'mps'). If None, auto-detects.
Returns:
Loaded DispersionTransformer model ready for inference
Example:
>>> from nb_transformer import load_pretrained_model
>>> model = load_pretrained_model()
>>> params = model.predict_parameters([2.1, 1.8, 2.3], [1.5, 1.2, 1.7])
"""
import torch
import os
from .model import DispersionTransformer
from .train import DispersionLightningModule
# Auto-detect device if not specified
if device is None:
if torch.cuda.is_available():
device = 'cuda'
elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
device = 'mps'
else:
device = 'cpu'
# Use bundled checkpoint if none specified
if checkpoint_path is None:
package_dir = os.path.dirname(__file__)
checkpoint_path = os.path.join(package_dir, '..', 'model_checkpoint', 'last-v13.ckpt')
if not os.path.exists(checkpoint_path):
raise FileNotFoundError(
f"Bundled model checkpoint not found at {checkpoint_path}. "
"Please provide checkpoint_path explicitly."
)
# Load checkpoint
try:
lightning_module = DispersionLightningModule.load_from_checkpoint(
checkpoint_path,
map_location=device
)
model = lightning_module.model
model.to(device)
model.eval()
return model
except Exception as e:
raise RuntimeError(f"Failed to load model from {checkpoint_path}: {e}")
def quick_inference_example():
"""
Demonstrate quick inference with the pre-trained model.
Returns:
Dictionary with example parameters
"""
# Load model
model = load_pretrained_model()
# Example data: two conditions with different sample sizes
condition_1 = [2.1, 1.8, 2.3, 2.0] # 4 samples from control
condition_2 = [1.5, 1.2, 1.7, 1.4, 1.6] # 5 samples from treatment
# Predict parameters
params = model.predict_parameters(condition_1, condition_2)
print("NB-Transformer Quick Inference Example")
print("=====================================")
print(f"Control samples: {condition_1}")
print(f"Treatment samples: {condition_2}")
print(f"μ̂ (base mean): {params['mu']:.3f}")
print(f"β̂ (log fold change): {params['beta']:.3f}")
print(f"α̂ (log dispersion): {params['alpha']:.3f}")
print(f"Fold change: {np.exp(params['beta']):.2f}x")
return params