File size: 5,068 Bytes
9c40bf5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 |
# -*- coding: utf-8 -*-
"""
ResNet50 Evaluation & Confusion Matrix Script
Dataset: Animals-10
"""
import torch
import torch.nn as nn
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader, random_split
from sklearn.metrics import confusion_matrix, classification_report
import seaborn as sns
import matplotlib.pyplot as plt
import os
import kagglehub
# --- CONFIGURATION ---
MODEL_PATH = "best_resnet50_animals.pt"
BATCH_SIZE = 64
NUM_WORKERS = 2
# --- UTILITY FUNCTIONS ---
def get_data_path():
"""Locates the dataset locally or downloads it via KaggleHub."""
current_dir = os.getcwd()
local_path = os.path.join(current_dir, "animals10", "raw-img")
if os.path.exists(local_path):
print(f"Dataset found locally at: {local_path}")
return local_path
print("Dataset not found locally. Downloading via KaggleHub...")
path = kagglehub.dataset_download("alessiocorrado99/animals10")
return os.path.join(path, "raw-img")
def evaluate_model():
# Device Configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if device.type == "cuda":
print(f"Device: CUDA ({torch.cuda.get_device_name(0)})")
else:
print("Device: CPU")
# Data Path
data_path = get_data_path()
# Test Transformations (Standardization)
test_transform = transforms.Compose([
transforms.Resize((256, 256)),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])
print("Loading dataset...")
dataset = datasets.ImageFolder(data_path, transform=test_transform)
classes = dataset.classes
print(f"Total samples: {len(dataset)} | Classes: {len(classes)}")
# Replicate the split logic to isolate the Test Set
# Note: Ensure the seed matches the training script for consistency
total_len = len(dataset)
train_len = int(0.8 * total_len)
val_len = int(0.1 * total_len)
test_len = total_len - train_len - val_len
generator = torch.Generator().manual_seed(42)
_, _, test_set = random_split(dataset, [train_len, val_len, test_len], generator=generator)
test_loader = DataLoader(
test_set,
batch_size=BATCH_SIZE,
shuffle=False,
num_workers=NUM_WORKERS,
pin_memory=True if device.type == "cuda" else False
)
# Load Model
print(f"Loading model weights from: {MODEL_PATH}")
if not os.path.exists(MODEL_PATH):
print(f"Error: Model file '{MODEL_PATH}' not found in the directory.")
return
model = models.resnet50(weights=None)
model.fc = nn.Linear(model.fc.in_features, 10)
# Load weights (handle CPU/GPU mapping)
model.load_state_dict(torch.load(MODEL_PATH, map_location=device))
model = model.to(device)
model.eval()
print("Model loaded successfully.")
# Inference Loop
all_preds = []
all_labels = []
print(f"Starting inference on {len(test_set)} test samples...")
with torch.no_grad():
for batch_idx, (inputs, labels) in enumerate(test_loader):
inputs = inputs.to(device)
outputs = model(inputs)
_, preds = torch.max(outputs, 1)
all_preds.extend(preds.cpu().numpy())
all_labels.extend(labels.numpy())
# Progress log
if (batch_idx + 1) % 10 == 0:
print(f"Processed batch: {batch_idx + 1}/{len(test_loader)}")
# Metrics
correct_preds = sum([1 for i in range(len(all_preds)) if all_preds[i] == all_labels[i]])
accuracy = 100 * correct_preds / len(all_preds)
print(f"\nTest Accuracy: {accuracy:.2f}%")
# Classification Report
print("\nClassification Report:")
print(classification_report(all_labels, all_preds, target_names=classes, digits=3))
# Confusion Matrix Plotting
cm = confusion_matrix(all_labels, all_preds)
plt.figure(figsize=(12, 10))
sns.heatmap(
cm,
annot=True,
fmt='d',
xticklabels=classes,
yticklabels=classes,
cmap='Blues',
cbar_kws={'label': 'Count'}
)
plt.xlabel('Predicted Class', fontsize=12, fontweight='bold')
plt.ylabel('True Class', fontsize=12, fontweight='bold')
plt.title(f'Confusion Matrix - Accuracy: {accuracy:.2f}%', fontsize=14, fontweight='bold')
plt.xticks(rotation=45, ha='right')
plt.yticks(rotation=0)
plt.tight_layout()
output_file = 'model_performance_matrix.png'
plt.savefig(output_file, dpi=300, bbox_inches='tight')
print(f"\nConfusion Matrix saved as: {output_file}")
# plt.show() # Uncomment if running in a notebook environment
if device.type == "cuda":
torch.cuda.empty_cache()
if __name__ == '__main__':
evaluate_model() |