import torch import torch.nn as nn import torch.nn.functional as F import math import numpy as np from .utils import masked_mean_pooling class MultiHeadAttention(nn.Module): def __init__(self, d_model, n_heads, dropout=0.1): super().__init__() assert d_model % n_heads == 0 self.d_model = d_model self.n_heads = n_heads self.d_k = d_model // n_heads self.w_q = nn.Linear(d_model, d_model) self.w_k = nn.Linear(d_model, d_model) self.w_v = nn.Linear(d_model, d_model) self.w_o = nn.Linear(d_model, d_model) self.dropout = nn.Dropout(dropout) self.scale = math.sqrt(self.d_k) def forward(self, query, key, value, mask=None): batch_size = query.size(0) # Linear transformations and reshape Q = self.w_q(query).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2) K = self.w_k(key).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2) V = self.w_v(value).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2) # Scaled dot-product attention scores = torch.matmul(Q, K.transpose(-2, -1)) / self.scale if mask is not None: # Expand mask for multi-head attention: (B, seq_len) -> (B, 1, 1, seq_len) # This broadcasts to (B, n_heads, seq_len, seq_len) for attention scores mask = mask.unsqueeze(1).unsqueeze(2) scores = scores.masked_fill(mask == 0, -1e4) attention_weights = F.softmax(scores, dim=-1) attention_weights = self.dropout(attention_weights) # Apply attention to values attended = torch.matmul(attention_weights, V) # Concatenate heads and put through final linear layer attended = attended.transpose(1, 2).contiguous().view( batch_size, -1, self.d_model ) return self.w_o(attended) class TransformerBlock(nn.Module): def __init__(self, d_model, n_heads, dropout=0.1): super().__init__() self.attention = MultiHeadAttention(d_model, n_heads, dropout) self.norm1 = nn.LayerNorm(d_model) self.norm2 = nn.LayerNorm(d_model) self.feed_forward = nn.Sequential( nn.Linear(d_model, 4 * d_model), nn.GELU(), nn.Dropout(dropout), nn.Linear(4 * d_model, d_model), nn.Dropout(dropout) ) self.dropout = nn.Dropout(dropout) def forward(self, x, mask=None): # Self-attention with residual connection attn_output = self.attention(x, x, x, mask) x = self.norm1(x + self.dropout(attn_output)) # Feed-forward with residual connection ff_output = self.feed_forward(x) x = self.norm2(x + ff_output) return x class CrossAttentionBlock(nn.Module): def __init__(self, d_model, n_heads, dropout=0.1): super().__init__() self.cross_attention = MultiHeadAttention(d_model, n_heads, dropout) self.norm1 = nn.LayerNorm(d_model) self.norm2 = nn.LayerNorm(d_model) self.feed_forward = nn.Sequential( nn.Linear(d_model, 4 * d_model), nn.GELU(), nn.Dropout(dropout), nn.Linear(4 * d_model, d_model), nn.Dropout(dropout) ) self.dropout = nn.Dropout(dropout) def forward(self, query, key_value, mask=None): # Cross-attention with residual connection attn_output = self.cross_attention(query, key_value, key_value, mask) x = self.norm1(query + self.dropout(attn_output)) # Feed-forward with residual connection ff_output = self.feed_forward(x) x = self.norm2(x + ff_output) return x class PairSetTransformer(nn.Module): """ Base Pair-Set Transformer that processes two variable-length sets using intra-set and cross-set attention mechanisms. This is a general architecture that can be subclassed for specific tasks. """ def __init__(self, dim_input, d_model=128, n_heads=8, num_self_layers=3, num_cross_layers=3, dropout=0.1, num_outputs=1): super().__init__() self.dim_input = dim_input self.d_model = d_model self.n_heads = n_heads self.num_self_layers = num_self_layers self.num_cross_layers = num_cross_layers self.num_outputs = num_outputs # Embedding layers self.embed_x = nn.Linear(dim_input, d_model) self.embed_y = nn.Linear(dim_input, d_model) # Intra-set self-attention layers self.self_layers_x = nn.ModuleList([ TransformerBlock(d_model, n_heads, dropout) for _ in range(num_self_layers) ]) self.self_layers_y = nn.ModuleList([ TransformerBlock(d_model, n_heads, dropout) for _ in range(num_self_layers) ]) # Cross-set attention layers self.cross_layers_x = nn.ModuleList([ CrossAttentionBlock(d_model, n_heads, dropout) for _ in range(num_cross_layers) ]) self.cross_layers_y = nn.ModuleList([ CrossAttentionBlock(d_model, n_heads, dropout) for _ in range(num_cross_layers) ]) # Combined feature size after concatenation: [φ(X), φ(Y), φ(X)−φ(Y), φ(X)⊙φ(Y)] combined_dim = 4 * d_model # Output head - can be overridden by subclasses self.head = self._create_output_head(combined_dim, dropout) self.dropout = nn.Dropout(dropout) def _create_output_head(self, input_dim, dropout): """ Create output head. Can be overridden by subclasses for task-specific heads. Args: input_dim: Dimension of combined features dropout: Dropout rate Returns: Output head module """ return nn.Sequential( nn.Linear(input_dim, 2 * self.d_model), nn.GELU(), nn.Dropout(dropout), nn.Linear(2 * self.d_model, self.d_model), nn.GELU(), nn.Dropout(dropout), nn.Linear(self.d_model, self.num_outputs) ) def forward(self, x, y, x_mask=None, y_mask=None): # x: (B, n1, dim_input) # y: (B, n2, dim_input) # x_mask: (B, n1) boolean mask for x (True = real data, False = padding) # y_mask: (B, n2) boolean mask for y (True = real data, False = padding) # Embedding x_emb = self.dropout(self.embed_x(x)) # (B, n1, d_model) y_emb = self.dropout(self.embed_y(y)) # (B, n2, d_model) # Create attention masks (invert for attention - True = attend, False = ignore) x_attn_mask = x_mask if x_mask is not None else None y_attn_mask = y_mask if y_mask is not None else None # Intra-set self-attention for layer in self.self_layers_x: x_emb = layer(x_emb, x_attn_mask) for layer in self.self_layers_y: y_emb = layer(y_emb, y_attn_mask) # Cross-set attention for cross_x, cross_y in zip(self.cross_layers_x, self.cross_layers_y): x_cross = cross_x(x_emb, y_emb, y_attn_mask) # X attending to Y y_cross = cross_y(y_emb, x_emb, x_attn_mask) # Y attending to X x_emb = x_cross y_emb = y_cross # Masked mean pooling over sets if x_mask is not None: phi_x = masked_mean_pooling(x_emb, x_mask, dim=1) # (B, d_model) else: phi_x = x_emb.mean(dim=1) # (B, d_model) if y_mask is not None: phi_y = masked_mean_pooling(y_emb, y_mask, dim=1) # (B, d_model) else: phi_y = y_emb.mean(dim=1) # (B, d_model) # Combine features: [φ(X), φ(Y), φ(X)−φ(Y), φ(X)⊙φ(Y)] diff = phi_x - phi_y prod = phi_x * phi_y combined = torch.cat([phi_x, phi_y, diff, prod], dim=1) # (B, 4*d_model) # Final regression output output = self.head(combined) # (B, num_outputs) # Return appropriate shape based on number of outputs if self.num_outputs == 1: return output.squeeze(-1) # (B,) for single output else: return output # (B, num_outputs) for multiple outputs def predict(self, set_x, set_y, padding_value=-1e9): """ Simple prediction interface for two sets (e.g., Python lists). Args: set_x: First set as Python list or 1D array-like set_y: Second set as Python list or 1D array-like padding_value: Value to use for padding (default: -1e9) Returns: Model predictions as tensor """ from .utils import pad_sequences, create_padding_mask # Optimize for CPU inference if not torch.cuda.is_available(): torch.set_num_threads(torch.get_num_threads()) # Get the device the model is on device = next(self.parameters()).device # Convert inputs to tensors if needed and move to model's device if not isinstance(set_x, torch.Tensor): set_x = torch.tensor(set_x, dtype=torch.float32, device=device) else: set_x = set_x.to(device) if not isinstance(set_y, torch.Tensor): set_y = torch.tensor(set_y, dtype=torch.float32, device=device) else: set_y = set_y.to(device) # Ensure proper shape: (n,) -> (n, 1) if set_x.dim() == 1: set_x = set_x.unsqueeze(-1) if set_y.dim() == 1: set_y = set_y.unsqueeze(-1) # Create batch of size 1 x_batch = [set_x] y_batch = [set_y] # Pad sequences and create masks x_padded = pad_sequences(x_batch, padding_value=padding_value) y_padded = pad_sequences(y_batch, padding_value=padding_value) x_mask = create_padding_mask(x_batch) y_mask = create_padding_mask(y_batch) # Set model to evaluation mode self.eval() # Make prediction with torch.no_grad(): prediction = self.forward(x_padded, y_padded, x_mask, y_mask) return prediction def save_model(self, filepath): """ Save the trained model to a file. Args: filepath: Path to save the model """ torch.save({ 'model_state_dict': self.state_dict(), 'model_config': { 'dim_input': self.dim_input, 'd_model': self.d_model, 'n_heads': self.n_heads, 'num_self_layers': self.num_self_layers, 'num_cross_layers': self.num_cross_layers, 'num_outputs': self.num_outputs } }, filepath) @classmethod def load_model(cls, filepath): """ Load a trained model from a file. Args: filepath: Path to the saved model Returns: Loaded PairSetTransformer model """ checkpoint = torch.load(filepath, map_location='cpu', weights_only=False) # Create model with saved configuration model = cls(**checkpoint['model_config']) # Load trained weights model.load_state_dict(checkpoint['model_state_dict']) return model class DispersionTransformer(PairSetTransformer): """ Negative Binomial GLM parameter estimation transformer. This transformer estimates three parameters from two sets of log-transformed counts: - mu: Base mean parameter (log scale) - beta: Log fold change between conditions - alpha: Dispersion parameter (log scale) The model assumes: - Condition 1: x ~ NB(l * exp(mu), exp(alpha)) - Condition 2: x ~ NB(l * exp(mu + beta), exp(alpha)) Inputs are log-transformed scaled counts: y = log10(1e4 * x / l + 1) """ TARGET_COLUMNS = ['mu', 'beta', 'alpha'] def __init__(self, dim_input=1, d_model=128, n_heads=8, num_self_layers=3, num_cross_layers=3, dropout=0.1, target_stats=None): """ Initialize Dispersion transformer with 3 outputs. Args: dim_input: Input dimension (default: 1 for scalar values) d_model: Model dimension n_heads: Number of attention heads num_self_layers: Number of self-attention layers num_cross_layers: Number of cross-attention layers dropout: Dropout rate target_stats: Dictionary with normalization stats for denormalization """ super().__init__( dim_input=dim_input, d_model=d_model, n_heads=n_heads, num_self_layers=num_self_layers, num_cross_layers=num_cross_layers, dropout=dropout, num_outputs=3 # Three parameters: mu, beta, alpha ) # Store normalization parameters for denormalization if target_stats is None: # Default normalization parameters self.target_stats = { 'mu': {'mean': -1.0, 'std': 2.0}, 'alpha': {'mean': -2.0, 'std': 1.0}, 'beta': {'mean': 0.0, 'std': (0.3 * 1.0**2)**0.5} } else: self.target_stats = target_stats # Register target_stats as buffer so it's saved with model state import torch for param_name in ['mu', 'beta', 'alpha']: mean_tensor = torch.tensor(self.target_stats[param_name]['mean'], dtype=torch.float32) std_tensor = torch.tensor(self.target_stats[param_name]['std'], dtype=torch.float32) self.register_buffer(f'{param_name}_mean', mean_tensor) self.register_buffer(f'{param_name}_std', std_tensor) def _create_output_head(self, input_dim, dropout): """ Create output head for NB GLM parameters. Uses shared layers for feature processing with separate final projections for each parameter to allow parameter-specific specialization. """ # Shared feature processing self.shared_layers = nn.Sequential( nn.Linear(input_dim, 2 * self.d_model), nn.GELU(), nn.Dropout(dropout), nn.Linear(2 * self.d_model, self.d_model), nn.GELU(), nn.Dropout(dropout), ) # Parameter-specific heads (just final projection) self.mu_head = nn.Linear(self.d_model, 1) # Base mean self.beta_head = nn.Linear(self.d_model, 1) # Log fold change self.alpha_head = nn.Linear(self.d_model, 1) # Dispersion # Return a module that combines all components return nn.ModuleDict({ 'shared': self.shared_layers, 'mu': self.mu_head, 'beta': self.beta_head, 'alpha': self.alpha_head }) def forward(self, x, y, x_mask=None, y_mask=None): """ Forward pass through Dispersion transformer. Args: x: First set tensor (B, n1, dim_input) - condition 1 samples y: Second set tensor (B, n2, dim_input) - condition 2 samples x_mask: Mask for first set (B, n1) y_mask: Mask for second set (B, n2) Returns: Tensor of shape (B, 3) with NB GLM parameters in order: [mu, beta, alpha] """ # Embedding x_emb = self.dropout(self.embed_x(x)) # (B, n1, d_model) y_emb = self.dropout(self.embed_y(y)) # (B, n2, d_model) # Create attention masks x_attn_mask = x_mask if x_mask is not None else None y_attn_mask = y_mask if y_mask is not None else None # Intra-set self-attention for layer in self.self_layers_x: x_emb = layer(x_emb, x_attn_mask) for layer in self.self_layers_y: y_emb = layer(y_emb, y_attn_mask) # Cross-set attention for cross_x, cross_y in zip(self.cross_layers_x, self.cross_layers_y): x_cross = cross_x(x_emb, y_emb, y_attn_mask) # X attending to Y y_cross = cross_y(y_emb, x_emb, x_attn_mask) # Y attending to X x_emb = x_cross y_emb = y_cross # Masked mean pooling over sets if x_mask is not None: phi_x = masked_mean_pooling(x_emb, x_mask, dim=1) # (B, d_model) else: phi_x = x_emb.mean(dim=1) # (B, d_model) if y_mask is not None: phi_y = masked_mean_pooling(y_emb, y_mask, dim=1) # (B, d_model) else: phi_y = y_emb.mean(dim=1) # (B, d_model) # Combine features: [φ(X), φ(Y), φ(X)−φ(Y), φ(X)⊙φ(Y)] diff = phi_x - phi_y prod = phi_x * phi_y combined = torch.cat([phi_x, phi_y, diff, prod], dim=1) # (B, 4*d_model) # Process through shared layers shared_features = self.head['shared'](combined) # (B, d_model) # Generate outputs from parameter-specific heads mu_output = self.head['mu'](shared_features) # (B, 1) beta_output = self.head['beta'](shared_features) # (B, 1) alpha_output = self.head['alpha'](shared_features) # (B, 1) # Combine outputs in the expected order outputs = torch.cat([mu_output, beta_output, alpha_output], dim=1) # (B, 3) return outputs def predict_parameters(self, set_1, set_2, padding_value=-1e9): """ Predict NB GLM parameters for two sets. Args: set_1: First set (condition 1 samples) set_2: Second set (condition 2 samples) padding_value: Padding value for variable length sequences Returns: Dictionary with estimated parameters: mu, beta, alpha (denormalized) """ predictions = self.predict(set_1, set_2, padding_value) if predictions.dim() == 1: predictions = predictions.unsqueeze(0) # Add batch dimension if needed # Get normalized predictions normalized_result = {} for i, col in enumerate(self.TARGET_COLUMNS): normalized_result[col] = predictions[0, i].item() # Denormalize to original scale result = self._denormalize_targets(normalized_result) return result def predict_batch_parameters(self, set_1_list, set_2_list, padding_value=-1e9): """ Predict NB GLM parameters for multiple pairs in a single vectorized call. Args: set_1_list: List of first sets (condition 1 samples) set_2_list: List of second sets (condition 2 samples) padding_value: Padding value for variable length sequences Returns: List of dictionaries with estimated parameters: mu, beta, alpha (denormalized) """ import torch from .utils import pad_sequences, create_padding_mask # Convert lists to tensors and pad set_1_tensors = [] set_2_tensors = [] for set_1, set_2 in zip(set_1_list, set_2_list): # Convert to tensors if needed if not isinstance(set_1, torch.Tensor): set_1 = torch.tensor(set_1, dtype=torch.float32).unsqueeze(-1) if not isinstance(set_2, torch.Tensor): set_2 = torch.tensor(set_2, dtype=torch.float32).unsqueeze(-1) set_1_tensors.append(set_1) set_2_tensors.append(set_2) # Pad sequences to same length within batch set_1_padded = pad_sequences(set_1_tensors, padding_value=padding_value) set_2_padded = pad_sequences(set_2_tensors, padding_value=padding_value) # Create padding masks set_1_mask = create_padding_mask(set_1_tensors) set_2_mask = create_padding_mask(set_2_tensors) # Single forward pass for entire batch self.eval() with torch.no_grad(): predictions = self(set_1_padded, set_2_padded, set_1_mask, set_2_mask) # Convert to list of results results = [] for i in range(predictions.shape[0]): # Get normalized predictions normalized_result = {} for j, col in enumerate(self.TARGET_COLUMNS): normalized_result[col] = predictions[i, j].item() # Denormalize to original scale result = self._denormalize_targets(normalized_result) results.append(result) return results def _denormalize_targets(self, normalized_targets): """Denormalize targets back to original scale using saved buffers.""" denormalized = {} for param in self.TARGET_COLUMNS: # Use registered buffers for denormalization (automatically saved/loaded) mean = getattr(self, f'{param}_mean').item() std = getattr(self, f'{param}_std').item() denormalized[param] = normalized_targets[param] * std + mean return denormalized @staticmethod def load_from_checkpoint(checkpoint_path): """ Load DispersionTransformer from PyTorch Lightning checkpoint. Args: checkpoint_path: Path to .ckpt file Returns: DispersionTransformer model with normalization parameters loaded """ from .train import DispersionLightningModule lightning_model = DispersionLightningModule.load_from_checkpoint(checkpoint_path) return lightning_model.model class DESeq2Transformer(PairSetTransformer): """ DESeq2-specific transformer that predicts two core DESeq2 statistics: - log2FoldChange: Log2 fold change between conditions - lfcSE: Log2 fold change standard error (log-transformed during training) The standard error target is log-transformed during training for better optimization of right-skewed, multi-order-of-magnitude data. The test statistic (stat = log2FoldChange / lfcSE) can be computed post-prediction using the compute_stat() helper method. """ TARGET_COLUMNS = [ 'log2FoldChange', 'lfcSE' ] # Standard error target that is log-transformed during training SE_TARGETS = ['lfcSE'] SE_EPSILON = 1e-8 # Small epsilon for numerical stability in log transformation @classmethod def _inverse_transform_targets(cls, predictions): """ Apply inverse transformation to targets: SE inverse log transformation. Args: predictions: torch.Tensor with shape (batch_size, 2) containing model predictions Returns: torch.Tensor with targets in original scale """ # Convert to numpy for transformation, then back to tensor if isinstance(predictions, torch.Tensor): pred_numpy = predictions.detach().cpu().numpy() device = predictions.device dtype = predictions.dtype else: pred_numpy = predictions device = None dtype = None # Apply SE inverse log transformation for i, col in enumerate(cls.TARGET_COLUMNS): if col in cls.SE_TARGETS: # Apply inverse transformation: exp(log_SE) - epsilon pred_numpy[:, i] = np.exp(pred_numpy[:, i]) - cls.SE_EPSILON # Convert back to tensor if input was tensor if device is not None: return torch.tensor(pred_numpy, dtype=dtype, device=device) else: return pred_numpy @staticmethod def compute_stat(log2fc, lfcse): """ Compute the test statistic from log2 fold change and standard error. Args: log2fc: Log2 fold change value(s) lfcse: Standard error value(s) (in original scale, not log-transformed) Returns: Test statistic (log2fc / lfcse) """ # Avoid division by zero lfcse_safe = np.maximum(lfcse, 1e-10) return log2fc / lfcse_safe def __init__(self, dim_input=1, d_model=128, n_heads=8, num_self_layers=3, num_cross_layers=3, dropout=0.1): """ Initialize DESeq2 transformer with 2 outputs. Args: dim_input: Input dimension (default: 1 for scalar values) d_model: Model dimension n_heads: Number of attention heads num_self_layers: Number of self-attention layers num_cross_layers: Number of cross-attention layers dropout: Dropout rate """ super().__init__( dim_input=dim_input, d_model=d_model, n_heads=n_heads, num_self_layers=num_self_layers, num_cross_layers=num_cross_layers, dropout=dropout, num_outputs=2 # Two targets: log2FoldChange and lfcSE ) def _create_output_head(self, input_dim, dropout): """ Create DESeq2-specific output head with minimal split architecture. Uses shared layers for most computation with separate final projections for log2 fold change and standard error to allow slight specialization. """ # Shared feature processing (99% of computation) self.shared_layers = nn.Sequential( nn.Linear(input_dim, 2 * self.d_model), nn.GELU(), nn.Dropout(dropout), nn.Linear(2 * self.d_model, self.d_model), nn.GELU(), nn.Dropout(dropout), ) # Minimal separate heads (just final projection) self.log2fc_head = nn.Linear(self.d_model, 1) # log2FoldChange self.lfcse_head = nn.Linear(self.d_model, 1) # lfcSE # Return a module that combines all components return nn.ModuleDict({ 'shared': self.shared_layers, 'log2fc': self.log2fc_head, 'lfcse': self.lfcse_head }) def forward(self, x, y, x_mask=None, y_mask=None): """ Forward pass through DESeq2 transformer. Args: x: First set tensor (B, n1, dim_input) y: Second set tensor (B, n2, dim_input) x_mask: Mask for first set (B, n1) y_mask: Mask for second set (B, n2) Returns: Tensor of shape (B, 2) with DESeq2 statistics in order: [log2FoldChange, lfcSE] """ # x: (B, n1, dim_input) # y: (B, n2, dim_input) # x_mask: (B, n1) boolean mask for x (True = real data, False = padding) # y_mask: (B, n2) boolean mask for y (True = real data, False = padding) # Embedding x_emb = self.dropout(self.embed_x(x)) # (B, n1, d_model) y_emb = self.dropout(self.embed_y(y)) # (B, n2, d_model) # Create attention masks (invert for attention - True = attend, False = ignore) x_attn_mask = x_mask if x_mask is not None else None y_attn_mask = y_mask if y_mask is not None else None # Intra-set self-attention for layer in self.self_layers_x: x_emb = layer(x_emb, x_attn_mask) for layer in self.self_layers_y: y_emb = layer(y_emb, y_attn_mask) # Cross-set attention for cross_x, cross_y in zip(self.cross_layers_x, self.cross_layers_y): x_cross = cross_x(x_emb, y_emb, y_attn_mask) # X attending to Y y_cross = cross_y(y_emb, x_emb, x_attn_mask) # Y attending to X x_emb = x_cross y_emb = y_cross # Masked mean pooling over sets if x_mask is not None: phi_x = masked_mean_pooling(x_emb, x_mask, dim=1) # (B, d_model) else: phi_x = x_emb.mean(dim=1) # (B, d_model) if y_mask is not None: phi_y = masked_mean_pooling(y_emb, y_mask, dim=1) # (B, d_model) else: phi_y = y_emb.mean(dim=1) # (B, d_model) # Combine features: [φ(X), φ(Y), φ(X)−φ(Y), φ(X)⊙φ(Y)] diff = phi_x - phi_y prod = phi_x * phi_y combined = torch.cat([phi_x, phi_y, diff, prod], dim=1) # (B, 4*d_model) # Process through shared layers shared_features = self.head['shared'](combined) # (B, d_model) # Generate outputs from minimal separate heads log2fc_output = self.head['log2fc'](shared_features) # (B, 1) lfcse_output = self.head['lfcse'](shared_features) # (B, 1) # Combine outputs in the expected order outputs = torch.cat([log2fc_output, lfcse_output], dim=1) # (B, 2) return outputs def predict_deseq2(self, set_A, set_B, padding_value=-1e9): """ Predict DESeq2 statistics for two sets. Args: set_A: First set (condition A samples) set_B: Second set (condition B samples) padding_value: Padding value for variable length sequences Returns: Dictionary with DESeq2 statistics and computed test statistic """ predictions = self.predict(set_A, set_B, padding_value) if predictions.dim() == 1: predictions = predictions.unsqueeze(0) # Add batch dimension if needed # Apply inverse transformation to standard error targets predictions = self._inverse_transform_targets(predictions) result = {} for i, col in enumerate(self.TARGET_COLUMNS): result[col] = predictions[0, i].item() # Compute test statistic from predictions result['stat'] = self.compute_stat(result['log2FoldChange'], result['lfcSE']) return result