atalasdev's picture
Initial release: ResNet50 model weights, training scripts, and evaluation metrics
9c40bf5 verified
# -*- coding: utf-8 -*-
"""
ResNet50 Evaluation & Confusion Matrix Script
Dataset: Animals-10
"""
import torch
import torch.nn as nn
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader, random_split
from sklearn.metrics import confusion_matrix, classification_report
import seaborn as sns
import matplotlib.pyplot as plt
import os
import kagglehub
# --- CONFIGURATION ---
MODEL_PATH = "best_resnet50_animals.pt"
BATCH_SIZE = 64
NUM_WORKERS = 2
# --- UTILITY FUNCTIONS ---
def get_data_path():
"""Locates the dataset locally or downloads it via KaggleHub."""
current_dir = os.getcwd()
local_path = os.path.join(current_dir, "animals10", "raw-img")
if os.path.exists(local_path):
print(f"Dataset found locally at: {local_path}")
return local_path
print("Dataset not found locally. Downloading via KaggleHub...")
path = kagglehub.dataset_download("alessiocorrado99/animals10")
return os.path.join(path, "raw-img")
def evaluate_model():
# Device Configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if device.type == "cuda":
print(f"Device: CUDA ({torch.cuda.get_device_name(0)})")
else:
print("Device: CPU")
# Data Path
data_path = get_data_path()
# Test Transformations (Standardization)
test_transform = transforms.Compose([
transforms.Resize((256, 256)),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])
print("Loading dataset...")
dataset = datasets.ImageFolder(data_path, transform=test_transform)
classes = dataset.classes
print(f"Total samples: {len(dataset)} | Classes: {len(classes)}")
# Replicate the split logic to isolate the Test Set
# Note: Ensure the seed matches the training script for consistency
total_len = len(dataset)
train_len = int(0.8 * total_len)
val_len = int(0.1 * total_len)
test_len = total_len - train_len - val_len
generator = torch.Generator().manual_seed(42)
_, _, test_set = random_split(dataset, [train_len, val_len, test_len], generator=generator)
test_loader = DataLoader(
test_set,
batch_size=BATCH_SIZE,
shuffle=False,
num_workers=NUM_WORKERS,
pin_memory=True if device.type == "cuda" else False
)
# Load Model
print(f"Loading model weights from: {MODEL_PATH}")
if not os.path.exists(MODEL_PATH):
print(f"Error: Model file '{MODEL_PATH}' not found in the directory.")
return
model = models.resnet50(weights=None)
model.fc = nn.Linear(model.fc.in_features, 10)
# Load weights (handle CPU/GPU mapping)
model.load_state_dict(torch.load(MODEL_PATH, map_location=device))
model = model.to(device)
model.eval()
print("Model loaded successfully.")
# Inference Loop
all_preds = []
all_labels = []
print(f"Starting inference on {len(test_set)} test samples...")
with torch.no_grad():
for batch_idx, (inputs, labels) in enumerate(test_loader):
inputs = inputs.to(device)
outputs = model(inputs)
_, preds = torch.max(outputs, 1)
all_preds.extend(preds.cpu().numpy())
all_labels.extend(labels.numpy())
# Progress log
if (batch_idx + 1) % 10 == 0:
print(f"Processed batch: {batch_idx + 1}/{len(test_loader)}")
# Metrics
correct_preds = sum([1 for i in range(len(all_preds)) if all_preds[i] == all_labels[i]])
accuracy = 100 * correct_preds / len(all_preds)
print(f"\nTest Accuracy: {accuracy:.2f}%")
# Classification Report
print("\nClassification Report:")
print(classification_report(all_labels, all_preds, target_names=classes, digits=3))
# Confusion Matrix Plotting
cm = confusion_matrix(all_labels, all_preds)
plt.figure(figsize=(12, 10))
sns.heatmap(
cm,
annot=True,
fmt='d',
xticklabels=classes,
yticklabels=classes,
cmap='Blues',
cbar_kws={'label': 'Count'}
)
plt.xlabel('Predicted Class', fontsize=12, fontweight='bold')
plt.ylabel('True Class', fontsize=12, fontweight='bold')
plt.title(f'Confusion Matrix - Accuracy: {accuracy:.2f}%', fontsize=14, fontweight='bold')
plt.xticks(rotation=45, ha='right')
plt.yticks(rotation=0)
plt.tight_layout()
output_file = 'model_performance_matrix.png'
plt.savefig(output_file, dpi=300, bbox_inches='tight')
print(f"\nConfusion Matrix saved as: {output_file}")
# plt.show() # Uncomment if running in a notebook environment
if device.type == "cuda":
torch.cuda.empty_cache()
if __name__ == '__main__':
evaluate_model()