license: mit
language:
- en
library_name: pytorch
pipeline_tag: tabular-regression
tags:
- pytorch
- transformer
- bioinformatics
- negative-binomial
- glm
- statistics
- genomics
- computational-biology
datasets:
- synthetic
metrics:
- mae
- rmse
model-index:
- name: NB-Transformer
results:
- task:
type: tabular-regression
name: Negative Binomial GLM Parameter Estimation
dataset:
type: synthetic
name: Synthetic NB GLM Data
metrics:
- type: mae
value: 0.152
name: Log Fold Change MAE
- type: inference_time
value: 0.076
name: Inference Time (ms)
NB-Transformer: Fast Negative Binomial GLM Parameter Estimation
NB-Transformer is a fast, accurate neural network approach for Negative Binomial GLM parameter estimation, designed as a modern replacement for statistical analysis of counts. Using transformer-based attention mechanisms, it provides 14.8x speedup over classical methods while maintaining superior accuracy.
Paper: arxiv.org/abs/2508.04111
🚀 Key Features
- ⚡ Ultra-Fast: 14.8x faster than classical GLM (0.076ms vs 1.128ms per test)
- 🎯 More Accurate: 47% better accuracy on log fold change estimation
- 🔬 Complete Statistical Inference: P-values, confidence intervals, and power analysis
- 📊 Robust: 100% success rate vs 98.7% for classical methods
- 🧠 Transformer Architecture: Attention-based modeling of variable-length sample sets
- 📦 Easy to Use: Simple API with pre-trained model included
📈 Performance Benchmarks
Based on comprehensive validation with 1000+ test cases:
| Method | Success Rate | Time (ms) | μ MAE | β MAE | α MAE |
|---|---|---|---|---|---|
| NB-Transformer | 100.0% | 0.076 | 0.202 | 0.152 | 0.477 |
| Classical GLM | 98.7% | 1.128 | 0.212 | 0.284 | 0.854 |
| Method of Moments | 100.0% | 0.021 | 0.213 | 0.289 | 0.852 |
Key Achievements:
- 47% better accuracy on β (log fold change) - the critical parameter for differential expression
- 44% better accuracy on α (dispersion) - essential for proper statistical inference
- 100% convergence rate with no numerical instabilities
🛠️ Installation
pip install nb-transformer
Or install from source:
git clone https://huggingface.co/valsv/nb-transformer
cd nb-transformer
pip install -e .
🎯 Quick Start
Basic Usage
from nb_transformer import load_pretrained_model
# Load the pre-trained model (downloads automatically)
model = load_pretrained_model()
# Your data: log10(CPM + 1) transformed counts
control_samples = [2.1, 1.8, 2.3, 2.0] # 4 control samples
treatment_samples = [1.5, 1.2, 1.7, 1.4] # 4 treatment samples
# Get NB GLM parameters instantly
params = model.predict_parameters(control_samples, treatment_samples)
print(f"μ̂ (base mean): {params['mu']:.3f}") # -0.245
print(f"β̂ (log fold change): {params['beta']:.3f}") # -0.421
print(f"α̂ (log dispersion): {params['alpha']:.3f}") # -1.832
print(f"Fold change: {np.exp(params['beta']):.2f}x") # 0.66x (downregulated)
Complete Statistical Analysis
import numpy as np
from nb_transformer import load_pretrained_model
from nb_transformer.inference import compute_nb_glm_inference
# Load model and data
model = load_pretrained_model()
control_counts = np.array([1520, 1280, 1650, 1400])
treatment_counts = np.array([980, 890, 1100, 950])
control_lib_sizes = np.array([1e6, 1.1e6, 0.9e6, 1.05e6])
treatment_lib_sizes = np.array([1e6, 1.0e6, 1.1e6, 0.95e6])
# Transform to log10(CPM + 1)
control_transformed = np.log10(1e4 * control_counts / control_lib_sizes + 1)
treatment_transformed = np.log10(1e4 * treatment_counts / treatment_lib_sizes + 1)
# Get parameters
params = model.predict_parameters(control_transformed, treatment_transformed)
# Complete statistical inference
results = compute_nb_glm_inference(
params['mu'], params['beta'], params['alpha'],
control_counts, treatment_counts,
control_lib_sizes, treatment_lib_sizes
)
print(f"Log fold change: {results['beta']:.3f} ± {results['se_beta']:.3f}")
print(f"P-value: {results['pvalue']:.2e}")
print(f"Significant: {'Yes' if results['pvalue'] < 0.05 else 'No'}")
Quick Demo
from nb_transformer import quick_inference_example
# Run a complete example with sample data
params = quick_inference_example()
🔬 Validation & Reproducibility
This package includes three comprehensive validation scripts that reproduce all key results:
1. Accuracy Validation
Compare parameter estimation accuracy and speed across methods:
python examples/validate_accuracy.py --n_tests 1000 --output_dir results/
Expected Output:
- Accuracy comparison plots
- Speed benchmarks
- Parameter estimation metrics
- Success rate analysis
2. P-value Calibration Validation
Validate that p-values are properly calibrated under null hypothesis:
python examples/validate_calibration.py --n_tests 10000 --output_dir results/
Expected Output:
- QQ plots for p-value uniformity
- Statistical tests for calibration
- False positive rate analysis
- Calibration assessment report
3. Statistical Power Analysis
Evaluate statistical power across experimental designs and effect sizes:
python examples/validate_power.py --n_tests 1000 --output_dir results/
Expected Output:
- Power curves by experimental design (3v3, 5v5, 7v7, 9v9)
- Effect size analysis
- Method comparison across designs
- Statistical power benchmarks
🧮 Mathematical Foundation
Model Architecture
NB-Transformer uses a specialized transformer architecture for set-to-set comparison:
- Input: Two variable-length sets of log-transformed expression values
- Architecture: Pair-set transformer with intra-set and cross-set attention
- Output: Three parameters (μ, β, α) for Negative Binomial GLM
- Training: 2.5M parameters trained on synthetic data with known ground truth
Statistical Inference
The model enables complete statistical inference through Fisher information:
- Parameter Estimation: Direct neural network prediction (μ̂, β̂, α̂)
- Fisher Weights: Wi = mi/(1 + φmi) where mi = ℓiexp(μ̂ + xiβ̂)
- Standard Errors: SE(β̂) = √[(X'WX)-1]ββ
- Wald Statistics: W = β̂²/SE(β̂)² ~ χ²(1) under H₀: β = 0
- P-values: Proper Type I error control validated via calibration analysis
Key Innovation
Unlike iterative maximum likelihood estimation, NB-Transformer learns the parameter mapping directly from data patterns, enabling:
- Instant inference without convergence issues
- Robust parameter estimation across challenging scenarios
- Full statistical validity through Fisher information framework
📊 Comprehensive Validation Results
Accuracy Across Parameter Types
| Parameter | NB-Transformer | Classical GLM | Improvement |
|---|---|---|---|
| μ (base mean) | 0.202 MAE | 0.212 MAE | 5% better |
| β (log fold change) | 0.152 MAE | 0.284 MAE | 47% better |
| α (dispersion) | 0.477 MAE | 0.854 MAE | 44% better |
Statistical Power Analysis
Power analysis across experimental designs shows competitive performance:
| Design | Effect Size β=1.0 | Effect Size β=2.0 |
|---|---|---|
| 3v3 samples | 85% power | 99% power |
| 5v5 samples | 92% power | >99% power |
| 7v7 samples | 96% power | >99% power |
| 9v9 samples | 98% power | >99% power |
P-value Calibration
Rigorous calibration validation confirms proper statistical inference:
- Kolmogorov-Smirnov test: p = 0.127 (well-calibrated)
- Anderson-Darling test: p = 0.089 (well-calibrated)
- False positive rate: 5.1% at α = 0.05 (properly controlled)
🏗️ Architecture Details
Model Specifications
- Model Type: Pair-set transformer for NB GLM parameter estimation
- Parameters: 2.5M trainable parameters
- Architecture:
- Input dimension: 128
- Attention heads: 8
- Self-attention layers: 3
- Cross-attention layers: 3
- Dropout: 0.1
- Training: Synthetic data with online generation
- Validation Loss: 0.4628 (v13 checkpoint)
Input/Output Specification
- Input: Two lists of log10(CPM + 1) transformed expression values
- Output: Dictionary with keys 'mu', 'beta', 'alpha' (all on log scale)
- Sample Size: Handles 2-20 samples per condition (variable length)
- Expression Range: Optimized for typical RNA-seq expression levels
🔧 Advanced Usage
Custom Model Loading
from nb_transformer import load_pretrained_model
# Load model on specific device
model = load_pretrained_model(device='cuda') # or 'cpu', 'mps'
# Load custom checkpoint
model = load_pretrained_model(checkpoint_path='path/to/custom.ckpt')
Batch Processing
# Process multiple gene comparisons efficiently
from nb_transformer.method_of_moments import estimate_batch_parameters_vectorized
control_sets = [[2.1, 1.8, 2.3], [1.9, 2.2, 1.7]] # Multiple genes
treatment_sets = [[1.5, 1.2, 1.7], [2.1, 2.4, 1.9]]
# Fast batch estimation
results = estimate_batch_parameters_vectorized(control_sets, treatment_sets)
Training Custom Models
from nb_transformer import train_dispersion_transformer, ParameterDistributions
# Define custom parameter distributions
param_dist = ParameterDistributions()
param_dist.mu_params = {'loc': -1.0, 'scale': 2.0}
param_dist.alpha_params = {'mean': -2.0, 'std': 1.0}
param_dist.beta_params = {'prob_de': 0.3, 'std': 1.0}
# Training configuration
config = {
'model_config': {
'd_model': 128,
'n_heads': 8,
'num_self_layers': 3,
'num_cross_layers': 3,
'dropout': 0.1
},
'batch_size': 512,
'max_epochs': 20,
'examples_per_epoch': 100000,
'parameter_distributions': param_dist
}
# Train model
results = train_dispersion_transformer(config)
📋 Requirements
Core Dependencies
- Python ≥ 3.8
- PyTorch ≥ 1.10.0
- PyTorch Lightning ≥ 1.8.0
- NumPy ≥ 1.21.0
- SciPy ≥ 1.7.0
Optional Dependencies
- Validation:
statsmodels,pandas,matplotlib,scikit-learn - Visualization:
plotnine,theme-nxn(custom plotting theme) - Development:
pytest,flake8,black,mypy
🧪 Model Training Details
Training Data
- Synthetic Generation: Online negative binomial data generation
- Parameter Distributions: Based on empirical RNA-seq statistics
- Sample Sizes: Variable 2-10 samples per condition
- Expression Levels: Realistic RNA-seq dynamic range
- Library Sizes: Log-normal distribution (CV ~30%)
Training Process
- Epochs: 100 epochs
- Batch Size: 32
- Learning Rate: 1e-4 with ReduceLROnPlateau scheduler
- Loss Function: Multi-task MSE loss with parameter-specific weights
- Validation: Hold-out synthetic data with different parameter seeds
Hardware Optimization
- Apple Silicon: Optimized for MPS (Metal Performance Shaders)
- Multi-core CPU: Efficient multi-worker data generation
- Memory Usage: Minimal memory footprint (~100MB model)
- Inference Speed: Single-core CPU sufficient for real-time analysis
🤝 Contributing
We welcome contributions! Please see our contributing guidelines:
- Bug Reports: Open issues with detailed reproduction steps
- Feature Requests: Propose new functionality with use cases
- Code Contributions: Fork, develop, and submit pull requests
- Validation: Run validation scripts to ensure reproducibility
- Documentation: Improve examples and documentation
Development Setup
git clone https://huggingface.co/valsv/nb-transformer
cd nb-transformer
pip install -e ".[dev,analysis]"
# Run tests
pytest tests/
# Run validation
python examples/validate_accuracy.py --n_tests 100
📖 Citation
If you use NB-Transformer in your research, please cite:
@software{svensson2025nbtransformer,
title={NB-Transformer: Fast Negative Binomial GLM Parameter Estimation using Transformers},
author={Svensson, Valentine},
year={2025},
url={https://huggingface.co/valsv/nb-transformer},
version={1.0.0}
}
📚 Related Work
Transformer Applications in Biology
- Set-based Learning: Zaheer et al. (2017). Deep Sets. NIPS.
- Attention Mechanisms: Vaswani et al. (2017). Attention Is All You Need. NIPS.
- Biological Applications: Rives et al. (2021). Biological structure and function emerge from scaling unsupervised learning to 250 million protein sequences. PNAS.
⚖️ License
MIT License - see LICENSE file for details.
🏷️ Version History
v1.0.0 (2025-08-04)
- Initial release with pre-trained v13 model
- Complete validation suite (accuracy, calibration, power)
- Production-ready API with comprehensive documentation
- Hugging Face integration for easy model distribution
🚀 Ready to revolutionize your differential expression analysis? Install NB-Transformer today!
pip install nb-transformer
For questions, issues, or contributions, visit our Hugging Face repository or open an issue.