# -*- coding: utf-8 -*- """ ResNet50 Image Classification Training Script Dataset: Animals-10 Model: ResNet50 (Pre-trained on ImageNet) """ import kagglehub import torch import torch.nn as nn from torch.utils.data import DataLoader, random_split from torchvision import datasets from torchvision.transforms import v2 from torchvision.models import resnet50, ResNet50_Weights from torch.optim import AdamW from torch.optim.lr_scheduler import OneCycleLR import time import os import copy # --- HYPERPARAMETERS & SYSTEM CONFIGURATION --- BATCH_SIZE = 32 ACCUMULATION_STEPS = 2 # Effective Batch Size = 64 EPOCHS = 15 LEARNING_RATE = 1e-4 NUM_WORKERS = 2 # --- CUSTOM DATASET CLASS --- # Defined globally to ensure compatibility with multi-process data loading on Windows. class TransformedDataset(torch.utils.data.Dataset): def __init__(self, subset, transform): self.subset = subset self.transform = transform def __getitem__(self, index): x, y = self.subset[index] return self.transform(x), y def __len__(self): return len(self.subset) # --- UTILITY FUNCTIONS --- def get_device(): """Selects the compute device (CUDA or CPU).""" if torch.cuda.is_available(): print(f"Device selected: CUDA ({torch.cuda.get_device_name(0)})") return torch.device("cuda") print("Device selected: CPU") return torch.device("cpu") def get_data_path(): """Locates the dataset locally or downloads it via KaggleHub.""" current_dir = os.getcwd() local_path = os.path.join(current_dir, "animals10", "raw-img") if os.path.exists(local_path): print(f"Dataset found locally at: {local_path}") return local_path print("Dataset not found locally. Downloading via KaggleHub...") path = kagglehub.dataset_download("alessiocorrado99/animals10") return os.path.join(path, "raw-img") # --- MAIN TRAINING LOOP --- def main(): device = get_device() image_path = get_data_path() # --- DATA AUGMENTATION & NORMALIZATION --- # Normalization statistics based on ImageNet stats = ((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) augmentations = { 'train': v2.Compose([ v2.Resize((256, 256)), v2.RandomResizedCrop(224, scale=(0.6, 1.0)), v2.RandomHorizontalFlip(p=0.5), v2.RandomRotation(15), v2.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2), v2.RandomErasing(p=0.1, scale=(0.02, 0.15)), v2.PILToTensor(), v2.ToDtype(torch.float32, scale=True), v2.Normalize(*stats), ]), 'val': v2.Compose([ v2.Resize((256, 256)), v2.CenterCrop(224), v2.PILToTensor(), v2.ToDtype(torch.float32, scale=True), v2.Normalize(*stats), ]), } # Data Preparation print("Initializing dataset and splits...") full_dataset = datasets.ImageFolder(image_path) total_len = len(full_dataset) train_len = int(0.8 * total_len) val_len = int(0.1 * total_len) test_len = total_len - train_len - val_len # Deterministic split for reproducibility train_subset, val_subset, test_subset = random_split( full_dataset, [train_len, val_len, test_len], generator=torch.Generator().manual_seed(42) ) # Data Loaders train_loader = DataLoader(TransformedDataset(train_subset, augmentations['train']), batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=True) val_loader = DataLoader(TransformedDataset(val_subset, augmentations['val']), batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True) test_loader = DataLoader(TransformedDataset(test_subset, augmentations['val']), batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True) # --- MODEL INITIALIZATION --- print("Loading ResNet50 model with ImageNet weights...") model = resnet50(weights=ResNet50_Weights.IMAGENET1K_V2) # Transfer Learning: Freeze early layers, unfreeze Layer 4 and FC for name, param in model.named_parameters(): if "layer4" in name or "fc" in name: param.requires_grad = True else: param.requires_grad = False # Modify the final fully connected layer for 10 classes model.fc = nn.Linear(model.fc.in_features, 10) model = model.to(device) # Optimization Setup optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=LEARNING_RATE, weight_decay=1e-2) scheduler = OneCycleLR(optimizer, max_lr=LEARNING_RATE*10, steps_per_epoch=len(train_loader)//ACCUMULATION_STEPS, epochs=EPOCHS) criterion = nn.CrossEntropyLoss(label_smoothing=0.1) scaler = torch.amp.GradScaler('cuda') # --- TRAINING PROCESS --- best_acc = 0.0 best_model_wts = copy.deepcopy(model.state_dict()) print("-" * 60) print(f"Starting Training Loop") print(f"Epochs: {EPOCHS} | Batch Size: {BATCH_SIZE} | Accumulation Steps: {ACCUMULATION_STEPS}") print("-" * 60) for epoch in range(EPOCHS): start_time = time.time() # -- Training Phase -- model.train() train_loss = 0 correct = 0 total = 0 optimizer.zero_grad() for i, (inputs, labels) in enumerate(train_loader): inputs, labels = inputs.to(device), labels.to(device) # Mixed Precision Context with torch.amp.autocast('cuda'): outputs = model(inputs) loss = criterion(outputs, labels) loss = loss / ACCUMULATION_STEPS scaler.scale(loss).backward() if (i + 1) % ACCUMULATION_STEPS == 0: scaler.step(optimizer) scaler.update() optimizer.zero_grad() scheduler.step() train_loss += loss.item() * ACCUMULATION_STEPS _, predicted = outputs.max(1) total += labels.size(0) correct += predicted.eq(labels).sum().item() train_acc = 100. * correct / total train_avg_loss = train_loss / len(train_loader) # -- Validation Phase -- model.eval() val_correct = 0 val_total = 0 with torch.no_grad(): for inputs, labels in val_loader: inputs, labels = inputs.to(device), labels.to(device) outputs = model(inputs) _, predicted = outputs.max(1) val_total += labels.size(0) val_correct += predicted.eq(labels).sum().item() val_acc = 100. * val_correct / val_total epoch_time = time.time() - start_time # Print Epoch Statistics print(f"Epoch [{epoch+1}/{EPOCHS}] | Time: {epoch_time:.0f}s | " f"Train Loss: {train_avg_loss:.4f} | Train Acc: {train_acc:.2f}% | " f"Val Acc: {val_acc:.2f}%") # Save Best Model if val_acc > best_acc: best_acc = val_acc best_model_wts = copy.deepcopy(model.state_dict()) torch.save(model.state_dict(), "best_resnet50_animals.pt") print(f" -> Validation accuracy improved. Model saved.") print("-" * 60) print(f"Training Completed. Best Validation Accuracy: {best_acc:.2f}%") print("-" * 60) # --- FINAL TESTING --- print("Starting evaluation on Test Set...") model.load_state_dict(best_model_wts) model.eval() test_correct = 0 test_total = 0 with torch.no_grad(): for inputs, labels in test_loader: inputs, labels = inputs.to(device), labels.to(device) outputs = model(inputs) _, predicted = outputs.max(1) test_total += labels.size(0) test_correct += predicted.eq(labels).sum().item() print(f"Final Test Set Accuracy: {100. * test_correct / test_total:.2f}%") if __name__ == '__main__': main()