|
|
import torch |
|
|
import torch.nn as nn |
|
|
import numpy as np |
|
|
from typing import Tuple, Callable, Optional |
|
|
|
|
|
|
|
|
def normalize_data(data: torch.Tensor, mean: Optional[torch.Tensor] = None, |
|
|
std: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
|
|
""" |
|
|
Normalize data to zero mean and unit variance. |
|
|
|
|
|
Args: |
|
|
data: Input tensor to normalize |
|
|
mean: Optional precomputed mean (if None, computed from data) |
|
|
std: Optional precomputed std (if None, computed from data) |
|
|
|
|
|
Returns: |
|
|
Tuple of (normalized_data, mean, std) |
|
|
""" |
|
|
if mean is None: |
|
|
mean = data.mean() |
|
|
if std is None: |
|
|
std = data.std() |
|
|
|
|
|
|
|
|
std = torch.clamp(std, min=1e-8) |
|
|
|
|
|
normalized = (data - mean) / std |
|
|
return normalized, mean, std |
|
|
|
|
|
|
|
|
def denormalize_data(normalized_data: torch.Tensor, mean: torch.Tensor, |
|
|
std: torch.Tensor) -> torch.Tensor: |
|
|
""" |
|
|
Denormalize data using provided mean and std. |
|
|
|
|
|
Args: |
|
|
normalized_data: Normalized tensor |
|
|
mean: Mean used for normalization |
|
|
std: Standard deviation used for normalization |
|
|
|
|
|
Returns: |
|
|
Denormalized tensor |
|
|
""" |
|
|
return normalized_data * std + mean |
|
|
|
|
|
|
|
|
def mean_pooling(x: torch.Tensor, dim: int = 1) -> torch.Tensor: |
|
|
""" |
|
|
Apply mean pooling along specified dimension. |
|
|
|
|
|
Args: |
|
|
x: Input tensor |
|
|
dim: Dimension to pool over |
|
|
|
|
|
Returns: |
|
|
Mean-pooled tensor |
|
|
""" |
|
|
return x.mean(dim=dim) |
|
|
|
|
|
|
|
|
def masked_mean_pooling(x: torch.Tensor, mask: torch.Tensor, dim: int = 1) -> torch.Tensor: |
|
|
""" |
|
|
Apply mean pooling along specified dimension, excluding masked (padded) positions. |
|
|
|
|
|
Args: |
|
|
x: Input tensor (B, seq_len, dim) |
|
|
mask: Boolean mask tensor (B, seq_len) where True indicates real data |
|
|
dim: Dimension to pool over (default: 1, sequence dimension) |
|
|
|
|
|
Returns: |
|
|
Mean-pooled tensor excluding masked positions |
|
|
""" |
|
|
if mask.dim() == 2 and x.dim() == 3: |
|
|
|
|
|
mask = mask.unsqueeze(-1) |
|
|
|
|
|
|
|
|
masked_x = x * mask.float() |
|
|
|
|
|
|
|
|
sum_x = masked_x.sum(dim=dim) |
|
|
|
|
|
|
|
|
count = mask.float().sum(dim=dim) |
|
|
|
|
|
|
|
|
count = torch.clamp(count, min=1e-8) |
|
|
|
|
|
|
|
|
return sum_x / count |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def pad_sequences(sequences: list, max_length: Optional[int] = None, |
|
|
padding_value: float = -1e9) -> torch.Tensor: |
|
|
""" |
|
|
Pad sequences to the same length with a configurable padding value. |
|
|
|
|
|
Args: |
|
|
sequences: List of tensors with different lengths |
|
|
max_length: Maximum length to pad to (if None, use longest sequence) |
|
|
padding_value: Value to use for padding (default: -1e9, avoids conflict with meaningful zeros) |
|
|
|
|
|
Returns: |
|
|
Padded tensor of shape (batch_size, max_length, dim) |
|
|
""" |
|
|
if max_length is None: |
|
|
max_length = max(seq.size(0) for seq in sequences) |
|
|
|
|
|
batch_size = len(sequences) |
|
|
dim = sequences[0].size(-1) |
|
|
|
|
|
padded = torch.full((batch_size, max_length, dim), padding_value, |
|
|
dtype=sequences[0].dtype, device=sequences[0].device) |
|
|
|
|
|
for i, seq in enumerate(sequences): |
|
|
length = min(seq.size(0), max_length) |
|
|
padded[i, :length] = seq[:length] |
|
|
|
|
|
return padded |
|
|
|
|
|
|
|
|
def create_padding_mask(sequences: list, max_length: Optional[int] = None) -> torch.Tensor: |
|
|
""" |
|
|
Create padding mask for sequences. |
|
|
|
|
|
Args: |
|
|
sequences: List of tensors with different lengths |
|
|
max_length: Maximum length (if None, use longest sequence) |
|
|
|
|
|
Returns: |
|
|
Boolean mask tensor where True indicates real data, False indicates padding |
|
|
""" |
|
|
if max_length is None: |
|
|
max_length = max(seq.size(0) for seq in sequences) |
|
|
|
|
|
batch_size = len(sequences) |
|
|
mask = torch.zeros(batch_size, max_length, dtype=torch.bool, device=sequences[0].device) |
|
|
|
|
|
for i, seq in enumerate(sequences): |
|
|
length = min(seq.size(0), max_length) |
|
|
mask[i, :length] = True |
|
|
|
|
|
return mask |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def compute_rmse(predictions: torch.Tensor, targets: torch.Tensor) -> float: |
|
|
""" |
|
|
Compute Root Mean Square Error. |
|
|
|
|
|
Args: |
|
|
predictions: Predicted values |
|
|
targets: True target values |
|
|
|
|
|
Returns: |
|
|
RMSE value |
|
|
""" |
|
|
mse = torch.mean((predictions - targets) ** 2) |
|
|
return torch.sqrt(mse).item() |
|
|
|
|
|
|
|
|
def compute_mae(predictions: torch.Tensor, targets: torch.Tensor) -> float: |
|
|
""" |
|
|
Compute Mean Absolute Error. |
|
|
|
|
|
Args: |
|
|
predictions: Predicted values |
|
|
targets: True target values |
|
|
|
|
|
Returns: |
|
|
MAE value |
|
|
""" |
|
|
mae = torch.mean(torch.abs(predictions - targets)) |
|
|
return mae.item() |
|
|
|
|
|
|
|
|
class EarlyStopping: |
|
|
""" |
|
|
Early stopping utility to stop training when validation loss stops improving. |
|
|
""" |
|
|
|
|
|
def __init__(self, patience: int = 5, min_delta: float = 0.0, |
|
|
restore_best_weights: bool = True): |
|
|
""" |
|
|
Args: |
|
|
patience: Number of epochs with no improvement after which training will be stopped |
|
|
min_delta: Minimum change in monitored quantity to qualify as improvement |
|
|
restore_best_weights: Whether to restore model weights from the best epoch |
|
|
""" |
|
|
self.patience = patience |
|
|
self.min_delta = min_delta |
|
|
self.restore_best_weights = restore_best_weights |
|
|
|
|
|
self.best_loss = float('inf') |
|
|
self.counter = 0 |
|
|
self.best_weights = None |
|
|
|
|
|
def __call__(self, val_loss: float, model: nn.Module) -> bool: |
|
|
""" |
|
|
Check if training should be stopped. |
|
|
|
|
|
Args: |
|
|
val_loss: Current validation loss |
|
|
model: Model to potentially save weights for |
|
|
|
|
|
Returns: |
|
|
True if training should be stopped, False otherwise |
|
|
""" |
|
|
if val_loss < self.best_loss - self.min_delta: |
|
|
self.best_loss = val_loss |
|
|
self.counter = 0 |
|
|
if self.restore_best_weights: |
|
|
self.best_weights = model.state_dict().copy() |
|
|
else: |
|
|
self.counter += 1 |
|
|
|
|
|
if self.counter >= self.patience: |
|
|
if self.restore_best_weights and self.best_weights is not None: |
|
|
model.load_state_dict(self.best_weights) |
|
|
return True |
|
|
|
|
|
return False |