nb-transformer / example_usage.py
valsv's picture
Upload folder using huggingface_hub
ccd282b verified
#!/usr/bin/env python
"""
NB-Transformer Example Usage Script
This script demonstrates the basic usage of NB-Transformer for fast
Negative Binomial GLM parameter estimation.
Run this script to see NB-Transformer in action:
python example_usage.py
"""
import numpy as np
from nb_transformer import load_pretrained_model, quick_inference_example
def basic_example():
"""Basic parameter estimation example."""
print("🚀 NB-TRANSFORMER BASIC EXAMPLE")
print("=" * 50)
# Load the pre-trained model
print("Loading pre-trained NB-Transformer model...")
model = load_pretrained_model()
print("✅ Model loaded successfully!")
# Example data (log10(CPM + 1) transformed)
control_samples = [2.1, 1.8, 2.3, 2.0, 1.9] # 5 control samples
treatment_samples = [1.5, 1.2, 1.7, 1.4, 1.6] # 5 treatment samples
print(f"\n📊 INPUT DATA")
print(f"Control samples (n={len(control_samples)}): {control_samples}")
print(f"Treatment samples (n={len(treatment_samples)}): {treatment_samples}")
# Predict NB GLM parameters
print(f"\n⚡ RUNNING INFERENCE...")
params = model.predict_parameters(control_samples, treatment_samples)
# Display results
print(f"\n📈 RESULTS")
print(f"μ̂ (base mean, log scale): {params['mu']:.3f}")
print(f"β̂ (log fold change): {params['beta']:.3f}")
print(f"α̂ (log dispersion): {params['alpha']:.3f}")
# Interpret results
fold_change = np.exp(params['beta'])
if fold_change > 1:
direction = "upregulated"
magnitude = f"{fold_change:.2f}x"
else:
direction = "downregulated"
magnitude = f"{1/fold_change:.2f}x"
print(f"\n🧬 BIOLOGICAL INTERPRETATION")
print(f"Fold change: {fold_change:.2f}x")
print(f"Gene appears to be {direction} ({magnitude})")
print(f"Base expression level: {np.exp(params['mu']):.2f}")
print(f"Dispersion parameter: {np.exp(params['alpha']):.3f}")
return params
def statistical_inference_example():
"""Complete statistical inference example with p-values."""
print(f"\n\n🔬 COMPLETE STATISTICAL INFERENCE EXAMPLE")
print("=" * 50)
from nb_transformer.inference import compute_nb_glm_inference
# Load model
model = load_pretrained_model()
# Simulate realistic RNA-seq data
print("📊 SIMULATING REALISTIC RNA-SEQ DATA")
# Control condition
control_counts = np.array([1520, 1280, 1650, 1400, 1350])
control_lib_sizes = np.array([1e6, 1.1e6, 0.9e6, 1.05e6, 0.95e6])
# Treatment condition (downregulated gene)
treatment_counts = np.array([980, 890, 1100, 950, 850])
treatment_lib_sizes = np.array([1e6, 1.0e6, 1.1e6, 0.95e6, 1.02e6])
print(f"Control counts: {control_counts}")
print(f"Treatment counts: {treatment_counts}")
print(f"Control library sizes: {np.mean(control_lib_sizes)/1e6:.2f}M (avg)")
print(f"Treatment library sizes: {np.mean(treatment_lib_sizes)/1e6:.2f}M (avg)")
# Transform to log10(CPM + 1)
control_transformed = np.log10(1e4 * control_counts / control_lib_sizes + 1)
treatment_transformed = np.log10(1e4 * treatment_counts / treatment_lib_sizes + 1)
print(f"\n⚡ PARAMETER ESTIMATION")
params = model.predict_parameters(control_transformed, treatment_transformed)
print(f"\n🧮 STATISTICAL INFERENCE")
# Complete statistical analysis with p-values
results = compute_nb_glm_inference(
params['mu'], params['beta'], params['alpha'],
control_counts, treatment_counts,
control_lib_sizes, treatment_lib_sizes
)
print(f"Parameter estimates:")
print(f" μ̂ = {results['mu']:.3f} (base mean)")
print(f" β̂ = {results['beta']:.3f} ± {results['se_beta']:.3f} (log fold change)")
print(f" α̂ = {results['alpha']:.3f} (log dispersion)")
print(f"\nStatistical test results:")
print(f" Wald statistic: {results['wald_stat']:.3f}")
print(f" P-value: {results['pvalue']:.2e}")
print(f" Significant (α=0.05): {'✅ Yes' if results['pvalue'] < 0.05 else '❌ No'}")
# Confidence interval
z_alpha = 1.96 # 95% CI
ci_lower = results['beta'] - z_alpha * results['se_beta']
ci_upper = results['beta'] + z_alpha * results['se_beta']
print(f"\n📊 95% CONFIDENCE INTERVAL")
print(f"Log fold change: [{ci_lower:.3f}, {ci_upper:.3f}]")
print(f"Fold change: [{np.exp(ci_lower):.3f}x, {np.exp(ci_upper):.3f}x]")
return results
def speed_comparison_example():
"""Demonstrate speed advantage over classical methods."""
print(f"\n\n⚡ SPEED COMPARISON EXAMPLE")
print("=" * 50)
import time
# Load model
model = load_pretrained_model()
# Generate test data
n_tests = 100
print(f"Running {n_tests} parameter estimation tests...")
test_cases = []
for _ in range(n_tests):
control = np.random.lognormal(0, 0.5, 5)
treatment = np.random.lognormal(0, 0.5, 5)
test_cases.append((control, treatment))
# Time NB-Transformer
print(f"\n🚀 Testing NB-Transformer speed...")
start_time = time.perf_counter()
for control, treatment in test_cases:
params = model.predict_parameters(control, treatment)
transformer_time = time.perf_counter() - start_time
transformer_avg = (transformer_time / n_tests) * 1000 # ms per test
print(f"NB-Transformer: {transformer_time:.3f}s total, {transformer_avg:.3f}ms per test")
# Compare with Method of Moments (fastest baseline)
print(f"\n📊 Testing Method of Moments speed...")
from nb_transformer import estimate_batch_parameters_vectorized
start_time = time.perf_counter()
control_batch = [case[0] for case in test_cases]
treatment_batch = [case[1] for case in test_cases]
results = estimate_batch_parameters_vectorized(control_batch, treatment_batch)
mom_time = time.perf_counter() - start_time
mom_avg = (mom_time / n_tests) * 1000 # ms per test
print(f"Method of Moments: {mom_time:.3f}s total, {mom_avg:.3f}ms per test")
# Speed comparison
if mom_avg > 0:
speedup = mom_avg / transformer_avg
print(f"\n🏃 SPEED COMPARISON")
print(f"NB-Transformer vs Method of Moments: {speedup:.1f}x {'faster' if speedup > 1 else 'slower'}")
print(f"\n💡 Note: Classical GLM is typically ~15x slower than NB-Transformer")
print(f"Expected classical GLM time: ~{transformer_avg * 15:.1f}ms per test")
def main():
"""Run all examples."""
print("🧬 NB-TRANSFORMER DEMONSTRATION")
print("=" * 60)
print("Fast Negative Binomial GLM Parameter Estimation")
print("A modern replacement for DESeq2 statistical analysis")
print("=" * 60)
try:
# Run examples
basic_example()
statistical_inference_example()
speed_comparison_example()
print(f"\n\n✨ QUICK INFERENCE EXAMPLE")
print("=" * 50)
quick_inference_example()
print(f"\n\n🎉 ALL EXAMPLES COMPLETED SUCCESSFULLY!")
print("=" * 50)
print("🚀 Ready to use NB-Transformer in your research!")
print("📚 See examples/ directory for validation scripts")
print("🔗 Visit https://huggingface.co/valsv/nb-transformer for more info")
except Exception as e:
print(f"\n❌ Error running examples: {e}")
print("Please ensure nb-transformer is properly installed:")
print(" pip install nb-transformer")
raise
if __name__ == '__main__':
main()