import os import torch import pandas as pd import numpy as np from torch.utils.data import Dataset, IterableDataset from typing import List, Tuple, Optional, Dict, Union from scipy import stats from .utils import pad_sequences, create_padding_mask class CollateWrapper: """Wrapper class for collate function to avoid pickling issues with multiprocessing.""" def __init__(self, padding_value): self.padding_value = padding_value def __call__(self, batch): return collate_nb_glm_batch(batch, padding_value=self.padding_value) def collate_nb_glm_batch(batch: List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor]], padding_value: float = -1e9) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """ Collate function for variable-length NB GLM sequences. Args: batch: List of (set_1, set_2, targets) tuples padding_value: Value to use for padding Returns: Tuple of (set_1_batch, set_2_batch, set_1_mask, set_2_mask, targets_batch) """ set_1_list, set_2_list, targets_list = zip(*batch) # Pad sequences to same length within batch set_1_padded = pad_sequences(list(set_1_list), padding_value=padding_value) set_2_padded = pad_sequences(list(set_2_list), padding_value=padding_value) # Create padding masks set_1_mask = create_padding_mask(list(set_1_list)) set_2_mask = create_padding_mask(list(set_2_list)) # Stack targets targets_batch = torch.stack(targets_list) return set_1_padded, set_2_padded, set_1_mask, set_2_mask, targets_batch class SyntheticNBGLMDataset(IterableDataset): """ Online synthetic data generator for Negative Binomial GLM parameter estimation. Generates training examples on-the-fly with known ground truth parameters: - mu: Base mean parameter (log scale) - beta: Log fold change between conditions - alpha: Dispersion parameter (log scale) Each example consists of two sets of samples drawn from: - Condition 1: x ~ NB(l * exp(mu), exp(alpha)) - Condition 2: x ~ NB(l * exp(mu + beta), exp(alpha)) Counts are transformed to: y = log10(1e4 * x / l + 1) """ TARGET_COLUMNS = ['mu', 'beta', 'alpha'] def __init__(self, num_examples_per_epoch: int = 100000, min_samples_per_condition: int = 2, max_samples_per_condition: int = 10, mu_loc: float = -1.0, mu_scale: float = 2.0, alpha_mean: float = -2.0, alpha_std: float = 1.0, beta_prob_de: float = 0.3, beta_std: float = 1.0, library_size_mean: float = 10000, library_size_cv: float = 0.3, seed: Optional[int] = None): """ Initialize synthetic NB GLM dataset. Args: num_examples_per_epoch: Number of examples to generate per epoch min_samples_per_condition: Minimum samples per condition max_samples_per_condition: Maximum samples per condition mu_loc: Location parameter for mu log-normal distribution mu_scale: Scale parameter for mu log-normal distribution alpha_mean: Mean of alpha normal distribution alpha_std: Std of alpha normal distribution beta_prob_de: Probability of differential expression (non-zero beta) beta_std: Standard deviation of beta when DE library_size_mean: Mean library size library_size_cv: Coefficient of variation for library size seed: Random seed for reproducibility """ self.num_examples_per_epoch = num_examples_per_epoch self.min_samples = min_samples_per_condition self.max_samples = max_samples_per_condition # Parameter distribution parameters self.mu_loc = mu_loc self.mu_scale = mu_scale self.alpha_mean = alpha_mean self.alpha_std = alpha_std self.beta_prob_de = beta_prob_de self.beta_std = beta_std # Library size parameters self.library_size_mean = library_size_mean self.library_size_cv = library_size_cv self.library_size_std = library_size_mean * library_size_cv # Target normalization parameters for unit-normal targets self.target_stats = { 'mu': {'mean': mu_loc, 'std': mu_scale}, 'alpha': {'mean': alpha_mean, 'std': alpha_std}, # Beta mixture: mean=0, std=sqrt(prob_de * std^2) 'beta': {'mean': 0.0, 'std': (beta_prob_de * beta_std**2)**0.5} } # Random number generator self.seed = seed self.rng = np.random.RandomState(seed) def __len__(self): """Return the number of examples per epoch for progress tracking.""" return self.num_examples_per_epoch def __iter__(self): """Infinite iterator that generates examples on-the-fly.""" worker_info = torch.utils.data.get_worker_info() # Handle multi-worker data loading if worker_info is None: # Single-process data loading examples_per_worker = self.num_examples_per_epoch worker_id = 0 else: # Multi-process data loading examples_per_worker = self.num_examples_per_epoch // worker_info.num_workers worker_id = worker_info.id # Set different seed for each worker if self.seed is not None: self.rng = np.random.RandomState(self.seed + worker_id) # Generate examples for _ in range(examples_per_worker): yield self._generate_example() def _generate_example(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Generate a single training example.""" # Sample parameters mu = self._sample_mu() alpha = self._sample_alpha(mu) beta = self._sample_beta() # Sample experimental design n1 = self.rng.randint(self.min_samples, self.max_samples + 1) n2 = self.rng.randint(self.min_samples, self.max_samples + 1) # Generate counts for condition 1 set_1 = self._generate_set(mu, alpha, n1) # Generate counts for condition 2 (with beta offset) set_2 = self._generate_set(mu + beta, alpha, n2) # Create normalized target tensor for better regression performance targets_raw = {'mu': mu, 'beta': beta, 'alpha': alpha} targets_normalized = self._normalize_targets(targets_raw) targets = torch.tensor([targets_normalized['mu'], targets_normalized['beta'], targets_normalized['alpha']], dtype=torch.float32) return set_1, set_2, targets def _normalize_targets(self, targets: Dict[str, float]) -> Dict[str, float]: """Normalize targets to unit normal for better regression performance.""" normalized = {} for param in ['mu', 'beta', 'alpha']: mean = self.target_stats[param]['mean'] std = self.target_stats[param]['std'] # Avoid division by zero std = max(std, 1e-8) normalized[param] = (targets[param] - mean) / std return normalized def denormalize_targets(self, normalized_targets: Dict[str, float]) -> Dict[str, float]: """Denormalize targets back to original scale.""" denormalized = {} for param in ['mu', 'beta', 'alpha']: mean = self.target_stats[param]['mean'] std = self.target_stats[param]['std'] denormalized[param] = normalized_targets[param] * std + mean return denormalized def _sample_mu(self) -> float: """Sample base mean parameter from log-normal distribution.""" return self.rng.normal(self.mu_loc, self.mu_scale) def _sample_alpha(self, mu: float) -> float: """ Sample dispersion parameter. For now, we use a simple normal distribution. In the future, this could model the mean-dispersion relationship. """ # Simple independent sampling for now return self.rng.normal(self.alpha_mean, self.alpha_std) def _sample_beta(self) -> float: """Sample log fold change with mixture distribution.""" if self.rng.random() < self.beta_prob_de: # Differential expression - sample from normal return self.rng.normal(0, self.beta_std) else: # No differential expression return 0.0 def _sample_library_sizes(self, n_samples: int) -> np.ndarray: """Sample library sizes from log-normal distribution.""" # Use log-normal to ensure positive values with realistic variation log_mean = np.log(self.library_size_mean) - 0.5 * np.log(1 + self.library_size_cv**2) log_std = np.sqrt(np.log(1 + self.library_size_cv**2)) return self.rng.lognormal(log_mean, log_std, size=n_samples) def _generate_set(self, mu: float, alpha: float, n_samples: int) -> torch.Tensor: """ Generate a set of transformed counts from NB distribution. Args: mu: Log mean parameter alpha: Log dispersion parameter n_samples: Number of samples to generate Returns: Tensor of shape (n_samples, 1) with transformed counts """ # Sample library sizes library_sizes = self._sample_library_sizes(n_samples) # Convert parameters from log scale mean_expr = np.exp(mu) dispersion = np.exp(alpha) # Generate counts from NB distribution counts = [] for lib_size in library_sizes: # Mean count for this sample mean_count = lib_size * mean_expr # NB parameterization: mean = r * p / (1 - p) # variance = mean + mean^2 / r # where r is the dispersion parameter # So: r = mean^2 / (variance - mean) = 1 / dispersion r = 1.0 / dispersion p = r / (r + mean_count) # Sample from negative binomial count = self.rng.negative_binomial(r, p) counts.append(count) counts = np.array(counts) # Transform counts: y = log10(1e4 * x / l + 1) transformed = np.log10(1e4 * counts / library_sizes + 1) # Convert to tensor with shape (n_samples, 1) return torch.tensor(transformed, dtype=torch.float32).unsqueeze(-1) class ParameterDistributions: """ Container for parameter distributions learned from empirical data. This class loads and stores the distributions needed for realistic synthetic data generation. """ def __init__(self, empirical_stats_file: Optional[str] = None): """ Initialize parameter distributions. Args: empirical_stats_file: Path to empirical statistics file If None, uses default distributions """ if empirical_stats_file is not None: self._load_empirical_distributions(empirical_stats_file) else: self._set_default_distributions() def _load_empirical_distributions(self, filepath: str): """Load parameter distributions from empirical data analysis.""" # This would load pre-computed distribution parameters # from the analysis script (to be implemented) raise NotImplementedError("Empirical distribution loading not yet implemented") def _set_default_distributions(self): """Set reasonable default distributions for synthetic data.""" # Default mu distribution (log-normal) self.mu_params = { 'loc': -1.0, # Moderate expression 'scale': 2.0 # Wide range of expression levels } # Default alpha distribution self.alpha_params = { 'mean': -2.0, # Moderate dispersion 'std': 1.0 # Some variation } # Default beta distribution self.beta_params = { 'prob_de': 0.3, # 30% of genes are DE 'std': 1.0 # Moderate fold changes } # Default library size distribution self.library_params = { 'mean': 10000, # 10K reads per sample 'cv': 0.3 # 30% coefficient of variation } # Target normalization parameters (computed from distributions above) self.target_stats = { 'mu': {'mean': self.mu_params['loc'], 'std': self.mu_params['scale']}, 'alpha': {'mean': self.alpha_params['mean'], 'std': self.alpha_params['std']}, # Beta is mixture: E[β] = prob_de * 0 + (1-prob_de) * 0 = 0 # Var[β] = prob_de * std^2 + (1-prob_de) * 0 = prob_de * std^2 'beta': {'mean': 0.0, 'std': (self.beta_params['prob_de'] * self.beta_params['std']**2)**0.5} } def create_dataloaders(batch_size: int = 32, num_workers: int = 4, num_examples_per_epoch: int = 100000, parameter_distributions: Optional[ParameterDistributions] = None, padding_value: float = -1e9, seed: Optional[int] = None, persistent_workers: bool = False) -> torch.utils.data.DataLoader: """ Create dataloader for synthetic NB GLM training. Args: batch_size: Batch size for training num_workers: Number of worker processes for data generation num_examples_per_epoch: Examples to generate per epoch parameter_distributions: Parameter distributions for generation padding_value: Padding value for variable-length sequences seed: Random seed for reproducibility persistent_workers: Whether to keep workers alive between epochs Returns: DataLoader for training """ # Use default distributions if none provided if parameter_distributions is None: parameter_distributions = ParameterDistributions() # Create dataset with distribution parameters dataset = SyntheticNBGLMDataset( num_examples_per_epoch=num_examples_per_epoch, mu_loc=parameter_distributions.mu_params['loc'], mu_scale=parameter_distributions.mu_params['scale'], alpha_mean=parameter_distributions.alpha_params['mean'], alpha_std=parameter_distributions.alpha_params['std'], beta_prob_de=parameter_distributions.beta_params['prob_de'], beta_std=parameter_distributions.beta_params['std'], library_size_mean=parameter_distributions.library_params['mean'], library_size_cv=parameter_distributions.library_params['cv'], seed=seed ) # Create collate function instance collate_fn = CollateWrapper(padding_value) # Create dataloader with persistent workers to avoid file descriptor leaks dataloader = torch.utils.data.DataLoader( dataset, batch_size=batch_size, num_workers=num_workers, collate_fn=collate_fn, pin_memory=True, persistent_workers=persistent_workers and num_workers > 0 ) return dataloader