valsv's picture
Upload folder using huggingface_hub
ccd282b verified
import os
import torch
import pandas as pd
import numpy as np
from torch.utils.data import Dataset, IterableDataset
from typing import List, Tuple, Optional, Dict, Union
from scipy import stats
from .utils import pad_sequences, create_padding_mask
class CollateWrapper:
"""Wrapper class for collate function to avoid pickling issues with multiprocessing."""
def __init__(self, padding_value):
self.padding_value = padding_value
def __call__(self, batch):
return collate_nb_glm_batch(batch, padding_value=self.padding_value)
def collate_nb_glm_batch(batch: List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor]],
padding_value: float = -1e9) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Collate function for variable-length NB GLM sequences.
Args:
batch: List of (set_1, set_2, targets) tuples
padding_value: Value to use for padding
Returns:
Tuple of (set_1_batch, set_2_batch, set_1_mask, set_2_mask, targets_batch)
"""
set_1_list, set_2_list, targets_list = zip(*batch)
# Pad sequences to same length within batch
set_1_padded = pad_sequences(list(set_1_list), padding_value=padding_value)
set_2_padded = pad_sequences(list(set_2_list), padding_value=padding_value)
# Create padding masks
set_1_mask = create_padding_mask(list(set_1_list))
set_2_mask = create_padding_mask(list(set_2_list))
# Stack targets
targets_batch = torch.stack(targets_list)
return set_1_padded, set_2_padded, set_1_mask, set_2_mask, targets_batch
class SyntheticNBGLMDataset(IterableDataset):
"""
Online synthetic data generator for Negative Binomial GLM parameter estimation.
Generates training examples on-the-fly with known ground truth parameters:
- mu: Base mean parameter (log scale)
- beta: Log fold change between conditions
- alpha: Dispersion parameter (log scale)
Each example consists of two sets of samples drawn from:
- Condition 1: x ~ NB(l * exp(mu), exp(alpha))
- Condition 2: x ~ NB(l * exp(mu + beta), exp(alpha))
Counts are transformed to: y = log10(1e4 * x / l + 1)
"""
TARGET_COLUMNS = ['mu', 'beta', 'alpha']
def __init__(self,
num_examples_per_epoch: int = 100000,
min_samples_per_condition: int = 2,
max_samples_per_condition: int = 10,
mu_loc: float = -1.0,
mu_scale: float = 2.0,
alpha_mean: float = -2.0,
alpha_std: float = 1.0,
beta_prob_de: float = 0.3,
beta_std: float = 1.0,
library_size_mean: float = 10000,
library_size_cv: float = 0.3,
seed: Optional[int] = None):
"""
Initialize synthetic NB GLM dataset.
Args:
num_examples_per_epoch: Number of examples to generate per epoch
min_samples_per_condition: Minimum samples per condition
max_samples_per_condition: Maximum samples per condition
mu_loc: Location parameter for mu log-normal distribution
mu_scale: Scale parameter for mu log-normal distribution
alpha_mean: Mean of alpha normal distribution
alpha_std: Std of alpha normal distribution
beta_prob_de: Probability of differential expression (non-zero beta)
beta_std: Standard deviation of beta when DE
library_size_mean: Mean library size
library_size_cv: Coefficient of variation for library size
seed: Random seed for reproducibility
"""
self.num_examples_per_epoch = num_examples_per_epoch
self.min_samples = min_samples_per_condition
self.max_samples = max_samples_per_condition
# Parameter distribution parameters
self.mu_loc = mu_loc
self.mu_scale = mu_scale
self.alpha_mean = alpha_mean
self.alpha_std = alpha_std
self.beta_prob_de = beta_prob_de
self.beta_std = beta_std
# Library size parameters
self.library_size_mean = library_size_mean
self.library_size_cv = library_size_cv
self.library_size_std = library_size_mean * library_size_cv
# Target normalization parameters for unit-normal targets
self.target_stats = {
'mu': {'mean': mu_loc, 'std': mu_scale},
'alpha': {'mean': alpha_mean, 'std': alpha_std},
# Beta mixture: mean=0, std=sqrt(prob_de * std^2)
'beta': {'mean': 0.0, 'std': (beta_prob_de * beta_std**2)**0.5}
}
# Random number generator
self.seed = seed
self.rng = np.random.RandomState(seed)
def __len__(self):
"""Return the number of examples per epoch for progress tracking."""
return self.num_examples_per_epoch
def __iter__(self):
"""Infinite iterator that generates examples on-the-fly."""
worker_info = torch.utils.data.get_worker_info()
# Handle multi-worker data loading
if worker_info is None:
# Single-process data loading
examples_per_worker = self.num_examples_per_epoch
worker_id = 0
else:
# Multi-process data loading
examples_per_worker = self.num_examples_per_epoch // worker_info.num_workers
worker_id = worker_info.id
# Set different seed for each worker
if self.seed is not None:
self.rng = np.random.RandomState(self.seed + worker_id)
# Generate examples
for _ in range(examples_per_worker):
yield self._generate_example()
def _generate_example(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Generate a single training example."""
# Sample parameters
mu = self._sample_mu()
alpha = self._sample_alpha(mu)
beta = self._sample_beta()
# Sample experimental design
n1 = self.rng.randint(self.min_samples, self.max_samples + 1)
n2 = self.rng.randint(self.min_samples, self.max_samples + 1)
# Generate counts for condition 1
set_1 = self._generate_set(mu, alpha, n1)
# Generate counts for condition 2 (with beta offset)
set_2 = self._generate_set(mu + beta, alpha, n2)
# Create normalized target tensor for better regression performance
targets_raw = {'mu': mu, 'beta': beta, 'alpha': alpha}
targets_normalized = self._normalize_targets(targets_raw)
targets = torch.tensor([targets_normalized['mu'], targets_normalized['beta'], targets_normalized['alpha']], dtype=torch.float32)
return set_1, set_2, targets
def _normalize_targets(self, targets: Dict[str, float]) -> Dict[str, float]:
"""Normalize targets to unit normal for better regression performance."""
normalized = {}
for param in ['mu', 'beta', 'alpha']:
mean = self.target_stats[param]['mean']
std = self.target_stats[param]['std']
# Avoid division by zero
std = max(std, 1e-8)
normalized[param] = (targets[param] - mean) / std
return normalized
def denormalize_targets(self, normalized_targets: Dict[str, float]) -> Dict[str, float]:
"""Denormalize targets back to original scale."""
denormalized = {}
for param in ['mu', 'beta', 'alpha']:
mean = self.target_stats[param]['mean']
std = self.target_stats[param]['std']
denormalized[param] = normalized_targets[param] * std + mean
return denormalized
def _sample_mu(self) -> float:
"""Sample base mean parameter from log-normal distribution."""
return self.rng.normal(self.mu_loc, self.mu_scale)
def _sample_alpha(self, mu: float) -> float:
"""
Sample dispersion parameter.
For now, we use a simple normal distribution.
In the future, this could model the mean-dispersion relationship.
"""
# Simple independent sampling for now
return self.rng.normal(self.alpha_mean, self.alpha_std)
def _sample_beta(self) -> float:
"""Sample log fold change with mixture distribution."""
if self.rng.random() < self.beta_prob_de:
# Differential expression - sample from normal
return self.rng.normal(0, self.beta_std)
else:
# No differential expression
return 0.0
def _sample_library_sizes(self, n_samples: int) -> np.ndarray:
"""Sample library sizes from log-normal distribution."""
# Use log-normal to ensure positive values with realistic variation
log_mean = np.log(self.library_size_mean) - 0.5 * np.log(1 + self.library_size_cv**2)
log_std = np.sqrt(np.log(1 + self.library_size_cv**2))
return self.rng.lognormal(log_mean, log_std, size=n_samples)
def _generate_set(self, mu: float, alpha: float, n_samples: int) -> torch.Tensor:
"""
Generate a set of transformed counts from NB distribution.
Args:
mu: Log mean parameter
alpha: Log dispersion parameter
n_samples: Number of samples to generate
Returns:
Tensor of shape (n_samples, 1) with transformed counts
"""
# Sample library sizes
library_sizes = self._sample_library_sizes(n_samples)
# Convert parameters from log scale
mean_expr = np.exp(mu)
dispersion = np.exp(alpha)
# Generate counts from NB distribution
counts = []
for lib_size in library_sizes:
# Mean count for this sample
mean_count = lib_size * mean_expr
# NB parameterization: mean = r * p / (1 - p)
# variance = mean + mean^2 / r
# where r is the dispersion parameter
# So: r = mean^2 / (variance - mean) = 1 / dispersion
r = 1.0 / dispersion
p = r / (r + mean_count)
# Sample from negative binomial
count = self.rng.negative_binomial(r, p)
counts.append(count)
counts = np.array(counts)
# Transform counts: y = log10(1e4 * x / l + 1)
transformed = np.log10(1e4 * counts / library_sizes + 1)
# Convert to tensor with shape (n_samples, 1)
return torch.tensor(transformed, dtype=torch.float32).unsqueeze(-1)
class ParameterDistributions:
"""
Container for parameter distributions learned from empirical data.
This class loads and stores the distributions needed for realistic
synthetic data generation.
"""
def __init__(self, empirical_stats_file: Optional[str] = None):
"""
Initialize parameter distributions.
Args:
empirical_stats_file: Path to empirical statistics file
If None, uses default distributions
"""
if empirical_stats_file is not None:
self._load_empirical_distributions(empirical_stats_file)
else:
self._set_default_distributions()
def _load_empirical_distributions(self, filepath: str):
"""Load parameter distributions from empirical data analysis."""
# This would load pre-computed distribution parameters
# from the analysis script (to be implemented)
raise NotImplementedError("Empirical distribution loading not yet implemented")
def _set_default_distributions(self):
"""Set reasonable default distributions for synthetic data."""
# Default mu distribution (log-normal)
self.mu_params = {
'loc': -1.0, # Moderate expression
'scale': 2.0 # Wide range of expression levels
}
# Default alpha distribution
self.alpha_params = {
'mean': -2.0, # Moderate dispersion
'std': 1.0 # Some variation
}
# Default beta distribution
self.beta_params = {
'prob_de': 0.3, # 30% of genes are DE
'std': 1.0 # Moderate fold changes
}
# Default library size distribution
self.library_params = {
'mean': 10000, # 10K reads per sample
'cv': 0.3 # 30% coefficient of variation
}
# Target normalization parameters (computed from distributions above)
self.target_stats = {
'mu': {'mean': self.mu_params['loc'], 'std': self.mu_params['scale']},
'alpha': {'mean': self.alpha_params['mean'], 'std': self.alpha_params['std']},
# Beta is mixture: E[β] = prob_de * 0 + (1-prob_de) * 0 = 0
# Var[β] = prob_de * std^2 + (1-prob_de) * 0 = prob_de * std^2
'beta': {'mean': 0.0, 'std': (self.beta_params['prob_de'] * self.beta_params['std']**2)**0.5}
}
def create_dataloaders(batch_size: int = 32,
num_workers: int = 4,
num_examples_per_epoch: int = 100000,
parameter_distributions: Optional[ParameterDistributions] = None,
padding_value: float = -1e9,
seed: Optional[int] = None,
persistent_workers: bool = False) -> torch.utils.data.DataLoader:
"""
Create dataloader for synthetic NB GLM training.
Args:
batch_size: Batch size for training
num_workers: Number of worker processes for data generation
num_examples_per_epoch: Examples to generate per epoch
parameter_distributions: Parameter distributions for generation
padding_value: Padding value for variable-length sequences
seed: Random seed for reproducibility
persistent_workers: Whether to keep workers alive between epochs
Returns:
DataLoader for training
"""
# Use default distributions if none provided
if parameter_distributions is None:
parameter_distributions = ParameterDistributions()
# Create dataset with distribution parameters
dataset = SyntheticNBGLMDataset(
num_examples_per_epoch=num_examples_per_epoch,
mu_loc=parameter_distributions.mu_params['loc'],
mu_scale=parameter_distributions.mu_params['scale'],
alpha_mean=parameter_distributions.alpha_params['mean'],
alpha_std=parameter_distributions.alpha_params['std'],
beta_prob_de=parameter_distributions.beta_params['prob_de'],
beta_std=parameter_distributions.beta_params['std'],
library_size_mean=parameter_distributions.library_params['mean'],
library_size_cv=parameter_distributions.library_params['cv'],
seed=seed
)
# Create collate function instance
collate_fn = CollateWrapper(padding_value)
# Create dataloader with persistent workers to avoid file descriptor leaks
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=batch_size,
num_workers=num_workers,
collate_fn=collate_fn,
pin_memory=True,
persistent_workers=persistent_workers and num_workers > 0
)
return dataloader