from __future__ import annotations from pathlib import Path from typing import Dict import numpy as np import pandas as pd from PIL import Image import torch from torchvision import transforms import timm from timm.models.vision_transformer import resize_pos_embed import joblib # ----------------------- paths & device ----------------------- ROOT_DIR = Path(__file__).resolve().parent.parent # AMLGroupSpaceFinal/ BASELINE_DIR = ROOT_DIR / "baseline" LIST_DIR = ROOT_DIR / "list" PLANT_CKPT_PATH = BASELINE_DIR / "plant_dinov2_patch14.pth" LOGREG_PATH = BASELINE_DIR / "logreg_baseline.joblib" SCALER_PATH = BASELINE_DIR / "scaler_baseline.joblib" SPECIES_LIST_PATH = LIST_DIR / "species_list.txt" DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") # ----------------------- helpers (trimmed from evaluate.py) ----------------------- def read_species(p: Path): """Read species_list.txt and return list of species names in index order.""" rows = [] with open(p, "r", encoding="utf-8") as f: for ln in f: ln = ln.strip() if not ln or ln.startswith("#"): continue if ";" in ln: cid, name = ln.split(";", 1) else: parts = ln.split() cid, name = parts[0], " ".join(parts[1:]) if len(parts) > 1 else "" try: cid = int(cid) except ValueError: continue rows.append((cid, name)) df = pd.DataFrame(rows, columns=["class_id", "species_name"]) # same order as in training: iterrows order names = list(df["species_name"]) return names def pool_feats(out): feats = out if isinstance(out, dict): for key in ("pooled", "x_norm_clstoken", "cls_token", "x"): if key in out: feats = out[key] break if isinstance(feats, (list, tuple)): feats = feats[0] if isinstance(feats, torch.Tensor) and feats.dim() == 3: feats = feats[:, 0] if feats.size(1) > 1 else feats.mean(dim=1) if isinstance(feats, torch.Tensor) and feats.dim() > 2: feats = feats.flatten(1) return feats def _unwrap_state_dict(obj): if isinstance(obj, dict): for k in ("state_dict", "model", "module", "ema", "shadow", "backbone", "net", "student", "teacher"): if k in obj and isinstance(obj[k], dict): return obj[k] return obj def _strip_prefixes(sd, prefixes=("module.", "backbone.", "model.", "student.")): out = {} for k, v in sd.items(): for p in prefixes: if k.startswith(p): k = k[len(p):] out[k] = v return out def maybe_load_plant_ckpt(model, ckpt_path: Path): if not ckpt_path.is_file(): print(f"[baseline] plant ckpt not found at {ckpt_path}, using generic DINOv2 weights.") return try: sd = torch.load(ckpt_path, map_location="cpu") sd = _unwrap_state_dict(sd) sd = _strip_prefixes(sd) msd = model.state_dict() if "pos_embed" in sd and "pos_embed" in msd and sd["pos_embed"].shape != msd["pos_embed"].shape: sd["pos_embed"] = resize_pos_embed(sd["pos_embed"], msd["pos_embed"]) print(f"[baseline] interpolated pos_embed to {tuple(msd['pos_embed'].shape)}") missing, unexpected = model.load_state_dict(sd, strict=False) print(f"[baseline] loaded plant ckpt; missing={len(missing)} unexpected={len(unexpected)}") except Exception as e: print(f"[baseline] failed to load '{ckpt_path}': {e}") def build_backbone(size: int = 224): model = timm.create_model( "vit_base_patch14_dinov2", pretrained=True, # generic DINOv2 as fallback num_classes=0, # features only img_size=size, pretrained_cfg_overlay=dict(input_size=(3, size, size)), ).to(DEVICE) pe = getattr(model, "patch_embed", None) if pe is not None: if hasattr(pe, "img_size"): pe.img_size = (size, size) if hasattr(pe, "strict_img_size"): pe.strict_img_size = False maybe_load_plant_ckpt(model, PLANT_CKPT_PATH) model.eval() for p in model.parameters(): p.requires_grad = False return model # ----------------------- global objects (loaded once) ----------------------- IMAGE_SIZE = 224 species_names = read_species(SPECIES_LIST_PATH) num_classes = len(species_names) backbone = build_backbone(IMAGE_SIZE) transform = transforms.Compose([ transforms.Resize(int(IMAGE_SIZE * 1.12)), transforms.CenterCrop(IMAGE_SIZE), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), ]) scaler = joblib.load(SCALER_PATH) logreg = joblib.load(LOGREG_PATH) # ----------------------- public API for Gradio ----------------------- def predict_baseline(image: Image.Image, top_k: int = 5) -> Dict[str, float]: """ Run DINOv2 + Logistic Regression baseline on a single PIL image. Returns {class_name: probability} for the top_k classes. """ if image is None: return {} x = transform(image).unsqueeze(0).to(DEVICE) with torch.no_grad(): out = backbone.forward_features(x) feats = pool_feats(out).cpu().numpy() feats_scaled = scaler.transform(feats) probs = logreg.predict_proba(feats_scaled)[0] # shape [num_classes] top_idx = np.argsort(-probs)[:top_k] result = {species_names[i]: float(probs[i]) for i in top_idx} return result