|
|
""" |
|
|
Statistical Inference Module for Negative Binomial GLM |
|
|
|
|
|
This module implements closed-form standard error calculations and statistical |
|
|
inference for negative binomial GLM parameters, following the mathematical |
|
|
derivation in methods/closed_form_standard_errors.md. |
|
|
|
|
|
Key functions: |
|
|
- compute_fisher_weights: Calculate Fisher information weights |
|
|
- compute_standard_errors: Closed-form standard errors for binary predictor |
|
|
- compute_wald_statistics: Wald test statistics and p-values |
|
|
- validate_calibration: QQ plots for p-value calibration assessment |
|
|
""" |
|
|
|
|
|
import numpy as np |
|
|
import matplotlib.pyplot as plt |
|
|
from scipy import stats |
|
|
from scipy.stats import uniform |
|
|
from typing import Tuple, Dict, Optional, Union |
|
|
import warnings |
|
|
|
|
|
|
|
|
def compute_fisher_weights(mu_hat: float, |
|
|
beta_hat: float, |
|
|
alpha_hat: float, |
|
|
x_indicators: np.ndarray, |
|
|
lib_sizes: np.ndarray) -> np.ndarray: |
|
|
""" |
|
|
Compute Fisher information weights for negative binomial GLM. |
|
|
|
|
|
For each observation i, the Fisher weight is: |
|
|
W_i = m_i / (1 + φ * m_i) |
|
|
|
|
|
where: |
|
|
- m_i = ℓ_i * exp(μ̂ + x_i * β̂) is the fitted mean |
|
|
- φ = exp(α̂) is the dispersion parameter |
|
|
- ℓ_i is the library size (exposure) |
|
|
- x_i ∈ {0,1} is the treatment indicator |
|
|
|
|
|
Args: |
|
|
mu_hat: Fitted intercept parameter (log scale) |
|
|
beta_hat: Fitted slope parameter (log fold change) |
|
|
alpha_hat: Fitted dispersion parameter (log scale) |
|
|
x_indicators: Binary treatment indicators (0 = control, 1 = treatment) |
|
|
lib_sizes: Library sizes (exposures) for each observation |
|
|
|
|
|
Returns: |
|
|
Array of Fisher weights W_i for each observation |
|
|
|
|
|
References: |
|
|
methods/closed_form_standard_errors.md |
|
|
""" |
|
|
|
|
|
phi = np.exp(alpha_hat) |
|
|
|
|
|
|
|
|
linear_predictor = mu_hat + x_indicators * beta_hat |
|
|
fitted_means = lib_sizes * np.exp(linear_predictor) |
|
|
|
|
|
|
|
|
weights = fitted_means / (1.0 + phi * fitted_means) |
|
|
|
|
|
return weights |
|
|
|
|
|
|
|
|
def compute_standard_errors(mu_hat: float, |
|
|
beta_hat: float, |
|
|
alpha_hat: float, |
|
|
x_indicators: np.ndarray, |
|
|
lib_sizes: np.ndarray) -> Dict[str, float]: |
|
|
""" |
|
|
Compute closed-form standard errors for negative binomial GLM with binary predictor. |
|
|
|
|
|
For a binary predictor x ∈ {0,1}, the standard errors are: |
|
|
- SE(β̂₁) = √(1/S₀ + 1/S₁) [slope/treatment effect] |
|
|
- SE(β̂₀) = 1/√S₀ [intercept] |
|
|
|
|
|
where: |
|
|
- S₀ = Σ W_i for observations with x_i = 0 (control group) |
|
|
- S₁ = Σ W_i for observations with x_i = 1 (treatment group) |
|
|
|
|
|
Args: |
|
|
mu_hat: Fitted intercept parameter (log scale) |
|
|
beta_hat: Fitted slope parameter (log fold change) |
|
|
alpha_hat: Fitted dispersion parameter (log scale) |
|
|
x_indicators: Binary treatment indicators (0 = control, 1 = treatment) |
|
|
lib_sizes: Library sizes (exposures) for each observation |
|
|
|
|
|
Returns: |
|
|
Dictionary with standard errors: |
|
|
- 'se_beta': Standard error of treatment effect (slope) |
|
|
- 'se_mu': Standard error of intercept |
|
|
- 'S0': Sum of weights for control group |
|
|
- 'S1': Sum of weights for treatment group |
|
|
|
|
|
References: |
|
|
methods/closed_form_standard_errors.md, Section 5 |
|
|
""" |
|
|
|
|
|
x_indicators = np.asarray(x_indicators) |
|
|
lib_sizes = np.asarray(lib_sizes) |
|
|
|
|
|
if len(x_indicators) != len(lib_sizes): |
|
|
raise ValueError("x_indicators and lib_sizes must have same length") |
|
|
|
|
|
if not np.all(np.isin(x_indicators, [0, 1])): |
|
|
raise ValueError("x_indicators must contain only 0s and 1s") |
|
|
|
|
|
if np.any(lib_sizes <= 0): |
|
|
raise ValueError("lib_sizes must be positive") |
|
|
|
|
|
|
|
|
weights = compute_fisher_weights(mu_hat, beta_hat, alpha_hat, x_indicators, lib_sizes) |
|
|
|
|
|
|
|
|
S0 = np.sum(weights[x_indicators == 0]) |
|
|
S1 = np.sum(weights[x_indicators == 1]) |
|
|
|
|
|
|
|
|
if S0 <= 0 or S1 <= 0: |
|
|
warnings.warn("One or both groups have zero weight sum. Standard errors may be unreliable.") |
|
|
se_beta = np.inf |
|
|
se_mu = np.inf |
|
|
else: |
|
|
|
|
|
se_beta = np.sqrt(1.0/S0 + 1.0/S1) |
|
|
se_mu = 1.0 / np.sqrt(S0) |
|
|
|
|
|
return { |
|
|
'se_beta': se_beta, |
|
|
'se_mu': se_mu, |
|
|
'S0': S0, |
|
|
'S1': S1 |
|
|
} |
|
|
|
|
|
|
|
|
def compute_wald_statistics(beta_hat: float, se_beta: float) -> Dict[str, float]: |
|
|
""" |
|
|
Compute Wald test statistics and p-values for treatment effect. |
|
|
|
|
|
The Wald statistic for testing H₀: β = 0 vs H₁: β ≠ 0 is: |
|
|
z = β̂ / SE(β̂) |
|
|
|
|
|
Under the null hypothesis, z ~ N(0,1) asymptotically. |
|
|
Two-sided p-value: p = 2 * (1 - Φ(|z|)) |
|
|
|
|
|
Args: |
|
|
beta_hat: Fitted treatment effect (log fold change) |
|
|
se_beta: Standard error of treatment effect |
|
|
|
|
|
Returns: |
|
|
Dictionary with test statistics: |
|
|
- 'z_stat': Wald z-statistic |
|
|
- 'p_value': Two-sided p-value |
|
|
- 'chi2_stat': Chi-squared statistic (z²) |
|
|
|
|
|
References: |
|
|
methods/closed_form_standard_errors.md, Section 6 |
|
|
""" |
|
|
|
|
|
if se_beta <= 0 or np.isinf(se_beta): |
|
|
return { |
|
|
'z_stat': np.nan, |
|
|
'p_value': np.nan, |
|
|
'chi2_stat': np.nan |
|
|
} |
|
|
|
|
|
|
|
|
z_stat = beta_hat / se_beta |
|
|
|
|
|
|
|
|
p_value = 2.0 * (1.0 - stats.norm.cdf(np.abs(z_stat))) |
|
|
|
|
|
|
|
|
chi2_stat = z_stat ** 2 |
|
|
|
|
|
return { |
|
|
'z_stat': z_stat, |
|
|
'p_value': p_value, |
|
|
'chi2_stat': chi2_stat |
|
|
} |
|
|
|
|
|
|
|
|
def compute_nb_glm_inference(mu_hat: float, |
|
|
beta_hat: float, |
|
|
alpha_hat: float, |
|
|
x_indicators: np.ndarray, |
|
|
lib_sizes: np.ndarray) -> Dict[str, float]: |
|
|
""" |
|
|
Complete statistical inference for negative binomial GLM with binary predictor. |
|
|
|
|
|
Combines parameter estimates with closed-form standard errors and test statistics |
|
|
to provide full statistical inference equivalent to classical GLM software. |
|
|
|
|
|
Args: |
|
|
mu_hat: Fitted intercept parameter (log scale) |
|
|
beta_hat: Fitted slope parameter (log fold change) |
|
|
alpha_hat: Fitted dispersion parameter (log scale) |
|
|
x_indicators: Binary treatment indicators (0 = control, 1 = treatment) |
|
|
lib_sizes: Library sizes (exposures) for each observation |
|
|
|
|
|
Returns: |
|
|
Dictionary with complete inference results: |
|
|
- Parameter estimates: mu_hat, beta_hat, alpha_hat |
|
|
- Standard errors: se_mu, se_beta |
|
|
- Test statistics: z_stat, chi2_stat |
|
|
- P-value: p_value (two-sided test of H₀: β = 0) |
|
|
- Fisher information: S0, S1 (group weight sums) |
|
|
""" |
|
|
|
|
|
se_results = compute_standard_errors(mu_hat, beta_hat, alpha_hat, x_indicators, lib_sizes) |
|
|
|
|
|
|
|
|
test_results = compute_wald_statistics(beta_hat, se_results['se_beta']) |
|
|
|
|
|
|
|
|
inference_results = { |
|
|
|
|
|
'mu_hat': mu_hat, |
|
|
'beta_hat': beta_hat, |
|
|
'alpha_hat': alpha_hat, |
|
|
|
|
|
|
|
|
'se_mu': se_results['se_mu'], |
|
|
'se_beta': se_results['se_beta'], |
|
|
|
|
|
|
|
|
'z_stat': test_results['z_stat'], |
|
|
'chi2_stat': test_results['chi2_stat'], |
|
|
'p_value': test_results['p_value'], |
|
|
|
|
|
|
|
|
'S0': se_results['S0'], |
|
|
'S1': se_results['S1'] |
|
|
} |
|
|
|
|
|
return inference_results |
|
|
|
|
|
|
|
|
def validate_calibration(p_values: np.ndarray, |
|
|
title: str = "P-value Calibration", |
|
|
output_path: Optional[str] = None, |
|
|
alpha: float = 0.05) -> Dict[str, float]: |
|
|
""" |
|
|
Validate statistical calibration using QQ plots and uniformity tests. |
|
|
|
|
|
Under correct calibration, p-values from null data should follow Uniform(0,1). |
|
|
This function creates QQ plots and performs statistical tests to assess calibration. |
|
|
|
|
|
Args: |
|
|
p_values: Array of p-values to test for uniformity |
|
|
title: Title for the QQ plot |
|
|
output_path: Optional path to save the plot |
|
|
alpha: Significance level for statistical tests |
|
|
|
|
|
Returns: |
|
|
Dictionary with calibration metrics: |
|
|
- 'ks_statistic': Kolmogorov-Smirnov test statistic |
|
|
- 'ks_pvalue': KS test p-value |
|
|
- 'ad_statistic': Anderson-Darling test statistic |
|
|
- 'ad_pvalue': AD test p-value (approximate) |
|
|
- 'is_calibrated_ks': Boolean, True if KS test is non-significant |
|
|
- 'is_calibrated_ad': Boolean, True if AD test is non-significant |
|
|
|
|
|
References: |
|
|
Statistical calibration assessment for hypothesis testing |
|
|
""" |
|
|
|
|
|
p_values = p_values[~np.isnan(p_values)] |
|
|
|
|
|
if len(p_values) == 0: |
|
|
raise ValueError("No valid p-values provided") |
|
|
|
|
|
|
|
|
ks_stat, ks_pval = stats.kstest(p_values, 'uniform') |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
n = len(p_values) |
|
|
p_sorted = np.sort(p_values) |
|
|
|
|
|
|
|
|
i = np.arange(1, n + 1) |
|
|
ad_stat = -n - np.sum((2*i - 1) * (np.log(p_sorted) + np.log(1 - p_sorted[::-1]))) / n |
|
|
|
|
|
|
|
|
|
|
|
if n >= 25: |
|
|
ad_critical_05 = 2.492 |
|
|
ad_pval_approx = 0.05 if ad_stat > ad_critical_05 else 0.1 |
|
|
else: |
|
|
|
|
|
ad_critical_05 = 2.0 |
|
|
ad_pval_approx = 0.05 if ad_stat > ad_critical_05 else 0.1 |
|
|
|
|
|
|
|
|
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5)) |
|
|
|
|
|
|
|
|
expected_quantiles = np.linspace(0, 1, len(p_values)) |
|
|
observed_quantiles = np.sort(p_values) |
|
|
|
|
|
ax1.scatter(expected_quantiles, observed_quantiles, alpha=0.6, s=20) |
|
|
ax1.plot([0, 1], [0, 1], 'r--', label='Perfect calibration') |
|
|
ax1.set_xlabel('Expected quantiles (Uniform)') |
|
|
ax1.set_ylabel('Observed quantiles (P-values)') |
|
|
ax1.set_title(f'{title}\nQQ Plot vs Uniform(0,1)') |
|
|
ax1.legend() |
|
|
ax1.grid(True, alpha=0.3) |
|
|
|
|
|
|
|
|
ax2.hist(p_values, bins=20, density=True, alpha=0.7, color='skyblue', |
|
|
edgecolor='black', label='Observed') |
|
|
ax2.axhline(y=1.0, color='red', linestyle='--', label='Expected (Uniform)') |
|
|
ax2.set_xlabel('P-value') |
|
|
ax2.set_ylabel('Density') |
|
|
ax2.set_title(f'{title}\nP-value Histogram') |
|
|
ax2.legend() |
|
|
ax2.grid(True, alpha=0.3) |
|
|
|
|
|
plt.tight_layout() |
|
|
|
|
|
|
|
|
textstr = f'KS test: D={ks_stat:.4f}, p={ks_pval:.4f}\nAD test: A²={ad_stat:.4f}' |
|
|
fig.text(0.02, 0.02, textstr, fontsize=10, |
|
|
bbox=dict(boxstyle="round,pad=0.3", facecolor="lightgray")) |
|
|
|
|
|
if output_path: |
|
|
plt.savefig(output_path, dpi=300, bbox_inches='tight') |
|
|
print(f"Calibration plot saved to: {output_path}") |
|
|
else: |
|
|
plt.show() |
|
|
|
|
|
|
|
|
calibration_metrics = { |
|
|
'ks_statistic': ks_stat, |
|
|
'ks_pvalue': ks_pval, |
|
|
'ad_statistic': ad_stat, |
|
|
'ad_pvalue': ad_pval_approx, |
|
|
'is_calibrated_ks': ks_pval > alpha, |
|
|
'is_calibrated_ad': ad_pval_approx > alpha, |
|
|
'n_tests': len(p_values) |
|
|
} |
|
|
|
|
|
return calibration_metrics |
|
|
|
|
|
|
|
|
def summarize_calibration_results(calibration_metrics: Dict[str, float]) -> str: |
|
|
""" |
|
|
Generate a human-readable summary of calibration results. |
|
|
|
|
|
Args: |
|
|
calibration_metrics: Output from validate_calibration() |
|
|
|
|
|
Returns: |
|
|
Formatted string summary |
|
|
""" |
|
|
ks_result = "✓ Well-calibrated" if calibration_metrics['is_calibrated_ks'] else "✗ Poorly calibrated" |
|
|
ad_result = "✓ Well-calibrated" if calibration_metrics['is_calibrated_ad'] else "✗ Poorly calibrated" |
|
|
|
|
|
summary = f""" |
|
|
Calibration Assessment Summary (n = {calibration_metrics['n_tests']:,}) |
|
|
========================================= |
|
|
|
|
|
Kolmogorov-Smirnov Test: |
|
|
Statistic: {calibration_metrics['ks_statistic']:.4f} |
|
|
P-value: {calibration_metrics['ks_pvalue']:.4f} |
|
|
Result: {ks_result} |
|
|
|
|
|
Anderson-Darling Test: |
|
|
Statistic: {calibration_metrics['ad_statistic']:.4f} |
|
|
P-value: ~{calibration_metrics['ad_pvalue']:.3f} |
|
|
Result: {ad_result} |
|
|
|
|
|
Interpretation: |
|
|
- Well-calibrated methods should show p-values ~ Uniform(0,1) under null hypothesis |
|
|
- Significant test results (p < 0.05) indicate poor calibration |
|
|
- QQ plot should follow diagonal line for good calibration |
|
|
""" |
|
|
|
|
|
return summary |
|
|
|
|
|
|
|
|
def load_pretrained_model(checkpoint_path: Optional[str] = None, device: Optional[str] = None): |
|
|
""" |
|
|
Load the pre-trained NB-Transformer model. |
|
|
|
|
|
Args: |
|
|
checkpoint_path: Path to checkpoint file. If None, uses bundled v13 model. |
|
|
device: Device to load model on ('cpu', 'cuda', 'mps'). If None, auto-detects. |
|
|
|
|
|
Returns: |
|
|
Loaded DispersionTransformer model ready for inference |
|
|
|
|
|
Example: |
|
|
>>> from nb_transformer import load_pretrained_model |
|
|
>>> model = load_pretrained_model() |
|
|
>>> params = model.predict_parameters([2.1, 1.8, 2.3], [1.5, 1.2, 1.7]) |
|
|
""" |
|
|
import torch |
|
|
import os |
|
|
from .model import DispersionTransformer |
|
|
from .train import DispersionLightningModule |
|
|
|
|
|
|
|
|
if device is None: |
|
|
if torch.cuda.is_available(): |
|
|
device = 'cuda' |
|
|
elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): |
|
|
device = 'mps' |
|
|
else: |
|
|
device = 'cpu' |
|
|
|
|
|
|
|
|
if checkpoint_path is None: |
|
|
package_dir = os.path.dirname(__file__) |
|
|
checkpoint_path = os.path.join(package_dir, '..', 'model_checkpoint', 'last-v13.ckpt') |
|
|
|
|
|
if not os.path.exists(checkpoint_path): |
|
|
raise FileNotFoundError( |
|
|
f"Bundled model checkpoint not found at {checkpoint_path}. " |
|
|
"Please provide checkpoint_path explicitly." |
|
|
) |
|
|
|
|
|
|
|
|
try: |
|
|
lightning_module = DispersionLightningModule.load_from_checkpoint( |
|
|
checkpoint_path, |
|
|
map_location=device |
|
|
) |
|
|
model = lightning_module.model |
|
|
model.to(device) |
|
|
model.eval() |
|
|
return model |
|
|
|
|
|
except Exception as e: |
|
|
raise RuntimeError(f"Failed to load model from {checkpoint_path}: {e}") |
|
|
|
|
|
|
|
|
def quick_inference_example(): |
|
|
""" |
|
|
Demonstrate quick inference with the pre-trained model. |
|
|
|
|
|
Returns: |
|
|
Dictionary with example parameters |
|
|
""" |
|
|
|
|
|
model = load_pretrained_model() |
|
|
|
|
|
|
|
|
condition_1 = [2.1, 1.8, 2.3, 2.0] |
|
|
condition_2 = [1.5, 1.2, 1.7, 1.4, 1.6] |
|
|
|
|
|
|
|
|
params = model.predict_parameters(condition_1, condition_2) |
|
|
|
|
|
print("NB-Transformer Quick Inference Example") |
|
|
print("=====================================") |
|
|
print(f"Control samples: {condition_1}") |
|
|
print(f"Treatment samples: {condition_2}") |
|
|
print(f"μ̂ (base mean): {params['mu']:.3f}") |
|
|
print(f"β̂ (log fold change): {params['beta']:.3f}") |
|
|
print(f"α̂ (log dispersion): {params['alpha']:.3f}") |
|
|
print(f"Fold change: {np.exp(params['beta']):.2f}x") |
|
|
|
|
|
return params |