|
|
|
|
|
""" |
|
|
NB-Transformer Accuracy Validation Script |
|
|
|
|
|
This script compares the accuracy and speed of three methods for NB GLM parameter estimation: |
|
|
1. NB-Transformer: Fast neural network approach (14.8x faster than classical) |
|
|
2. Classical NB GLM: Maximum likelihood estimation via statsmodels |
|
|
3. Method of Moments: Fastest but least accurate approach |
|
|
|
|
|
Usage: |
|
|
python validate_accuracy.py --n_tests 1000 --output_dir results/ |
|
|
|
|
|
Expected Performance (based on v13 model): |
|
|
- NB-Transformer: 100% success, 0.076ms, μ MAE=0.202, β MAE=0.152, α MAE=0.477 |
|
|
- Classical GLM: 98.7% success, 1.128ms, μ MAE=0.212, β MAE=0.284, α MAE=0.854 |
|
|
- Method of Moments: 100% success, 0.021ms, μ MAE=0.213, β MAE=0.289, α MAE=0.852 |
|
|
""" |
|
|
|
|
|
import os |
|
|
import sys |
|
|
import time |
|
|
import argparse |
|
|
import numpy as np |
|
|
import pandas as pd |
|
|
import matplotlib.pyplot as plt |
|
|
from typing import Dict, List, Tuple, Optional |
|
|
from scipy import stats |
|
|
import warnings |
|
|
|
|
|
|
|
|
try: |
|
|
from nb_transformer import load_pretrained_model, estimate_batch_parameters_vectorized |
|
|
TRANSFORMER_AVAILABLE = True |
|
|
except ImportError: |
|
|
TRANSFORMER_AVAILABLE = False |
|
|
print("Warning: nb-transformer not available. Install with: pip install nb-transformer") |
|
|
|
|
|
|
|
|
try: |
|
|
import statsmodels.api as sm |
|
|
from statsmodels.discrete.discrete_model import NegativeBinomial |
|
|
STATSMODELS_AVAILABLE = True |
|
|
except ImportError: |
|
|
STATSMODELS_AVAILABLE = False |
|
|
print("Warning: statsmodels not available. Install with: pip install statsmodels") |
|
|
|
|
|
|
|
|
try: |
|
|
from theme_nxn import theme_nxn, get_nxn_palette |
|
|
THEME_AVAILABLE = True |
|
|
except ImportError: |
|
|
THEME_AVAILABLE = False |
|
|
print("Warning: theme_nxn not available, using default matplotlib styling") |
|
|
|
|
|
|
|
|
def generate_test_data(n_tests: int = 1000, seed: int = 42) -> List[Dict]: |
|
|
""" |
|
|
Generate synthetic test cases with known ground truth parameters. |
|
|
|
|
|
Returns: |
|
|
List of test cases with known parameters and generated data |
|
|
""" |
|
|
print(f"Generating {n_tests} synthetic test cases...") |
|
|
|
|
|
np.random.seed(seed) |
|
|
test_cases = [] |
|
|
|
|
|
for i in range(n_tests): |
|
|
|
|
|
mu_true = np.random.normal(-1.0, 2.0) |
|
|
alpha_true = np.random.normal(-2.0, 1.0) |
|
|
|
|
|
|
|
|
if np.random.random() < 0.3: |
|
|
beta_true = np.random.normal(0, 1.0) |
|
|
else: |
|
|
beta_true = 0.0 |
|
|
|
|
|
|
|
|
n1, n2 = 3, 3 |
|
|
|
|
|
|
|
|
lib_sizes_1 = np.random.lognormal(np.log(10000) - 0.5*np.log(1.09), |
|
|
np.sqrt(np.log(1.09)), n1) |
|
|
lib_sizes_2 = np.random.lognormal(np.log(10000) - 0.5*np.log(1.09), |
|
|
np.sqrt(np.log(1.09)), n2) |
|
|
|
|
|
|
|
|
mean_expr = np.exp(mu_true) |
|
|
dispersion = np.exp(alpha_true) |
|
|
|
|
|
|
|
|
counts_1 = [] |
|
|
for lib_size in lib_sizes_1: |
|
|
mean_count = lib_size * mean_expr |
|
|
r = 1.0 / dispersion |
|
|
p = r / (r + mean_count) |
|
|
count = np.random.negative_binomial(r, p) |
|
|
counts_1.append(count) |
|
|
|
|
|
|
|
|
counts_2 = [] |
|
|
for lib_size in lib_sizes_2: |
|
|
mean_count = lib_size * mean_expr * np.exp(beta_true) |
|
|
r = 1.0 / dispersion |
|
|
p = r / (r + mean_count) |
|
|
count = np.random.negative_binomial(r, p) |
|
|
counts_2.append(count) |
|
|
|
|
|
|
|
|
transformed_1 = [np.log10(1e4 * c / l + 1) for c, l in zip(counts_1, lib_sizes_1)] |
|
|
transformed_2 = [np.log10(1e4 * c / l + 1) for c, l in zip(counts_2, lib_sizes_2)] |
|
|
|
|
|
test_cases.append({ |
|
|
'mu_true': mu_true, |
|
|
'beta_true': beta_true, |
|
|
'alpha_true': alpha_true, |
|
|
'counts_1': np.array(counts_1), |
|
|
'counts_2': np.array(counts_2), |
|
|
'lib_sizes_1': np.array(lib_sizes_1), |
|
|
'lib_sizes_2': np.array(lib_sizes_2), |
|
|
'transformed_1': np.array(transformed_1), |
|
|
'transformed_2': np.array(transformed_2) |
|
|
}) |
|
|
|
|
|
return test_cases |
|
|
|
|
|
|
|
|
def fit_transformer(model, test_cases: List[Dict]) -> Tuple[List[Dict], float]: |
|
|
"""Fit NB-Transformer to all test cases.""" |
|
|
print("Fitting NB-Transformer...") |
|
|
|
|
|
results = [] |
|
|
start_time = time.perf_counter() |
|
|
|
|
|
for case in test_cases: |
|
|
try: |
|
|
params = model.predict_parameters(case['transformed_1'], case['transformed_2']) |
|
|
results.append({ |
|
|
'mu_pred': params['mu'], |
|
|
'beta_pred': params['beta'], |
|
|
'alpha_pred': params['alpha'], |
|
|
'success': True |
|
|
}) |
|
|
except Exception as e: |
|
|
results.append({ |
|
|
'mu_pred': np.nan, |
|
|
'beta_pred': np.nan, |
|
|
'alpha_pred': np.nan, |
|
|
'success': False |
|
|
}) |
|
|
|
|
|
total_time = time.perf_counter() - start_time |
|
|
avg_time_ms = (total_time / len(test_cases)) * 1000 |
|
|
|
|
|
return results, avg_time_ms |
|
|
|
|
|
|
|
|
def fit_statsmodels(test_cases: List[Dict]) -> Tuple[List[Dict], float]: |
|
|
"""Fit classical NB GLM via statsmodels.""" |
|
|
if not STATSMODELS_AVAILABLE: |
|
|
return [], 0.0 |
|
|
|
|
|
print("Fitting classical NB GLM...") |
|
|
|
|
|
results = [] |
|
|
start_time = time.perf_counter() |
|
|
|
|
|
for case in test_cases: |
|
|
try: |
|
|
|
|
|
counts = np.concatenate([case['counts_1'], case['counts_2']]) |
|
|
exposures = np.concatenate([case['lib_sizes_1'], case['lib_sizes_2']]) |
|
|
X = np.concatenate([np.zeros(len(case['counts_1'])), |
|
|
np.ones(len(case['counts_2']))]) |
|
|
X_design = sm.add_constant(X) |
|
|
|
|
|
|
|
|
with warnings.catch_warnings(): |
|
|
warnings.simplefilter("ignore") |
|
|
model = NegativeBinomial(counts, X_design, exposure=exposures) |
|
|
fitted = model.fit(disp=0, maxiter=1000) |
|
|
|
|
|
|
|
|
mu_pred = fitted.params[0] |
|
|
beta_pred = fitted.params[1] |
|
|
alpha_pred = np.log(fitted.params[2]) |
|
|
|
|
|
results.append({ |
|
|
'mu_pred': mu_pred, |
|
|
'beta_pred': beta_pred, |
|
|
'alpha_pred': alpha_pred, |
|
|
'success': True |
|
|
}) |
|
|
|
|
|
except Exception as e: |
|
|
results.append({ |
|
|
'mu_pred': np.nan, |
|
|
'beta_pred': np.nan, |
|
|
'alpha_pred': np.nan, |
|
|
'success': False |
|
|
}) |
|
|
|
|
|
total_time = time.perf_counter() - start_time |
|
|
avg_time_ms = (total_time / len(test_cases)) * 1000 |
|
|
|
|
|
return results, avg_time_ms |
|
|
|
|
|
|
|
|
def fit_method_of_moments(test_cases: List[Dict]) -> Tuple[List[Dict], float]: |
|
|
"""Fit Method of Moments estimator.""" |
|
|
print("Fitting Method of Moments...") |
|
|
|
|
|
results = [] |
|
|
start_time = time.perf_counter() |
|
|
|
|
|
for case in test_cases: |
|
|
try: |
|
|
params = estimate_batch_parameters_vectorized( |
|
|
[case['transformed_1']], |
|
|
[case['transformed_2']] |
|
|
)[0] |
|
|
|
|
|
results.append({ |
|
|
'mu_pred': params['mu'], |
|
|
'beta_pred': params['beta'], |
|
|
'alpha_pred': params['alpha'], |
|
|
'success': True |
|
|
}) |
|
|
|
|
|
except Exception as e: |
|
|
results.append({ |
|
|
'mu_pred': np.nan, |
|
|
'beta_pred': np.nan, |
|
|
'alpha_pred': np.nan, |
|
|
'success': False |
|
|
}) |
|
|
|
|
|
total_time = time.perf_counter() - start_time |
|
|
avg_time_ms = (total_time / len(test_cases)) * 1000 |
|
|
|
|
|
return results, avg_time_ms |
|
|
|
|
|
|
|
|
def compute_metrics(results: List[Dict], test_cases: List[Dict]) -> Dict: |
|
|
"""Compute accuracy metrics for a method.""" |
|
|
successes = [r for r in results if r['success']] |
|
|
n_success = len(successes) |
|
|
n_total = len(results) |
|
|
|
|
|
if n_success == 0: |
|
|
return { |
|
|
'success_rate': 0.0, |
|
|
'mu_mae': np.nan, |
|
|
'beta_mae': np.nan, |
|
|
'alpha_mae': np.nan, |
|
|
'mu_rmse': np.nan, |
|
|
'beta_rmse': np.nan, |
|
|
'alpha_rmse': np.nan |
|
|
} |
|
|
|
|
|
|
|
|
mu_pred = np.array([r['mu_pred'] for r in successes]) |
|
|
beta_pred = np.array([r['beta_pred'] for r in successes]) |
|
|
alpha_pred = np.array([r['alpha_pred'] for r in successes]) |
|
|
|
|
|
mu_true = np.array([test_cases[i]['mu_true'] for i, r in enumerate(results) if r['success']]) |
|
|
beta_true = np.array([test_cases[i]['beta_true'] for i, r in enumerate(results) if r['success']]) |
|
|
alpha_true = np.array([test_cases[i]['alpha_true'] for i, r in enumerate(results) if r['success']]) |
|
|
|
|
|
return { |
|
|
'success_rate': n_success / n_total, |
|
|
'mu_mae': np.mean(np.abs(mu_pred - mu_true)), |
|
|
'beta_mae': np.mean(np.abs(beta_pred - beta_true)), |
|
|
'alpha_mae': np.mean(np.abs(alpha_pred - alpha_true)), |
|
|
'mu_rmse': np.sqrt(np.mean((mu_pred - mu_true)**2)), |
|
|
'beta_rmse': np.sqrt(np.mean((beta_pred - beta_true)**2)), |
|
|
'alpha_rmse': np.sqrt(np.mean((alpha_pred - alpha_true)**2)) |
|
|
} |
|
|
|
|
|
|
|
|
def create_comparison_plot(transformer_metrics: Dict, |
|
|
statsmodels_metrics: Dict, |
|
|
mom_metrics: Dict, |
|
|
transformer_time: float, |
|
|
statsmodels_time: float, |
|
|
mom_time: float, |
|
|
output_dir: str): |
|
|
"""Create comparison visualization.""" |
|
|
|
|
|
if THEME_AVAILABLE: |
|
|
palette = get_nxn_palette() |
|
|
else: |
|
|
palette = ['#1f77b4', '#ff7f0e', '#2ca02c'] |
|
|
|
|
|
fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(12, 10)) |
|
|
|
|
|
methods = ['NB-Transformer', 'Classical GLM', 'Method of Moments'] |
|
|
colors = palette[:3] |
|
|
|
|
|
|
|
|
success_rates = [ |
|
|
transformer_metrics['success_rate'] * 100, |
|
|
statsmodels_metrics['success_rate'] * 100 if STATSMODELS_AVAILABLE else 0, |
|
|
mom_metrics['success_rate'] * 100 |
|
|
] |
|
|
ax1.bar(methods, success_rates, color=colors, alpha=0.7) |
|
|
ax1.set_ylabel('Success Rate (%)') |
|
|
ax1.set_title('Convergence Success Rate') |
|
|
ax1.set_ylim(95, 101) |
|
|
|
|
|
|
|
|
times = [transformer_time, statsmodels_time if STATSMODELS_AVAILABLE else 0, mom_time] |
|
|
ax2.bar(methods, times, color=colors, alpha=0.7) |
|
|
ax2.set_ylabel('Average Time (ms)') |
|
|
ax2.set_title('Inference Speed') |
|
|
ax2.set_yscale('log') |
|
|
|
|
|
|
|
|
parameters = ['μ', 'β', 'α'] |
|
|
transformer_mae = [transformer_metrics['mu_mae'], transformer_metrics['beta_mae'], transformer_metrics['alpha_mae']] |
|
|
statsmodels_mae = [statsmodels_metrics['mu_mae'], statsmodels_metrics['beta_mae'], statsmodels_metrics['alpha_mae']] if STATSMODELS_AVAILABLE else [0, 0, 0] |
|
|
mom_mae = [mom_metrics['mu_mae'], mom_metrics['beta_mae'], mom_metrics['alpha_mae']] |
|
|
|
|
|
x = np.arange(len(parameters)) |
|
|
width = 0.25 |
|
|
|
|
|
ax3.bar(x - width, transformer_mae, width, label='NB-Transformer', color=colors[0], alpha=0.7) |
|
|
if STATSMODELS_AVAILABLE: |
|
|
ax3.bar(x, statsmodels_mae, width, label='Classical GLM', color=colors[1], alpha=0.7) |
|
|
ax3.bar(x + width, mom_mae, width, label='Method of Moments', color=colors[2], alpha=0.7) |
|
|
|
|
|
ax3.set_ylabel('Mean Absolute Error') |
|
|
ax3.set_title('Parameter Estimation Accuracy') |
|
|
ax3.set_xticks(x) |
|
|
ax3.set_xticklabels(parameters) |
|
|
ax3.legend() |
|
|
|
|
|
|
|
|
ax4.axis('tight') |
|
|
ax4.axis('off') |
|
|
|
|
|
table_data = [ |
|
|
['Method', 'Success %', 'Time (ms)', 'β MAE'], |
|
|
['NB-Transformer', f"{success_rates[0]:.1f}%", f"{transformer_time:.3f}", f"{transformer_metrics['beta_mae']:.3f}"], |
|
|
['Classical GLM', f"{success_rates[1]:.1f}%" if STATSMODELS_AVAILABLE else "N/A", f"{statsmodels_time:.3f}" if STATSMODELS_AVAILABLE else "N/A", f"{statsmodels_metrics['beta_mae']:.3f}" if STATSMODELS_AVAILABLE else "N/A"], |
|
|
['Method of Moments', f"{success_rates[2]:.1f}%", f"{mom_time:.3f}", f"{mom_metrics['beta_mae']:.3f}"] |
|
|
] |
|
|
|
|
|
table = ax4.table(cellText=table_data, cellLoc='center', loc='center') |
|
|
table.auto_set_font_size(False) |
|
|
table.set_fontsize(10) |
|
|
table.scale(1.2, 1.5) |
|
|
|
|
|
|
|
|
for i in range(4): |
|
|
table[(0, i)].set_facecolor('#40466e') |
|
|
table[(0, i)].set_text_props(weight='bold', color='white') |
|
|
|
|
|
if THEME_AVAILABLE: |
|
|
pass |
|
|
|
|
|
plt.tight_layout() |
|
|
plt.savefig(os.path.join(output_dir, 'accuracy_comparison.png'), dpi=300, bbox_inches='tight') |
|
|
plt.show() |
|
|
|
|
|
|
|
|
def print_summary(transformer_metrics: Dict, |
|
|
statsmodels_metrics: Dict, |
|
|
mom_metrics: Dict, |
|
|
transformer_time: float, |
|
|
statsmodels_time: float, |
|
|
mom_time: float): |
|
|
"""Print summary of results.""" |
|
|
|
|
|
print("\n" + "="*80) |
|
|
print("NB-TRANSFORMER ACCURACY VALIDATION RESULTS") |
|
|
print("="*80) |
|
|
|
|
|
print(f"\n📊 METHOD COMPARISON") |
|
|
print(f"{'Method':<20} {'Success %':<12} {'Time (ms)':<12} {'μ MAE':<10} {'β MAE':<10} {'α MAE':<10}") |
|
|
print("-" * 80) |
|
|
|
|
|
print(f"{'NB-Transformer':<20} {transformer_metrics['success_rate']*100:>8.1f}% {transformer_time:>8.3f} {transformer_metrics['mu_mae']:>6.3f} {transformer_metrics['beta_mae']:>6.3f} {transformer_metrics['alpha_mae']:>6.3f}") |
|
|
|
|
|
if STATSMODELS_AVAILABLE: |
|
|
print(f"{'Classical GLM':<20} {statsmodels_metrics['success_rate']*100:>8.1f}% {statsmodels_time:>8.3f} {statsmodels_metrics['mu_mae']:>6.3f} {statsmodels_metrics['beta_mae']:>6.3f} {statsmodels_metrics['alpha_mae']:>6.3f}") |
|
|
|
|
|
print(f"{'Method of Moments':<20} {mom_metrics['success_rate']*100:>8.1f}% {mom_time:>8.3f} {mom_metrics['mu_mae']:>6.3f} {mom_metrics['beta_mae']:>6.3f} {mom_metrics['alpha_mae']:>6.3f}") |
|
|
|
|
|
if STATSMODELS_AVAILABLE and statsmodels_time > 0: |
|
|
speedup = statsmodels_time / transformer_time |
|
|
accuracy_improvement = (statsmodels_metrics['beta_mae'] - transformer_metrics['beta_mae']) / statsmodels_metrics['beta_mae'] * 100 |
|
|
|
|
|
print(f"\n🚀 KEY ACHIEVEMENTS:") |
|
|
print(f" • {speedup:.1f}x faster than classical GLM") |
|
|
print(f" • {accuracy_improvement:.0f}% better accuracy on β (log fold change)") |
|
|
print(f" • {transformer_metrics['success_rate']*100:.1f}% success rate vs {statsmodels_metrics['success_rate']*100:.1f}% for classical GLM") |
|
|
|
|
|
print(f"\n✅ VALIDATION COMPLETE: NB-Transformer maintains superior speed and accuracy") |
|
|
|
|
|
|
|
|
def main(): |
|
|
parser = argparse.ArgumentParser(description='Validate NB-Transformer accuracy') |
|
|
parser.add_argument('--n_tests', type=int, default=1000, help='Number of test cases') |
|
|
parser.add_argument('--output_dir', type=str, default='validation_results', help='Output directory') |
|
|
parser.add_argument('--seed', type=int, default=42, help='Random seed') |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
|
os.makedirs(args.output_dir, exist_ok=True) |
|
|
|
|
|
|
|
|
if not TRANSFORMER_AVAILABLE: |
|
|
print("❌ nb-transformer not available. Please install: pip install nb-transformer") |
|
|
return |
|
|
|
|
|
|
|
|
print("Loading pre-trained NB-Transformer...") |
|
|
model = load_pretrained_model() |
|
|
|
|
|
|
|
|
test_cases = generate_test_data(args.n_tests, args.seed) |
|
|
|
|
|
|
|
|
transformer_results, transformer_time = fit_transformer(model, test_cases) |
|
|
statsmodels_results, statsmodels_time = fit_statsmodels(test_cases) |
|
|
mom_results, mom_time = fit_method_of_moments(test_cases) |
|
|
|
|
|
|
|
|
transformer_metrics = compute_metrics(transformer_results, test_cases) |
|
|
statsmodels_metrics = compute_metrics(statsmodels_results, test_cases) |
|
|
mom_metrics = compute_metrics(mom_results, test_cases) |
|
|
|
|
|
|
|
|
create_comparison_plot( |
|
|
transformer_metrics, statsmodels_metrics, mom_metrics, |
|
|
transformer_time, statsmodels_time, mom_time, |
|
|
args.output_dir |
|
|
) |
|
|
|
|
|
|
|
|
print_summary( |
|
|
transformer_metrics, statsmodels_metrics, mom_metrics, |
|
|
transformer_time, statsmodels_time, mom_time |
|
|
) |
|
|
|
|
|
|
|
|
results_df = pd.DataFrame({ |
|
|
'method': ['NB-Transformer', 'Classical GLM', 'Method of Moments'], |
|
|
'success_rate': [transformer_metrics['success_rate'], |
|
|
statsmodels_metrics['success_rate'] if STATSMODELS_AVAILABLE else np.nan, |
|
|
mom_metrics['success_rate']], |
|
|
'avg_time_ms': [transformer_time, |
|
|
statsmodels_time if STATSMODELS_AVAILABLE else np.nan, |
|
|
mom_time], |
|
|
'mu_mae': [transformer_metrics['mu_mae'], |
|
|
statsmodels_metrics['mu_mae'] if STATSMODELS_AVAILABLE else np.nan, |
|
|
mom_metrics['mu_mae']], |
|
|
'beta_mae': [transformer_metrics['beta_mae'], |
|
|
statsmodels_metrics['beta_mae'] if STATSMODELS_AVAILABLE else np.nan, |
|
|
mom_metrics['beta_mae']], |
|
|
'alpha_mae': [transformer_metrics['alpha_mae'], |
|
|
statsmodels_metrics['alpha_mae'] if STATSMODELS_AVAILABLE else np.nan, |
|
|
mom_metrics['alpha_mae']] |
|
|
}) |
|
|
|
|
|
results_df.to_csv(os.path.join(args.output_dir, 'accuracy_results.csv'), index=False) |
|
|
print(f"\n💾 Results saved to {args.output_dir}/") |
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
main() |