Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| from pathlib import Path | |
| from typing import Dict | |
| import numpy as np | |
| import pandas as pd | |
| from PIL import Image | |
| import torch | |
| from torchvision import transforms | |
| import timm | |
| from timm.models.vision_transformer import resize_pos_embed | |
| import joblib | |
| # ----------------------- paths & device ----------------------- | |
| ROOT_DIR = Path(__file__).resolve().parent.parent # AMLGroupSpaceFinal/ | |
| BASELINE_DIR = ROOT_DIR / "baseline" | |
| LIST_DIR = ROOT_DIR / "list" | |
| PLANT_CKPT_PATH = BASELINE_DIR / "plant_dinov2_patch14.pth" | |
| LOGREG_PATH = BASELINE_DIR / "logreg_baseline.joblib" | |
| SCALER_PATH = BASELINE_DIR / "scaler_baseline.joblib" | |
| SPECIES_LIST_PATH = LIST_DIR / "species_list.txt" | |
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # ----------------------- helpers (trimmed from evaluate.py) ----------------------- | |
| def read_species(p: Path): | |
| """Read species_list.txt and return list of species names in index order.""" | |
| rows = [] | |
| with open(p, "r", encoding="utf-8") as f: | |
| for ln in f: | |
| ln = ln.strip() | |
| if not ln or ln.startswith("#"): | |
| continue | |
| if ";" in ln: | |
| cid, name = ln.split(";", 1) | |
| else: | |
| parts = ln.split() | |
| cid, name = parts[0], " ".join(parts[1:]) if len(parts) > 1 else "" | |
| try: | |
| cid = int(cid) | |
| except ValueError: | |
| continue | |
| rows.append((cid, name)) | |
| df = pd.DataFrame(rows, columns=["class_id", "species_name"]) | |
| # same order as in training: iterrows order | |
| names = list(df["species_name"]) | |
| return names | |
| def pool_feats(out): | |
| feats = out | |
| if isinstance(out, dict): | |
| for key in ("pooled", "x_norm_clstoken", "cls_token", "x"): | |
| if key in out: | |
| feats = out[key] | |
| break | |
| if isinstance(feats, (list, tuple)): | |
| feats = feats[0] | |
| if isinstance(feats, torch.Tensor) and feats.dim() == 3: | |
| feats = feats[:, 0] if feats.size(1) > 1 else feats.mean(dim=1) | |
| if isinstance(feats, torch.Tensor) and feats.dim() > 2: | |
| feats = feats.flatten(1) | |
| return feats | |
| def _unwrap_state_dict(obj): | |
| if isinstance(obj, dict): | |
| for k in ("state_dict", "model", "module", "ema", "shadow", | |
| "backbone", "net", "student", "teacher"): | |
| if k in obj and isinstance(obj[k], dict): | |
| return obj[k] | |
| return obj | |
| def _strip_prefixes(sd, prefixes=("module.", "backbone.", "model.", "student.")): | |
| out = {} | |
| for k, v in sd.items(): | |
| for p in prefixes: | |
| if k.startswith(p): | |
| k = k[len(p):] | |
| out[k] = v | |
| return out | |
| def maybe_load_plant_ckpt(model, ckpt_path: Path): | |
| if not ckpt_path.is_file(): | |
| print(f"[baseline] plant ckpt not found at {ckpt_path}, using generic DINOv2 weights.") | |
| return | |
| try: | |
| sd = torch.load(ckpt_path, map_location="cpu") | |
| sd = _unwrap_state_dict(sd) | |
| sd = _strip_prefixes(sd) | |
| msd = model.state_dict() | |
| if "pos_embed" in sd and "pos_embed" in msd and sd["pos_embed"].shape != msd["pos_embed"].shape: | |
| sd["pos_embed"] = resize_pos_embed(sd["pos_embed"], msd["pos_embed"]) | |
| print(f"[baseline] interpolated pos_embed to {tuple(msd['pos_embed'].shape)}") | |
| missing, unexpected = model.load_state_dict(sd, strict=False) | |
| print(f"[baseline] loaded plant ckpt; missing={len(missing)} unexpected={len(unexpected)}") | |
| except Exception as e: | |
| print(f"[baseline] failed to load '{ckpt_path}': {e}") | |
| def build_backbone(size: int = 224): | |
| model = timm.create_model( | |
| "vit_base_patch14_dinov2", | |
| pretrained=True, # generic DINOv2 as fallback | |
| num_classes=0, # features only | |
| img_size=size, | |
| pretrained_cfg_overlay=dict(input_size=(3, size, size)), | |
| ).to(DEVICE) | |
| pe = getattr(model, "patch_embed", None) | |
| if pe is not None: | |
| if hasattr(pe, "img_size"): | |
| pe.img_size = (size, size) | |
| if hasattr(pe, "strict_img_size"): | |
| pe.strict_img_size = False | |
| maybe_load_plant_ckpt(model, PLANT_CKPT_PATH) | |
| model.eval() | |
| for p in model.parameters(): | |
| p.requires_grad = False | |
| return model | |
| # ----------------------- global objects (loaded once) ----------------------- | |
| IMAGE_SIZE = 224 | |
| species_names = read_species(SPECIES_LIST_PATH) | |
| num_classes = len(species_names) | |
| backbone = build_backbone(IMAGE_SIZE) | |
| transform = transforms.Compose([ | |
| transforms.Resize(int(IMAGE_SIZE * 1.12)), | |
| transforms.CenterCrop(IMAGE_SIZE), | |
| transforms.ToTensor(), | |
| transforms.Normalize([0.485, 0.456, 0.406], | |
| [0.229, 0.224, 0.225]), | |
| ]) | |
| scaler = joblib.load(SCALER_PATH) | |
| logreg = joblib.load(LOGREG_PATH) | |
| # ----------------------- public API for Gradio ----------------------- | |
| def predict_baseline(image: Image.Image, top_k: int = 5) -> Dict[str, float]: | |
| """ | |
| Run DINOv2 + Logistic Regression baseline on a single PIL image. | |
| Returns {class_name: probability} for the top_k classes. | |
| """ | |
| if image is None: | |
| return {} | |
| x = transform(image).unsqueeze(0).to(DEVICE) | |
| with torch.no_grad(): | |
| out = backbone.forward_features(x) | |
| feats = pool_feats(out).cpu().numpy() | |
| feats_scaled = scaler.transform(feats) | |
| probs = logreg.predict_proba(feats_scaled)[0] # shape [num_classes] | |
| top_idx = np.argsort(-probs)[:top_k] | |
| result = {species_names[i]: float(probs[i]) for i in top_idx} | |
| return result | |