|
|
import os |
|
|
import torch |
|
|
import pandas as pd |
|
|
import numpy as np |
|
|
from torch.utils.data import Dataset, IterableDataset |
|
|
from typing import List, Tuple, Optional, Dict, Union |
|
|
from scipy import stats |
|
|
from .utils import pad_sequences, create_padding_mask |
|
|
|
|
|
|
|
|
class CollateWrapper: |
|
|
"""Wrapper class for collate function to avoid pickling issues with multiprocessing.""" |
|
|
def __init__(self, padding_value): |
|
|
self.padding_value = padding_value |
|
|
|
|
|
def __call__(self, batch): |
|
|
return collate_nb_glm_batch(batch, padding_value=self.padding_value) |
|
|
|
|
|
|
|
|
def collate_nb_glm_batch(batch: List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor]], |
|
|
padding_value: float = -1e9) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: |
|
|
""" |
|
|
Collate function for variable-length NB GLM sequences. |
|
|
|
|
|
Args: |
|
|
batch: List of (set_1, set_2, targets) tuples |
|
|
padding_value: Value to use for padding |
|
|
|
|
|
Returns: |
|
|
Tuple of (set_1_batch, set_2_batch, set_1_mask, set_2_mask, targets_batch) |
|
|
""" |
|
|
set_1_list, set_2_list, targets_list = zip(*batch) |
|
|
|
|
|
|
|
|
set_1_padded = pad_sequences(list(set_1_list), padding_value=padding_value) |
|
|
set_2_padded = pad_sequences(list(set_2_list), padding_value=padding_value) |
|
|
|
|
|
|
|
|
set_1_mask = create_padding_mask(list(set_1_list)) |
|
|
set_2_mask = create_padding_mask(list(set_2_list)) |
|
|
|
|
|
|
|
|
targets_batch = torch.stack(targets_list) |
|
|
|
|
|
return set_1_padded, set_2_padded, set_1_mask, set_2_mask, targets_batch |
|
|
|
|
|
|
|
|
class SyntheticNBGLMDataset(IterableDataset): |
|
|
""" |
|
|
Online synthetic data generator for Negative Binomial GLM parameter estimation. |
|
|
|
|
|
Generates training examples on-the-fly with known ground truth parameters: |
|
|
- mu: Base mean parameter (log scale) |
|
|
- beta: Log fold change between conditions |
|
|
- alpha: Dispersion parameter (log scale) |
|
|
|
|
|
Each example consists of two sets of samples drawn from: |
|
|
- Condition 1: x ~ NB(l * exp(mu), exp(alpha)) |
|
|
- Condition 2: x ~ NB(l * exp(mu + beta), exp(alpha)) |
|
|
|
|
|
Counts are transformed to: y = log10(1e4 * x / l + 1) |
|
|
""" |
|
|
|
|
|
TARGET_COLUMNS = ['mu', 'beta', 'alpha'] |
|
|
|
|
|
def __init__(self, |
|
|
num_examples_per_epoch: int = 100000, |
|
|
min_samples_per_condition: int = 2, |
|
|
max_samples_per_condition: int = 10, |
|
|
mu_loc: float = -1.0, |
|
|
mu_scale: float = 2.0, |
|
|
alpha_mean: float = -2.0, |
|
|
alpha_std: float = 1.0, |
|
|
beta_prob_de: float = 0.3, |
|
|
beta_std: float = 1.0, |
|
|
library_size_mean: float = 10000, |
|
|
library_size_cv: float = 0.3, |
|
|
seed: Optional[int] = None): |
|
|
""" |
|
|
Initialize synthetic NB GLM dataset. |
|
|
|
|
|
Args: |
|
|
num_examples_per_epoch: Number of examples to generate per epoch |
|
|
min_samples_per_condition: Minimum samples per condition |
|
|
max_samples_per_condition: Maximum samples per condition |
|
|
mu_loc: Location parameter for mu log-normal distribution |
|
|
mu_scale: Scale parameter for mu log-normal distribution |
|
|
alpha_mean: Mean of alpha normal distribution |
|
|
alpha_std: Std of alpha normal distribution |
|
|
beta_prob_de: Probability of differential expression (non-zero beta) |
|
|
beta_std: Standard deviation of beta when DE |
|
|
library_size_mean: Mean library size |
|
|
library_size_cv: Coefficient of variation for library size |
|
|
seed: Random seed for reproducibility |
|
|
""" |
|
|
self.num_examples_per_epoch = num_examples_per_epoch |
|
|
self.min_samples = min_samples_per_condition |
|
|
self.max_samples = max_samples_per_condition |
|
|
|
|
|
|
|
|
self.mu_loc = mu_loc |
|
|
self.mu_scale = mu_scale |
|
|
self.alpha_mean = alpha_mean |
|
|
self.alpha_std = alpha_std |
|
|
self.beta_prob_de = beta_prob_de |
|
|
self.beta_std = beta_std |
|
|
|
|
|
|
|
|
self.library_size_mean = library_size_mean |
|
|
self.library_size_cv = library_size_cv |
|
|
self.library_size_std = library_size_mean * library_size_cv |
|
|
|
|
|
|
|
|
self.target_stats = { |
|
|
'mu': {'mean': mu_loc, 'std': mu_scale}, |
|
|
'alpha': {'mean': alpha_mean, 'std': alpha_std}, |
|
|
|
|
|
'beta': {'mean': 0.0, 'std': (beta_prob_de * beta_std**2)**0.5} |
|
|
} |
|
|
|
|
|
|
|
|
self.seed = seed |
|
|
self.rng = np.random.RandomState(seed) |
|
|
|
|
|
def __len__(self): |
|
|
"""Return the number of examples per epoch for progress tracking.""" |
|
|
return self.num_examples_per_epoch |
|
|
|
|
|
def __iter__(self): |
|
|
"""Infinite iterator that generates examples on-the-fly.""" |
|
|
worker_info = torch.utils.data.get_worker_info() |
|
|
|
|
|
|
|
|
if worker_info is None: |
|
|
|
|
|
examples_per_worker = self.num_examples_per_epoch |
|
|
worker_id = 0 |
|
|
else: |
|
|
|
|
|
examples_per_worker = self.num_examples_per_epoch // worker_info.num_workers |
|
|
worker_id = worker_info.id |
|
|
|
|
|
|
|
|
if self.seed is not None: |
|
|
self.rng = np.random.RandomState(self.seed + worker_id) |
|
|
|
|
|
|
|
|
for _ in range(examples_per_worker): |
|
|
yield self._generate_example() |
|
|
|
|
|
def _generate_example(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
|
|
"""Generate a single training example.""" |
|
|
|
|
|
mu = self._sample_mu() |
|
|
alpha = self._sample_alpha(mu) |
|
|
beta = self._sample_beta() |
|
|
|
|
|
|
|
|
n1 = self.rng.randint(self.min_samples, self.max_samples + 1) |
|
|
n2 = self.rng.randint(self.min_samples, self.max_samples + 1) |
|
|
|
|
|
|
|
|
set_1 = self._generate_set(mu, alpha, n1) |
|
|
|
|
|
|
|
|
set_2 = self._generate_set(mu + beta, alpha, n2) |
|
|
|
|
|
|
|
|
targets_raw = {'mu': mu, 'beta': beta, 'alpha': alpha} |
|
|
targets_normalized = self._normalize_targets(targets_raw) |
|
|
targets = torch.tensor([targets_normalized['mu'], targets_normalized['beta'], targets_normalized['alpha']], dtype=torch.float32) |
|
|
|
|
|
return set_1, set_2, targets |
|
|
|
|
|
def _normalize_targets(self, targets: Dict[str, float]) -> Dict[str, float]: |
|
|
"""Normalize targets to unit normal for better regression performance.""" |
|
|
normalized = {} |
|
|
for param in ['mu', 'beta', 'alpha']: |
|
|
mean = self.target_stats[param]['mean'] |
|
|
std = self.target_stats[param]['std'] |
|
|
|
|
|
std = max(std, 1e-8) |
|
|
normalized[param] = (targets[param] - mean) / std |
|
|
return normalized |
|
|
|
|
|
def denormalize_targets(self, normalized_targets: Dict[str, float]) -> Dict[str, float]: |
|
|
"""Denormalize targets back to original scale.""" |
|
|
denormalized = {} |
|
|
for param in ['mu', 'beta', 'alpha']: |
|
|
mean = self.target_stats[param]['mean'] |
|
|
std = self.target_stats[param]['std'] |
|
|
denormalized[param] = normalized_targets[param] * std + mean |
|
|
return denormalized |
|
|
|
|
|
def _sample_mu(self) -> float: |
|
|
"""Sample base mean parameter from log-normal distribution.""" |
|
|
return self.rng.normal(self.mu_loc, self.mu_scale) |
|
|
|
|
|
def _sample_alpha(self, mu: float) -> float: |
|
|
""" |
|
|
Sample dispersion parameter. |
|
|
|
|
|
For now, we use a simple normal distribution. |
|
|
In the future, this could model the mean-dispersion relationship. |
|
|
""" |
|
|
|
|
|
return self.rng.normal(self.alpha_mean, self.alpha_std) |
|
|
|
|
|
def _sample_beta(self) -> float: |
|
|
"""Sample log fold change with mixture distribution.""" |
|
|
if self.rng.random() < self.beta_prob_de: |
|
|
|
|
|
return self.rng.normal(0, self.beta_std) |
|
|
else: |
|
|
|
|
|
return 0.0 |
|
|
|
|
|
def _sample_library_sizes(self, n_samples: int) -> np.ndarray: |
|
|
"""Sample library sizes from log-normal distribution.""" |
|
|
|
|
|
log_mean = np.log(self.library_size_mean) - 0.5 * np.log(1 + self.library_size_cv**2) |
|
|
log_std = np.sqrt(np.log(1 + self.library_size_cv**2)) |
|
|
|
|
|
return self.rng.lognormal(log_mean, log_std, size=n_samples) |
|
|
|
|
|
def _generate_set(self, mu: float, alpha: float, n_samples: int) -> torch.Tensor: |
|
|
""" |
|
|
Generate a set of transformed counts from NB distribution. |
|
|
|
|
|
Args: |
|
|
mu: Log mean parameter |
|
|
alpha: Log dispersion parameter |
|
|
n_samples: Number of samples to generate |
|
|
|
|
|
Returns: |
|
|
Tensor of shape (n_samples, 1) with transformed counts |
|
|
""" |
|
|
|
|
|
library_sizes = self._sample_library_sizes(n_samples) |
|
|
|
|
|
|
|
|
mean_expr = np.exp(mu) |
|
|
dispersion = np.exp(alpha) |
|
|
|
|
|
|
|
|
counts = [] |
|
|
for lib_size in library_sizes: |
|
|
|
|
|
mean_count = lib_size * mean_expr |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
r = 1.0 / dispersion |
|
|
p = r / (r + mean_count) |
|
|
|
|
|
|
|
|
count = self.rng.negative_binomial(r, p) |
|
|
counts.append(count) |
|
|
|
|
|
counts = np.array(counts) |
|
|
|
|
|
|
|
|
transformed = np.log10(1e4 * counts / library_sizes + 1) |
|
|
|
|
|
|
|
|
return torch.tensor(transformed, dtype=torch.float32).unsqueeze(-1) |
|
|
|
|
|
|
|
|
class ParameterDistributions: |
|
|
""" |
|
|
Container for parameter distributions learned from empirical data. |
|
|
|
|
|
This class loads and stores the distributions needed for realistic |
|
|
synthetic data generation. |
|
|
""" |
|
|
|
|
|
def __init__(self, empirical_stats_file: Optional[str] = None): |
|
|
""" |
|
|
Initialize parameter distributions. |
|
|
|
|
|
Args: |
|
|
empirical_stats_file: Path to empirical statistics file |
|
|
If None, uses default distributions |
|
|
""" |
|
|
if empirical_stats_file is not None: |
|
|
self._load_empirical_distributions(empirical_stats_file) |
|
|
else: |
|
|
self._set_default_distributions() |
|
|
|
|
|
def _load_empirical_distributions(self, filepath: str): |
|
|
"""Load parameter distributions from empirical data analysis.""" |
|
|
|
|
|
|
|
|
raise NotImplementedError("Empirical distribution loading not yet implemented") |
|
|
|
|
|
def _set_default_distributions(self): |
|
|
"""Set reasonable default distributions for synthetic data.""" |
|
|
|
|
|
self.mu_params = { |
|
|
'loc': -1.0, |
|
|
'scale': 2.0 |
|
|
} |
|
|
|
|
|
|
|
|
self.alpha_params = { |
|
|
'mean': -2.0, |
|
|
'std': 1.0 |
|
|
} |
|
|
|
|
|
|
|
|
self.beta_params = { |
|
|
'prob_de': 0.3, |
|
|
'std': 1.0 |
|
|
} |
|
|
|
|
|
|
|
|
self.library_params = { |
|
|
'mean': 10000, |
|
|
'cv': 0.3 |
|
|
} |
|
|
|
|
|
|
|
|
self.target_stats = { |
|
|
'mu': {'mean': self.mu_params['loc'], 'std': self.mu_params['scale']}, |
|
|
'alpha': {'mean': self.alpha_params['mean'], 'std': self.alpha_params['std']}, |
|
|
|
|
|
|
|
|
'beta': {'mean': 0.0, 'std': (self.beta_params['prob_de'] * self.beta_params['std']**2)**0.5} |
|
|
} |
|
|
|
|
|
|
|
|
def create_dataloaders(batch_size: int = 32, |
|
|
num_workers: int = 4, |
|
|
num_examples_per_epoch: int = 100000, |
|
|
parameter_distributions: Optional[ParameterDistributions] = None, |
|
|
padding_value: float = -1e9, |
|
|
seed: Optional[int] = None, |
|
|
persistent_workers: bool = False) -> torch.utils.data.DataLoader: |
|
|
""" |
|
|
Create dataloader for synthetic NB GLM training. |
|
|
|
|
|
Args: |
|
|
batch_size: Batch size for training |
|
|
num_workers: Number of worker processes for data generation |
|
|
num_examples_per_epoch: Examples to generate per epoch |
|
|
parameter_distributions: Parameter distributions for generation |
|
|
padding_value: Padding value for variable-length sequences |
|
|
seed: Random seed for reproducibility |
|
|
persistent_workers: Whether to keep workers alive between epochs |
|
|
|
|
|
Returns: |
|
|
DataLoader for training |
|
|
""" |
|
|
|
|
|
if parameter_distributions is None: |
|
|
parameter_distributions = ParameterDistributions() |
|
|
|
|
|
|
|
|
dataset = SyntheticNBGLMDataset( |
|
|
num_examples_per_epoch=num_examples_per_epoch, |
|
|
mu_loc=parameter_distributions.mu_params['loc'], |
|
|
mu_scale=parameter_distributions.mu_params['scale'], |
|
|
alpha_mean=parameter_distributions.alpha_params['mean'], |
|
|
alpha_std=parameter_distributions.alpha_params['std'], |
|
|
beta_prob_de=parameter_distributions.beta_params['prob_de'], |
|
|
beta_std=parameter_distributions.beta_params['std'], |
|
|
library_size_mean=parameter_distributions.library_params['mean'], |
|
|
library_size_cv=parameter_distributions.library_params['cv'], |
|
|
seed=seed |
|
|
) |
|
|
|
|
|
|
|
|
collate_fn = CollateWrapper(padding_value) |
|
|
|
|
|
|
|
|
dataloader = torch.utils.data.DataLoader( |
|
|
dataset, |
|
|
batch_size=batch_size, |
|
|
num_workers=num_workers, |
|
|
collate_fn=collate_fn, |
|
|
pin_memory=True, |
|
|
persistent_workers=persistent_workers and num_workers > 0 |
|
|
) |
|
|
|
|
|
return dataloader |