|
|
|
|
|
"""
|
|
|
ResNet50 Evaluation & Confusion Matrix Script
|
|
|
Dataset: Animals-10
|
|
|
"""
|
|
|
|
|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
from torchvision import datasets, transforms, models
|
|
|
from torch.utils.data import DataLoader, random_split
|
|
|
from sklearn.metrics import confusion_matrix, classification_report
|
|
|
import seaborn as sns
|
|
|
import matplotlib.pyplot as plt
|
|
|
import os
|
|
|
import kagglehub
|
|
|
|
|
|
|
|
|
MODEL_PATH = "best_resnet50_animals.pt"
|
|
|
BATCH_SIZE = 64
|
|
|
NUM_WORKERS = 2
|
|
|
|
|
|
|
|
|
|
|
|
def get_data_path():
|
|
|
"""Locates the dataset locally or downloads it via KaggleHub."""
|
|
|
current_dir = os.getcwd()
|
|
|
local_path = os.path.join(current_dir, "animals10", "raw-img")
|
|
|
|
|
|
if os.path.exists(local_path):
|
|
|
print(f"Dataset found locally at: {local_path}")
|
|
|
return local_path
|
|
|
|
|
|
print("Dataset not found locally. Downloading via KaggleHub...")
|
|
|
path = kagglehub.dataset_download("alessiocorrado99/animals10")
|
|
|
return os.path.join(path, "raw-img")
|
|
|
|
|
|
def evaluate_model():
|
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
if device.type == "cuda":
|
|
|
print(f"Device: CUDA ({torch.cuda.get_device_name(0)})")
|
|
|
else:
|
|
|
print("Device: CPU")
|
|
|
|
|
|
|
|
|
data_path = get_data_path()
|
|
|
|
|
|
|
|
|
test_transform = transforms.Compose([
|
|
|
transforms.Resize((256, 256)),
|
|
|
transforms.CenterCrop(224),
|
|
|
transforms.ToTensor(),
|
|
|
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
|
|
|
])
|
|
|
|
|
|
print("Loading dataset...")
|
|
|
dataset = datasets.ImageFolder(data_path, transform=test_transform)
|
|
|
classes = dataset.classes
|
|
|
print(f"Total samples: {len(dataset)} | Classes: {len(classes)}")
|
|
|
|
|
|
|
|
|
|
|
|
total_len = len(dataset)
|
|
|
train_len = int(0.8 * total_len)
|
|
|
val_len = int(0.1 * total_len)
|
|
|
test_len = total_len - train_len - val_len
|
|
|
|
|
|
generator = torch.Generator().manual_seed(42)
|
|
|
_, _, test_set = random_split(dataset, [train_len, val_len, test_len], generator=generator)
|
|
|
|
|
|
test_loader = DataLoader(
|
|
|
test_set,
|
|
|
batch_size=BATCH_SIZE,
|
|
|
shuffle=False,
|
|
|
num_workers=NUM_WORKERS,
|
|
|
pin_memory=True if device.type == "cuda" else False
|
|
|
)
|
|
|
|
|
|
|
|
|
print(f"Loading model weights from: {MODEL_PATH}")
|
|
|
if not os.path.exists(MODEL_PATH):
|
|
|
print(f"Error: Model file '{MODEL_PATH}' not found in the directory.")
|
|
|
return
|
|
|
|
|
|
model = models.resnet50(weights=None)
|
|
|
model.fc = nn.Linear(model.fc.in_features, 10)
|
|
|
|
|
|
|
|
|
model.load_state_dict(torch.load(MODEL_PATH, map_location=device))
|
|
|
model = model.to(device)
|
|
|
model.eval()
|
|
|
print("Model loaded successfully.")
|
|
|
|
|
|
|
|
|
all_preds = []
|
|
|
all_labels = []
|
|
|
|
|
|
print(f"Starting inference on {len(test_set)} test samples...")
|
|
|
|
|
|
with torch.no_grad():
|
|
|
for batch_idx, (inputs, labels) in enumerate(test_loader):
|
|
|
inputs = inputs.to(device)
|
|
|
|
|
|
outputs = model(inputs)
|
|
|
_, preds = torch.max(outputs, 1)
|
|
|
|
|
|
all_preds.extend(preds.cpu().numpy())
|
|
|
all_labels.extend(labels.numpy())
|
|
|
|
|
|
|
|
|
if (batch_idx + 1) % 10 == 0:
|
|
|
print(f"Processed batch: {batch_idx + 1}/{len(test_loader)}")
|
|
|
|
|
|
|
|
|
correct_preds = sum([1 for i in range(len(all_preds)) if all_preds[i] == all_labels[i]])
|
|
|
accuracy = 100 * correct_preds / len(all_preds)
|
|
|
print(f"\nTest Accuracy: {accuracy:.2f}%")
|
|
|
|
|
|
|
|
|
print("\nClassification Report:")
|
|
|
print(classification_report(all_labels, all_preds, target_names=classes, digits=3))
|
|
|
|
|
|
|
|
|
cm = confusion_matrix(all_labels, all_preds)
|
|
|
plt.figure(figsize=(12, 10))
|
|
|
sns.heatmap(
|
|
|
cm,
|
|
|
annot=True,
|
|
|
fmt='d',
|
|
|
xticklabels=classes,
|
|
|
yticklabels=classes,
|
|
|
cmap='Blues',
|
|
|
cbar_kws={'label': 'Count'}
|
|
|
)
|
|
|
plt.xlabel('Predicted Class', fontsize=12, fontweight='bold')
|
|
|
plt.ylabel('True Class', fontsize=12, fontweight='bold')
|
|
|
plt.title(f'Confusion Matrix - Accuracy: {accuracy:.2f}%', fontsize=14, fontweight='bold')
|
|
|
plt.xticks(rotation=45, ha='right')
|
|
|
plt.yticks(rotation=0)
|
|
|
plt.tight_layout()
|
|
|
|
|
|
output_file = 'model_performance_matrix.png'
|
|
|
plt.savefig(output_file, dpi=300, bbox_inches='tight')
|
|
|
print(f"\nConfusion Matrix saved as: {output_file}")
|
|
|
|
|
|
|
|
|
if device.type == "cuda":
|
|
|
torch.cuda.empty_cache()
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
evaluate_model() |