|
|
""" |
|
|
Train DavidBeans: The Dynamic Duo |
|
|
================================== |
|
|
|
|
|
βββββββββββββββββββ |
|
|
β BEANS β "I see the patches..." |
|
|
β (ViT Backbone)β |
|
|
β π« β π« β π« β Cantor-routed sparse attention |
|
|
ββββββββββ¬βββββββββ |
|
|
β |
|
|
βΌ |
|
|
βββββββββββββββββββ |
|
|
β DAVID β "I know the crystals..." |
|
|
β (Classifier) β |
|
|
β π β π β π β Multi-scale projection |
|
|
ββββββββββ¬βββββββββ |
|
|
β |
|
|
βΌ |
|
|
[Prediction] |
|
|
|
|
|
Cross-contrast learning aligns patch features with crystal anchors. |
|
|
Unified Cayley-Menger loss maintains geometric structure throughout. |
|
|
|
|
|
Features: |
|
|
- HuggingFace Hub integration for model upload |
|
|
- Automatic model card generation |
|
|
- Checkpoint management |
|
|
|
|
|
Author: AbstractPhil |
|
|
Date: November 28, 2025 |
|
|
""" |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from torch.utils.data import DataLoader |
|
|
from torch.optim import AdamW |
|
|
from torch.optim.lr_scheduler import CosineAnnealingLR, OneCycleLR |
|
|
from tqdm.auto import tqdm |
|
|
import time |
|
|
import math |
|
|
from pathlib import Path |
|
|
from typing import Dict, Optional, Tuple, List |
|
|
from dataclasses import dataclass, field |
|
|
import json |
|
|
import os |
|
|
from datetime import datetime |
|
|
|
|
|
|
|
|
from geofractal.model.david_beans.model import DavidBeans, DavidBeansConfig |
|
|
|
|
|
|
|
|
try: |
|
|
from huggingface_hub import HfApi, create_repo, upload_folder |
|
|
HF_HUB_AVAILABLE = True |
|
|
except ImportError: |
|
|
HF_HUB_AVAILABLE = False |
|
|
print(" [!] huggingface_hub not installed. Run: pip install huggingface_hub") |
|
|
|
|
|
|
|
|
try: |
|
|
from safetensors.torch import save_file as save_safetensors |
|
|
SAFETENSORS_AVAILABLE = True |
|
|
except ImportError: |
|
|
SAFETENSORS_AVAILABLE = False |
|
|
|
|
|
|
|
|
try: |
|
|
from torch.utils.tensorboard import SummaryWriter |
|
|
TENSORBOARD_AVAILABLE = True |
|
|
except ImportError: |
|
|
TENSORBOARD_AVAILABLE = False |
|
|
print(" [!] tensorboard not installed. Run: pip install tensorboard") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass |
|
|
class TrainingConfig: |
|
|
"""Training hyperparameters.""" |
|
|
|
|
|
|
|
|
run_name: str = "default" |
|
|
run_number: Optional[int] = None |
|
|
|
|
|
|
|
|
dataset: str = "cifar10" |
|
|
image_size: int = 32 |
|
|
batch_size: int = 128 |
|
|
num_workers: int = 4 |
|
|
|
|
|
|
|
|
epochs: int = 100 |
|
|
warmup_epochs: int = 5 |
|
|
|
|
|
|
|
|
learning_rate: float = 1e-3 |
|
|
weight_decay: float = 0.05 |
|
|
betas: Tuple[float, float] = (0.9, 0.999) |
|
|
|
|
|
|
|
|
scheduler: str = "cosine" |
|
|
min_lr: float = 1e-6 |
|
|
|
|
|
|
|
|
ce_weight: float = 1.0 |
|
|
cayley_weight: float = 0.01 |
|
|
contrast_weight: float = 0.5 |
|
|
scale_ce_weight: float = 0.1 |
|
|
|
|
|
|
|
|
gradient_clip: float = 1.0 |
|
|
label_smoothing: float = 0.1 |
|
|
|
|
|
|
|
|
use_augmentation: bool = True |
|
|
mixup_alpha: float = 0.2 |
|
|
cutmix_alpha: float = 1.0 |
|
|
|
|
|
|
|
|
save_interval: int = 10 |
|
|
output_dir: str = "./checkpoints" |
|
|
resume_from: Optional[str] = None |
|
|
|
|
|
|
|
|
use_tensorboard: bool = True |
|
|
log_interval: int = 50 |
|
|
|
|
|
|
|
|
push_to_hub: bool = False |
|
|
hub_repo_id: Optional[str] = None |
|
|
hub_private: bool = False |
|
|
hub_append_run: bool = True |
|
|
|
|
|
|
|
|
device: str = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
def to_dict(self) -> Dict: |
|
|
return {k: v for k, v in self.__dict__.items()} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def generate_model_card( |
|
|
model_config: DavidBeansConfig, |
|
|
train_config: TrainingConfig, |
|
|
best_acc: float, |
|
|
training_history: Optional[Dict] = None |
|
|
) -> str: |
|
|
"""Generate a model card for HuggingFace Hub.""" |
|
|
|
|
|
scales_str = ", ".join([str(s) for s in model_config.scales]) |
|
|
|
|
|
dataset_info = { |
|
|
"cifar10": ("CIFAR-10", 10, "Image classification on 32x32 images"), |
|
|
"cifar100": ("CIFAR-100", 100, "Fine-grained image classification on 32x32 images"), |
|
|
}.get(train_config.dataset, (train_config.dataset, model_config.num_classes, "")) |
|
|
|
|
|
card_content = f"""--- |
|
|
library_name: pytorch |
|
|
license: apache-2.0 |
|
|
tags: |
|
|
- vision |
|
|
- image-classification |
|
|
- geometric-deep-learning |
|
|
- vit |
|
|
- cantor-routing |
|
|
- pentachoron |
|
|
- multi-scale |
|
|
datasets: |
|
|
- {train_config.dataset} |
|
|
metrics: |
|
|
- accuracy |
|
|
model-index: |
|
|
- name: DavidBeans |
|
|
results: |
|
|
- task: |
|
|
type: image-classification |
|
|
name: Image Classification |
|
|
dataset: |
|
|
name: {dataset_info[0]} |
|
|
type: {train_config.dataset} |
|
|
metrics: |
|
|
- type: accuracy |
|
|
value: {best_acc:.2f} |
|
|
name: Top-1 Accuracy |
|
|
--- |
|
|
|
|
|
# π«π DavidBeans: Unified Vision-to-Crystal Architecture |
|
|
|
|
|
DavidBeans combines **ViT-Beans** (Cantor-routed sparse attention) with **David** (multi-scale crystal classification) into a unified geometric deep learning architecture. |
|
|
|
|
|
## Model Description |
|
|
|
|
|
This model implements several novel techniques: |
|
|
|
|
|
- **Hybrid Cantor Routing**: Combines fractal Cantor set distances with positional proximity for sparse attention patterns |
|
|
- **Pentachoron Experts**: 5-vertex simplex structure with Cayley-Menger geometric regularization |
|
|
- **Multi-Scale Crystal Projection**: Projects features to multiple representation scales with learned fusion |
|
|
- **Cross-Contrastive Learning**: Aligns patch-level features with crystal anchors |
|
|
|
|
|
## Architecture |
|
|
|
|
|
``` |
|
|
Image [B, 3, {model_config.image_size}, {model_config.image_size}] |
|
|
β |
|
|
βΌ |
|
|
βββββββββββββββββββββββββββββββββββββββββββ |
|
|
β BEANS BACKBONE β |
|
|
β ββ Patch Embed β [{model_config.num_patches} patches, {model_config.dim}d] |
|
|
β ββ Hybrid Cantor Router (Ξ±={model_config.cantor_weight}) |
|
|
β ββ {model_config.num_layers} Γ Attention Blocks ({model_config.num_heads} heads) |
|
|
β ββ {model_config.num_layers} Γ Pentachoron Expert Layers |
|
|
βββββββββββββββββββββββββββββββββββββββββββ |
|
|
β |
|
|
βΌ |
|
|
βββββββββββββββββββββββββββββββββββββββββββ |
|
|
β DAVID HEAD β |
|
|
β ββ Multi-scale projection: [{scales_str}] |
|
|
β ββ Per-scale Crystal Heads |
|
|
β ββ Geometric Fusion (learned weights) |
|
|
βββββββββββββββββββββββββββββββββββββββββββ |
|
|
β |
|
|
βΌ |
|
|
[{model_config.num_classes} classes] |
|
|
``` |
|
|
|
|
|
## Training Details |
|
|
|
|
|
| Parameter | Value | |
|
|
|-----------|-------| |
|
|
| Dataset | {dataset_info[0]} | |
|
|
| Classes | {model_config.num_classes} | |
|
|
| Image Size | {model_config.image_size}Γ{model_config.image_size} | |
|
|
| Patch Size | {model_config.patch_size}Γ{model_config.patch_size} | |
|
|
| Embedding Dim | {model_config.dim} | |
|
|
| Layers | {model_config.num_layers} | |
|
|
| Attention Heads | {model_config.num_heads} | |
|
|
| Experts | {model_config.num_experts} (pentachoron) | |
|
|
| Sparse Neighbors | k={model_config.k_neighbors} | |
|
|
| Scales | [{scales_str}] | |
|
|
| Epochs | {train_config.epochs} | |
|
|
| Batch Size | {train_config.batch_size} | |
|
|
| Learning Rate | {train_config.learning_rate} | |
|
|
| Weight Decay | {train_config.weight_decay} | |
|
|
| Mixup Ξ± | {train_config.mixup_alpha} | |
|
|
| CutMix Ξ± | {train_config.cutmix_alpha} | |
|
|
| Label Smoothing | {train_config.label_smoothing} | |
|
|
|
|
|
## Results |
|
|
|
|
|
| Metric | Value | |
|
|
|--------|-------| |
|
|
| **Top-1 Accuracy** | **{best_acc:.2f}%** | |
|
|
|
|
|
## TensorBoard Logs |
|
|
|
|
|
Training logs are included in the `tensorboard/` directory. To view: |
|
|
|
|
|
```bash |
|
|
tensorboard --logdir tensorboard/ |
|
|
``` |
|
|
|
|
|
## Usage |
|
|
|
|
|
```python |
|
|
import torch |
|
|
from safetensors.torch import load_file |
|
|
from david_beans import DavidBeans, DavidBeansConfig |
|
|
|
|
|
# Load config |
|
|
config = DavidBeansConfig( |
|
|
image_size={model_config.image_size}, |
|
|
patch_size={model_config.patch_size}, |
|
|
dim={model_config.dim}, |
|
|
num_layers={model_config.num_layers}, |
|
|
num_heads={model_config.num_heads}, |
|
|
num_experts={model_config.num_experts}, |
|
|
k_neighbors={model_config.k_neighbors}, |
|
|
cantor_weight={model_config.cantor_weight}, |
|
|
scales={model_config.scales}, |
|
|
num_classes={model_config.num_classes} |
|
|
) |
|
|
|
|
|
# Create model and load weights |
|
|
model = DavidBeans(config) |
|
|
state_dict = load_file("model.safetensors") |
|
|
model.load_state_dict(state_dict) |
|
|
|
|
|
# Inference |
|
|
model.eval() |
|
|
with torch.no_grad(): |
|
|
output = model(images) |
|
|
predictions = output['logits'].argmax(dim=-1) |
|
|
``` |
|
|
|
|
|
## Citation |
|
|
|
|
|
```bibtex |
|
|
@misc{{davidbeans2025, |
|
|
author = {{AbstractPhil}}, |
|
|
title = {{DavidBeans: Unified Vision-to-Crystal Architecture}}, |
|
|
year = {{2025}}, |
|
|
publisher = {{HuggingFace}}, |
|
|
url = {{https://huggingface.co/{train_config.hub_repo_id or 'AbstractPhil/david-beans'}}} |
|
|
}} |
|
|
``` |
|
|
|
|
|
## License |
|
|
|
|
|
Apache 2.0 |
|
|
""" |
|
|
|
|
|
return card_content |
|
|
|
|
|
|
|
|
def save_for_hub( |
|
|
model: DavidBeans, |
|
|
model_config: DavidBeansConfig, |
|
|
train_config: TrainingConfig, |
|
|
best_acc: float, |
|
|
output_dir: Path, |
|
|
training_history: Optional[Dict] = None |
|
|
) -> Path: |
|
|
"""Save model in HuggingFace Hub format.""" |
|
|
|
|
|
hub_dir = output_dir / "hub" |
|
|
hub_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
|
|
|
state_dict = {k: v.clone() for k, v in model.state_dict().items()} |
|
|
|
|
|
if SAFETENSORS_AVAILABLE: |
|
|
try: |
|
|
save_safetensors(state_dict, hub_dir / "model.safetensors") |
|
|
print(f" β Saved model.safetensors") |
|
|
except Exception as e: |
|
|
print(f" [!] Safetensors failed ({e}), using pytorch format only") |
|
|
|
|
|
|
|
|
torch.save(state_dict, hub_dir / "pytorch_model.bin") |
|
|
print(f" β Saved pytorch_model.bin") |
|
|
|
|
|
|
|
|
config_dict = { |
|
|
"architecture": "DavidBeans", |
|
|
"model_type": "david_beans", |
|
|
**model_config.__dict__ |
|
|
} |
|
|
with open(hub_dir / "config.json", "w") as f: |
|
|
json.dump(config_dict, f, indent=2, default=str) |
|
|
print(f" β Saved config.json") |
|
|
|
|
|
|
|
|
with open(hub_dir / "training_config.json", "w") as f: |
|
|
json.dump(train_config.to_dict(), f, indent=2, default=str) |
|
|
|
|
|
|
|
|
model_card = generate_model_card(model_config, train_config, best_acc, training_history) |
|
|
with open(hub_dir / "README.md", "w") as f: |
|
|
f.write(model_card) |
|
|
print(f" β Generated README.md (model card)") |
|
|
|
|
|
|
|
|
if training_history: |
|
|
with open(hub_dir / "training_history.json", "w") as f: |
|
|
json.dump(training_history, f, indent=2) |
|
|
|
|
|
|
|
|
tb_dir = output_dir / "tensorboard" |
|
|
if tb_dir.exists(): |
|
|
import shutil |
|
|
hub_tb_dir = hub_dir / "tensorboard" |
|
|
if hub_tb_dir.exists(): |
|
|
shutil.rmtree(hub_tb_dir) |
|
|
shutil.copytree(tb_dir, hub_tb_dir) |
|
|
print(f" β Copied TensorBoard logs") |
|
|
|
|
|
return hub_dir |
|
|
|
|
|
|
|
|
def push_to_hub( |
|
|
hub_dir: Path, |
|
|
repo_id: str, |
|
|
private: bool = False, |
|
|
commit_message: Optional[str] = None |
|
|
) -> str: |
|
|
"""Push model to HuggingFace Hub.""" |
|
|
|
|
|
if not HF_HUB_AVAILABLE: |
|
|
raise RuntimeError("huggingface_hub not installed. Run: pip install huggingface_hub") |
|
|
|
|
|
api = HfApi() |
|
|
|
|
|
|
|
|
try: |
|
|
create_repo(repo_id, private=private, exist_ok=True) |
|
|
print(f" β Repository ready: {repo_id}") |
|
|
except Exception as e: |
|
|
print(f" [!] Repo creation note: {e}") |
|
|
|
|
|
|
|
|
if commit_message is None: |
|
|
commit_message = f"Upload DavidBeans model - {datetime.now().strftime('%Y-%m-%d %H:%M')}" |
|
|
|
|
|
url = upload_folder( |
|
|
folder_path=str(hub_dir), |
|
|
repo_id=repo_id, |
|
|
commit_message=commit_message |
|
|
) |
|
|
|
|
|
print(f" β Uploaded to: https://huggingface.co/{repo_id}") |
|
|
|
|
|
return url |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_dataloaders(config: TrainingConfig) -> Tuple[DataLoader, DataLoader, int]: |
|
|
"""Get train and test dataloaders.""" |
|
|
|
|
|
try: |
|
|
import torchvision |
|
|
import torchvision.transforms as T |
|
|
|
|
|
if config.dataset == "cifar10": |
|
|
if config.use_augmentation: |
|
|
train_transform = T.Compose([ |
|
|
T.RandomCrop(32, padding=4), |
|
|
T.RandomHorizontalFlip(), |
|
|
T.AutoAugment(T.AutoAugmentPolicy.CIFAR10), |
|
|
T.ToTensor(), |
|
|
T.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)) |
|
|
]) |
|
|
else: |
|
|
train_transform = T.Compose([ |
|
|
T.ToTensor(), |
|
|
T.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)) |
|
|
]) |
|
|
|
|
|
test_transform = T.Compose([ |
|
|
T.ToTensor(), |
|
|
T.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)) |
|
|
]) |
|
|
|
|
|
train_dataset = torchvision.datasets.CIFAR10( |
|
|
root='./data', train=True, download=True, transform=train_transform |
|
|
) |
|
|
test_dataset = torchvision.datasets.CIFAR10( |
|
|
root='./data', train=False, download=True, transform=test_transform |
|
|
) |
|
|
num_classes = 10 |
|
|
|
|
|
elif config.dataset == "cifar100": |
|
|
if config.use_augmentation: |
|
|
train_transform = T.Compose([ |
|
|
T.RandomCrop(32, padding=4), |
|
|
T.RandomHorizontalFlip(), |
|
|
T.AutoAugment(T.AutoAugmentPolicy.CIFAR10), |
|
|
T.ToTensor(), |
|
|
T.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)) |
|
|
]) |
|
|
else: |
|
|
train_transform = T.Compose([ |
|
|
T.ToTensor(), |
|
|
T.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)) |
|
|
]) |
|
|
|
|
|
test_transform = T.Compose([ |
|
|
T.ToTensor(), |
|
|
T.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)) |
|
|
]) |
|
|
|
|
|
train_dataset = torchvision.datasets.CIFAR100( |
|
|
root='./data', train=True, download=True, transform=train_transform |
|
|
) |
|
|
test_dataset = torchvision.datasets.CIFAR100( |
|
|
root='./data', train=False, download=True, transform=test_transform |
|
|
) |
|
|
num_classes = 100 |
|
|
else: |
|
|
raise ValueError(f"Unknown dataset: {config.dataset}") |
|
|
|
|
|
train_loader = DataLoader( |
|
|
train_dataset, |
|
|
batch_size=config.batch_size, |
|
|
shuffle=True, |
|
|
num_workers=config.num_workers, |
|
|
pin_memory=True, |
|
|
persistent_workers=config.num_workers > 0, |
|
|
drop_last=True |
|
|
) |
|
|
test_loader = DataLoader( |
|
|
test_dataset, |
|
|
batch_size=config.batch_size, |
|
|
shuffle=False, |
|
|
num_workers=config.num_workers, |
|
|
pin_memory=True, |
|
|
persistent_workers=config.num_workers > 0 |
|
|
) |
|
|
|
|
|
return train_loader, test_loader, num_classes |
|
|
|
|
|
except ImportError: |
|
|
print(" [!] torchvision not available, using synthetic data") |
|
|
return get_synthetic_dataloaders(config) |
|
|
|
|
|
|
|
|
def get_synthetic_dataloaders(config: TrainingConfig) -> Tuple[DataLoader, DataLoader, int]: |
|
|
"""Fallback synthetic data for testing.""" |
|
|
|
|
|
class SyntheticDataset(torch.utils.data.Dataset): |
|
|
def __init__(self, size: int, image_size: int, num_classes: int): |
|
|
self.size = size |
|
|
self.image_size = image_size |
|
|
self.num_classes = num_classes |
|
|
|
|
|
def __len__(self): |
|
|
return self.size |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
x = torch.randn(3, self.image_size, self.image_size) |
|
|
y = idx % self.num_classes |
|
|
return x, y |
|
|
|
|
|
num_classes = 10 |
|
|
train_dataset = SyntheticDataset(5000, config.image_size, num_classes) |
|
|
test_dataset = SyntheticDataset(1000, config.image_size, num_classes) |
|
|
|
|
|
train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True) |
|
|
test_loader = DataLoader(test_dataset, batch_size=config.batch_size, shuffle=False) |
|
|
|
|
|
return train_loader, test_loader, num_classes |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def mixup_data(x: torch.Tensor, y: torch.Tensor, alpha: float = 0.2): |
|
|
"""Mixup augmentation.""" |
|
|
if alpha > 0: |
|
|
lam = torch.distributions.Beta(alpha, alpha).sample().item() |
|
|
else: |
|
|
lam = 1.0 |
|
|
|
|
|
batch_size = x.size(0) |
|
|
index = torch.randperm(batch_size, device=x.device) |
|
|
|
|
|
mixed_x = lam * x + (1 - lam) * x[index] |
|
|
y_a, y_b = y, y[index] |
|
|
|
|
|
return mixed_x, y_a, y_b, lam |
|
|
|
|
|
|
|
|
def cutmix_data(x: torch.Tensor, y: torch.Tensor, alpha: float = 1.0): |
|
|
"""CutMix augmentation.""" |
|
|
if alpha > 0: |
|
|
lam = torch.distributions.Beta(alpha, alpha).sample().item() |
|
|
else: |
|
|
lam = 1.0 |
|
|
|
|
|
batch_size = x.size(0) |
|
|
index = torch.randperm(batch_size, device=x.device) |
|
|
|
|
|
_, _, H, W = x.shape |
|
|
|
|
|
cut_ratio = math.sqrt(1 - lam) |
|
|
cut_h = int(H * cut_ratio) |
|
|
cut_w = int(W * cut_ratio) |
|
|
|
|
|
cx = torch.randint(0, H, (1,)).item() |
|
|
cy = torch.randint(0, W, (1,)).item() |
|
|
|
|
|
x1 = max(0, cx - cut_h // 2) |
|
|
x2 = min(H, cx + cut_h // 2) |
|
|
y1 = max(0, cy - cut_w // 2) |
|
|
y2 = min(W, cy + cut_w // 2) |
|
|
|
|
|
mixed_x = x.clone() |
|
|
mixed_x[:, :, x1:x2, y1:y2] = x[index, :, x1:x2, y1:y2] |
|
|
|
|
|
lam = 1 - ((x2 - x1) * (y2 - y1)) / (H * W) |
|
|
|
|
|
y_a, y_b = y, y[index] |
|
|
|
|
|
return mixed_x, y_a, y_b, lam |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class MetricsTracker: |
|
|
"""Track training metrics with EMA smoothing.""" |
|
|
|
|
|
def __init__(self, ema_decay: float = 0.9): |
|
|
self.ema_decay = ema_decay |
|
|
self.metrics = {} |
|
|
self.ema_metrics = {} |
|
|
self.history = {} |
|
|
|
|
|
def update(self, **kwargs): |
|
|
for k, v in kwargs.items(): |
|
|
if isinstance(v, torch.Tensor): |
|
|
v = v.item() |
|
|
|
|
|
if k not in self.metrics: |
|
|
self.metrics[k] = [] |
|
|
self.ema_metrics[k] = v |
|
|
self.history[k] = [] |
|
|
|
|
|
self.metrics[k].append(v) |
|
|
self.ema_metrics[k] = self.ema_decay * self.ema_metrics[k] + (1 - self.ema_decay) * v |
|
|
|
|
|
def get_ema(self, key: str) -> float: |
|
|
return self.ema_metrics.get(key, 0.0) |
|
|
|
|
|
def get_epoch_mean(self, key: str) -> float: |
|
|
values = self.metrics.get(key, []) |
|
|
return sum(values) / len(values) if values else 0.0 |
|
|
|
|
|
def end_epoch(self): |
|
|
for k, v in self.metrics.items(): |
|
|
if v: |
|
|
self.history[k].append(sum(v) / len(v)) |
|
|
self.metrics = {k: [] for k in self.metrics} |
|
|
|
|
|
def get_history(self) -> Dict: |
|
|
return self.history |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def find_latest_checkpoint(output_dir: Path) -> Optional[Path]: |
|
|
"""Find the most recent checkpoint in output directory.""" |
|
|
checkpoints = list(output_dir.glob("checkpoint_epoch_*.pt")) |
|
|
|
|
|
if not checkpoints: |
|
|
|
|
|
best_model = output_dir / "best_model.pt" |
|
|
if best_model.exists(): |
|
|
return best_model |
|
|
return None |
|
|
|
|
|
|
|
|
def get_epoch(p): |
|
|
try: |
|
|
return int(p.stem.split("_")[-1]) |
|
|
except: |
|
|
return 0 |
|
|
|
|
|
checkpoints.sort(key=get_epoch, reverse=True) |
|
|
return checkpoints[0] |
|
|
|
|
|
|
|
|
def get_next_run_number(base_dir: Path) -> int: |
|
|
"""Get the next run number by scanning existing run directories.""" |
|
|
if not base_dir.exists(): |
|
|
return 1 |
|
|
|
|
|
max_num = 0 |
|
|
for d in base_dir.iterdir(): |
|
|
if d.is_dir() and d.name.startswith("run_"): |
|
|
try: |
|
|
|
|
|
num = int(d.name.split("_")[1]) |
|
|
max_num = max(max_num, num) |
|
|
except (IndexError, ValueError): |
|
|
continue |
|
|
|
|
|
return max_num + 1 |
|
|
|
|
|
|
|
|
def generate_run_dir_name(run_number: int, run_name: str) -> str: |
|
|
"""Generate a run directory name with number, name, and timestamp.""" |
|
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") |
|
|
|
|
|
safe_name = "".join(c if c.isalnum() or c == "_" else "_" for c in run_name.lower()) |
|
|
safe_name = "_".join(filter(None, safe_name.split("_"))) |
|
|
return f"run_{run_number:03d}_{safe_name}_{timestamp}" |
|
|
|
|
|
|
|
|
def find_latest_run_dir(base_dir: Path) -> Optional[Path]: |
|
|
"""Find the most recent run directory.""" |
|
|
if not base_dir.exists(): |
|
|
return None |
|
|
|
|
|
run_dirs = [d for d in base_dir.iterdir() if d.is_dir() and d.name.startswith("run_")] |
|
|
|
|
|
if not run_dirs: |
|
|
return None |
|
|
|
|
|
|
|
|
run_dirs.sort(key=lambda d: d.stat().st_mtime, reverse=True) |
|
|
return run_dirs[0] |
|
|
|
|
|
|
|
|
def find_checkpoint_in_runs(base_dir: Path, resume_from: str) -> Optional[Path]: |
|
|
""" |
|
|
Find a checkpoint to resume from. |
|
|
|
|
|
Args: |
|
|
base_dir: Base checkpoint directory (e.g., ./checkpoints/cifar100) |
|
|
resume_from: Either "latest", a run directory name, or a full path |
|
|
|
|
|
Returns: |
|
|
Path to checkpoint file, or None |
|
|
""" |
|
|
if resume_from == "latest": |
|
|
|
|
|
run_dir = find_latest_run_dir(base_dir) |
|
|
if run_dir: |
|
|
return find_latest_checkpoint(run_dir) |
|
|
|
|
|
return find_latest_checkpoint(base_dir) |
|
|
|
|
|
|
|
|
full_path = Path(resume_from) |
|
|
if full_path.exists(): |
|
|
if full_path.is_file(): |
|
|
return full_path |
|
|
elif full_path.is_dir(): |
|
|
return find_latest_checkpoint(full_path) |
|
|
|
|
|
|
|
|
run_path = base_dir / resume_from |
|
|
if run_path.exists(): |
|
|
return find_latest_checkpoint(run_path) |
|
|
|
|
|
return None |
|
|
|
|
|
|
|
|
def load_checkpoint( |
|
|
checkpoint_path: Path, |
|
|
model: DavidBeans, |
|
|
optimizer: Optional[torch.optim.Optimizer] = None, |
|
|
device: str = "cuda" |
|
|
) -> Tuple[int, float]: |
|
|
""" |
|
|
Load checkpoint and return (start_epoch, best_acc). |
|
|
|
|
|
Returns: |
|
|
start_epoch: Epoch to resume from (checkpoint_epoch + 1) |
|
|
best_acc: Best accuracy so far |
|
|
""" |
|
|
print(f"\nπ Loading checkpoint: {checkpoint_path}") |
|
|
checkpoint = torch.load(checkpoint_path, map_location=device) |
|
|
|
|
|
model.load_state_dict(checkpoint['model_state_dict']) |
|
|
print(f" β Loaded model weights") |
|
|
|
|
|
if optimizer is not None and 'optimizer_state_dict' in checkpoint: |
|
|
optimizer.load_state_dict(checkpoint['optimizer_state_dict']) |
|
|
print(f" β Loaded optimizer state") |
|
|
|
|
|
epoch = checkpoint.get('epoch', 0) |
|
|
best_acc = checkpoint.get('best_acc', 0.0) |
|
|
|
|
|
print(f" β Loaded checkpoint from epoch {epoch + 1}, best_acc={best_acc:.2f}%") |
|
|
print(f" β Will resume training from epoch {epoch + 2}") |
|
|
|
|
|
return epoch + 1, best_acc |
|
|
|
|
|
|
|
|
def get_config_from_checkpoint(checkpoint_path: Path) -> Tuple[DavidBeansConfig, dict]: |
|
|
""" |
|
|
Extract model and training configs from a checkpoint. |
|
|
|
|
|
Returns: |
|
|
(model_config, train_config_dict) |
|
|
""" |
|
|
checkpoint = torch.load(checkpoint_path, map_location='cpu') |
|
|
|
|
|
model_config_dict = checkpoint.get('model_config', {}) |
|
|
train_config_dict = checkpoint.get('train_config', {}) |
|
|
|
|
|
|
|
|
if 'betas' in train_config_dict and isinstance(train_config_dict['betas'], list): |
|
|
train_config_dict['betas'] = tuple(train_config_dict['betas']) |
|
|
|
|
|
model_config = DavidBeansConfig(**model_config_dict) |
|
|
|
|
|
return model_config, train_config_dict |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def train_epoch( |
|
|
model: DavidBeans, |
|
|
train_loader: DataLoader, |
|
|
optimizer: torch.optim.Optimizer, |
|
|
scheduler: Optional[torch.optim.lr_scheduler._LRScheduler], |
|
|
config: TrainingConfig, |
|
|
epoch: int, |
|
|
tracker: MetricsTracker, |
|
|
writer: Optional['SummaryWriter'] = None |
|
|
) -> Dict[str, float]: |
|
|
"""Train for one epoch.""" |
|
|
|
|
|
model.train() |
|
|
device = config.device |
|
|
|
|
|
total_loss = 0.0 |
|
|
total_correct = 0 |
|
|
total_samples = 0 |
|
|
global_step = epoch * len(train_loader) |
|
|
|
|
|
pbar = tqdm(train_loader, desc=f"Epoch {epoch + 1}", leave=True) |
|
|
|
|
|
for batch_idx, (images, targets) in enumerate(pbar): |
|
|
images = images.to(device, non_blocking=True) |
|
|
targets = targets.to(device, non_blocking=True) |
|
|
|
|
|
|
|
|
use_mixup = config.use_augmentation and config.mixup_alpha > 0 |
|
|
use_cutmix = config.use_augmentation and config.cutmix_alpha > 0 |
|
|
|
|
|
mixed = False |
|
|
if use_mixup or use_cutmix: |
|
|
r = torch.rand(1).item() |
|
|
if r < 0.5: |
|
|
pass |
|
|
elif r < 0.75 and use_mixup: |
|
|
images, targets_a, targets_b, lam = mixup_data(images, targets, config.mixup_alpha) |
|
|
mixed = True |
|
|
elif use_cutmix: |
|
|
images, targets_a, targets_b, lam = cutmix_data(images, targets, config.cutmix_alpha) |
|
|
mixed = True |
|
|
|
|
|
|
|
|
result = model(images, targets=targets, return_loss=True) |
|
|
losses = result['losses'] |
|
|
|
|
|
if mixed: |
|
|
logits = result['logits'] |
|
|
ce_loss = lam * F.cross_entropy(logits, targets_a, label_smoothing=config.label_smoothing) + \ |
|
|
(1 - lam) * F.cross_entropy(logits, targets_b, label_smoothing=config.label_smoothing) |
|
|
losses['ce'] = ce_loss |
|
|
|
|
|
|
|
|
loss = ( |
|
|
config.ce_weight * losses['ce'] + |
|
|
config.cayley_weight * losses.get('geometric', torch.tensor(0.0, device=device)) + |
|
|
config.contrast_weight * losses.get('contrast', torch.tensor(0.0, device=device)) |
|
|
) |
|
|
|
|
|
for scale in model.config.scales: |
|
|
scale_ce = losses.get(f'ce_{scale}', 0.0) |
|
|
if isinstance(scale_ce, torch.Tensor): |
|
|
loss = loss + config.scale_ce_weight * scale_ce |
|
|
|
|
|
|
|
|
optimizer.zero_grad() |
|
|
loss.backward() |
|
|
|
|
|
if config.gradient_clip > 0: |
|
|
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), config.gradient_clip) |
|
|
else: |
|
|
grad_norm = 0.0 |
|
|
|
|
|
optimizer.step() |
|
|
|
|
|
if scheduler is not None and config.scheduler == "onecycle": |
|
|
scheduler.step() |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
logits = result['logits'] |
|
|
preds = logits.argmax(dim=-1) |
|
|
|
|
|
if mixed: |
|
|
correct = (lam * (preds == targets_a).float() + |
|
|
(1 - lam) * (preds == targets_b).float()).sum() |
|
|
else: |
|
|
correct = (preds == targets).sum() |
|
|
|
|
|
total_correct += correct.item() |
|
|
total_samples += targets.size(0) |
|
|
total_loss += loss.item() |
|
|
|
|
|
|
|
|
def to_float(v): |
|
|
return v.item() if isinstance(v, torch.Tensor) else float(v) |
|
|
|
|
|
geo_loss = to_float(losses.get('geometric', 0.0)) |
|
|
contrast_loss = to_float(losses.get('contrast', 0.0)) |
|
|
expert_vol = to_float(losses.get('expert_volume', 0.0)) |
|
|
expert_collapse = to_float(losses.get('expert_collapse', 0.0)) |
|
|
expert_edge = to_float(losses.get('expert_edge_dev', 0.0)) |
|
|
current_lr = optimizer.param_groups[0]['lr'] |
|
|
|
|
|
tracker.update( |
|
|
loss=loss.item(), |
|
|
ce=losses['ce'].item(), |
|
|
geo=geo_loss, |
|
|
contrast=contrast_loss, |
|
|
expert_vol=expert_vol, |
|
|
expert_collapse=expert_collapse, |
|
|
expert_edge=expert_edge, |
|
|
lr=current_lr |
|
|
) |
|
|
|
|
|
|
|
|
if writer is not None and (batch_idx + 1) % config.log_interval == 0: |
|
|
step = global_step + batch_idx |
|
|
|
|
|
|
|
|
writer.add_scalar('train/loss_total', loss.item(), step) |
|
|
writer.add_scalar('train/loss_ce', losses['ce'].item(), step) |
|
|
writer.add_scalar('train/loss_geometric', geo_loss, step) |
|
|
writer.add_scalar('train/loss_contrast', contrast_loss, step) |
|
|
|
|
|
|
|
|
writer.add_scalar('train/expert_volume', expert_vol, step) |
|
|
writer.add_scalar('train/expert_collapse', expert_collapse, step) |
|
|
writer.add_scalar('train/expert_edge_dev', expert_edge, step) |
|
|
|
|
|
|
|
|
writer.add_scalar('train/learning_rate', current_lr, step) |
|
|
writer.add_scalar('train/grad_norm', to_float(grad_norm), step) |
|
|
writer.add_scalar('train/batch_acc', 100.0 * correct.item() / targets.size(0), step) |
|
|
|
|
|
pbar.set_postfix({ |
|
|
'loss': f"{tracker.get_ema('loss'):.3f}", |
|
|
'acc': f"{100.0 * total_correct / total_samples:.1f}%", |
|
|
'geo': f"{tracker.get_ema('geo'):.4f}", |
|
|
'vol': f"{tracker.get_ema('expert_vol'):.4f}" |
|
|
}) |
|
|
|
|
|
if scheduler is not None and config.scheduler == "cosine": |
|
|
scheduler.step() |
|
|
|
|
|
return { |
|
|
'loss': total_loss / len(train_loader), |
|
|
'acc': 100.0 * total_correct / total_samples |
|
|
} |
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def evaluate( |
|
|
model: DavidBeans, |
|
|
test_loader: DataLoader, |
|
|
config: TrainingConfig |
|
|
) -> Dict[str, float]: |
|
|
"""Evaluate on test set.""" |
|
|
|
|
|
model.eval() |
|
|
device = config.device |
|
|
|
|
|
total_loss = 0.0 |
|
|
total_correct = 0 |
|
|
total_samples = 0 |
|
|
scale_correct = {s: 0 for s in model.config.scales} |
|
|
|
|
|
for images, targets in test_loader: |
|
|
images = images.to(device, non_blocking=True) |
|
|
targets = targets.to(device, non_blocking=True) |
|
|
|
|
|
result = model(images, targets=targets, return_loss=True) |
|
|
|
|
|
logits = result['logits'] |
|
|
losses = result['losses'] |
|
|
|
|
|
loss = losses['total'] |
|
|
preds = logits.argmax(dim=-1) |
|
|
|
|
|
total_loss += loss.item() * targets.size(0) |
|
|
total_correct += (preds == targets).sum().item() |
|
|
total_samples += targets.size(0) |
|
|
|
|
|
for i, scale in enumerate(model.config.scales): |
|
|
scale_logits = result['scale_logits'][i] |
|
|
scale_preds = scale_logits.argmax(dim=-1) |
|
|
scale_correct[scale] += (scale_preds == targets).sum().item() |
|
|
|
|
|
metrics = { |
|
|
'loss': total_loss / total_samples, |
|
|
'acc': 100.0 * total_correct / total_samples |
|
|
} |
|
|
|
|
|
for scale, correct in scale_correct.items(): |
|
|
metrics[f'acc_{scale}'] = 100.0 * correct / total_samples |
|
|
|
|
|
return metrics |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def train_david_beans( |
|
|
model_config: Optional[DavidBeansConfig] = None, |
|
|
train_config: Optional[TrainingConfig] = None |
|
|
): |
|
|
"""Main training function.""" |
|
|
|
|
|
print("=" * 70) |
|
|
print(" DAVID-BEANS TRAINING: The Dynamic Duo") |
|
|
print("=" * 70) |
|
|
print() |
|
|
print(" π« BEANS (ViT) + π DAVID (Crystal)") |
|
|
print(" Sparse Attention Multi-Scale Projection") |
|
|
print() |
|
|
print("=" * 70) |
|
|
|
|
|
if train_config is None: |
|
|
train_config = TrainingConfig() |
|
|
|
|
|
base_output_dir = Path(train_config.output_dir) |
|
|
base_output_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
|
|
|
checkpoint_path = None |
|
|
run_dir = None |
|
|
|
|
|
if train_config.resume_from: |
|
|
|
|
|
checkpoint_path = find_checkpoint_in_runs(base_output_dir, train_config.resume_from) |
|
|
|
|
|
if checkpoint_path and checkpoint_path.exists(): |
|
|
print(f"\nπ Found checkpoint: {checkpoint_path}") |
|
|
|
|
|
run_dir = checkpoint_path.parent |
|
|
print(f" β Resuming in run directory: {run_dir.name}") |
|
|
|
|
|
|
|
|
loaded_model_config, loaded_train_config_dict = get_config_from_checkpoint(checkpoint_path) |
|
|
|
|
|
if model_config is None: |
|
|
model_config = loaded_model_config |
|
|
print(f" β Using model config from checkpoint") |
|
|
else: |
|
|
|
|
|
if model_config.dim != loaded_model_config.dim or model_config.scales != loaded_model_config.scales: |
|
|
print(f" β WARNING: Provided config differs from checkpoint!") |
|
|
print(f" Checkpoint: dim={loaded_model_config.dim}, scales={loaded_model_config.scales}") |
|
|
print(f" Provided: dim={model_config.dim}, scales={model_config.scales}") |
|
|
print(f" β Using checkpoint config to ensure compatibility") |
|
|
model_config = loaded_model_config |
|
|
else: |
|
|
print(f" [!] Checkpoint not found: {train_config.resume_from}") |
|
|
checkpoint_path = None |
|
|
|
|
|
|
|
|
if run_dir is None: |
|
|
|
|
|
if train_config.run_number is None: |
|
|
run_number = get_next_run_number(base_output_dir) |
|
|
else: |
|
|
run_number = train_config.run_number |
|
|
|
|
|
|
|
|
run_dir_name = generate_run_dir_name(run_number, train_config.run_name) |
|
|
run_dir = base_output_dir / run_dir_name |
|
|
run_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
print(f"\nπ New run: {run_dir_name}") |
|
|
print(f" Run #{run_number}: {train_config.run_name}") |
|
|
else: |
|
|
|
|
|
try: |
|
|
run_number = int(run_dir.name.split("_")[1]) |
|
|
except (IndexError, ValueError): |
|
|
run_number = 1 |
|
|
|
|
|
|
|
|
output_dir = run_dir |
|
|
|
|
|
|
|
|
effective_hub_repo_id = train_config.hub_repo_id |
|
|
if train_config.hub_repo_id and train_config.hub_append_run: |
|
|
|
|
|
parts = run_dir.name.split("_") |
|
|
if len(parts) >= 3: |
|
|
run_name_part = parts[2] |
|
|
else: |
|
|
run_name_part = train_config.run_name |
|
|
effective_hub_repo_id = f"{train_config.hub_repo_id}-run{run_number:03d}-{run_name_part}" |
|
|
print(f" Hub repo: {effective_hub_repo_id}") |
|
|
|
|
|
if model_config is None: |
|
|
model_config = DavidBeansConfig( |
|
|
image_size=train_config.image_size, |
|
|
patch_size=4, |
|
|
dim=256, |
|
|
num_layers=6, |
|
|
num_heads=8, |
|
|
num_experts=5, |
|
|
k_neighbors=16, |
|
|
cantor_weight=0.3, |
|
|
scales=[64, 128, 256], |
|
|
num_classes=10, |
|
|
contrast_weight=train_config.contrast_weight, |
|
|
cayley_weight=train_config.cayley_weight, |
|
|
dropout=0.1 |
|
|
) |
|
|
|
|
|
device = train_config.device |
|
|
print(f"\nDevice: {device}") |
|
|
|
|
|
|
|
|
print("\nLoading data...") |
|
|
train_loader, test_loader, num_classes = get_dataloaders(train_config) |
|
|
print(f" Dataset: {train_config.dataset}") |
|
|
print(f" Train: {len(train_loader.dataset)}, Test: {len(test_loader.dataset)}") |
|
|
print(f" Classes: {num_classes}") |
|
|
|
|
|
model_config.num_classes = num_classes |
|
|
|
|
|
|
|
|
print("\nBuilding model...") |
|
|
model = DavidBeans(model_config) |
|
|
model = model.to(device) |
|
|
|
|
|
print(f"\n{model}") |
|
|
|
|
|
num_params = sum(p.numel() for p in model.parameters()) |
|
|
num_trainable = sum(p.numel() for p in model.parameters() if p.requires_grad) |
|
|
print(f"\nParameters: {num_params:,} ({num_trainable:,} trainable)") |
|
|
|
|
|
|
|
|
print("\nSetting up optimizer...") |
|
|
|
|
|
decay_params = [] |
|
|
no_decay_params = [] |
|
|
|
|
|
for name, param in model.named_parameters(): |
|
|
if not param.requires_grad: |
|
|
continue |
|
|
if 'bias' in name or 'norm' in name or 'embedding' in name: |
|
|
no_decay_params.append(param) |
|
|
else: |
|
|
decay_params.append(param) |
|
|
|
|
|
optimizer = AdamW([ |
|
|
{'params': decay_params, 'weight_decay': train_config.weight_decay}, |
|
|
{'params': no_decay_params, 'weight_decay': 0.0} |
|
|
], lr=train_config.learning_rate, betas=train_config.betas) |
|
|
|
|
|
if train_config.scheduler == "cosine": |
|
|
scheduler = CosineAnnealingLR( |
|
|
optimizer, |
|
|
T_max=train_config.epochs - train_config.warmup_epochs, |
|
|
eta_min=train_config.min_lr |
|
|
) |
|
|
elif train_config.scheduler == "onecycle": |
|
|
scheduler = OneCycleLR( |
|
|
optimizer, |
|
|
max_lr=train_config.learning_rate, |
|
|
epochs=train_config.epochs, |
|
|
steps_per_epoch=len(train_loader), |
|
|
pct_start=train_config.warmup_epochs / train_config.epochs |
|
|
) |
|
|
else: |
|
|
scheduler = None |
|
|
|
|
|
print(f" Optimizer: AdamW (lr={train_config.learning_rate}, wd={train_config.weight_decay})") |
|
|
print(f" Scheduler: {train_config.scheduler}") |
|
|
print(f" TensorBoard: {output_dir / 'tensorboard'}") |
|
|
|
|
|
tracker = MetricsTracker() |
|
|
best_acc = 0.0 |
|
|
start_epoch = 0 |
|
|
|
|
|
print(f"\nOutput directory: {output_dir}") |
|
|
|
|
|
|
|
|
if checkpoint_path and checkpoint_path.exists(): |
|
|
start_epoch, best_acc = load_checkpoint( |
|
|
checkpoint_path, model, optimizer, device |
|
|
) |
|
|
|
|
|
|
|
|
if scheduler is not None and train_config.scheduler == "cosine": |
|
|
for _ in range(start_epoch): |
|
|
scheduler.step() |
|
|
|
|
|
|
|
|
writer = None |
|
|
if train_config.use_tensorboard and TENSORBOARD_AVAILABLE: |
|
|
tb_dir = output_dir / "tensorboard" |
|
|
tb_dir.mkdir(parents=True, exist_ok=True) |
|
|
writer = SummaryWriter(log_dir=str(tb_dir)) |
|
|
print(f" TensorBoard: {tb_dir}") |
|
|
|
|
|
|
|
|
config_text = json.dumps(model_config.__dict__, indent=2, default=str) |
|
|
writer.add_text("config/model", f"```json\n{config_text}\n```", 0) |
|
|
|
|
|
train_text = json.dumps(train_config.to_dict(), indent=2, default=str) |
|
|
writer.add_text("config/training", f"```json\n{train_text}\n```", 0) |
|
|
elif train_config.use_tensorboard: |
|
|
print(" [!] TensorBoard requested but not available") |
|
|
|
|
|
with open(output_dir / "model_config.json", "w") as f: |
|
|
json.dump(model_config.__dict__, f, indent=2, default=str) |
|
|
with open(output_dir / "train_config.json", "w") as f: |
|
|
json.dump(train_config.to_dict(), f, indent=2, default=str) |
|
|
|
|
|
print(f"\nOutput directory: {output_dir}") |
|
|
|
|
|
|
|
|
print("\n" + "=" * 70) |
|
|
print(" TRAINING") |
|
|
print("=" * 70) |
|
|
|
|
|
if start_epoch > 0: |
|
|
print(f" Resuming from epoch {start_epoch + 1}/{train_config.epochs}") |
|
|
|
|
|
for epoch in range(start_epoch, train_config.epochs): |
|
|
epoch_start = time.time() |
|
|
|
|
|
if epoch < train_config.warmup_epochs and train_config.scheduler == "cosine": |
|
|
warmup_lr = train_config.learning_rate * (epoch + 1) / train_config.warmup_epochs |
|
|
for param_group in optimizer.param_groups: |
|
|
param_group['lr'] = warmup_lr |
|
|
|
|
|
train_metrics = train_epoch( |
|
|
model, train_loader, optimizer, scheduler, |
|
|
train_config, epoch, tracker, writer |
|
|
) |
|
|
|
|
|
test_metrics = evaluate(model, test_loader, train_config) |
|
|
|
|
|
epoch_time = time.time() - epoch_start |
|
|
|
|
|
|
|
|
if writer is not None: |
|
|
|
|
|
writer.add_scalar('epoch/train_loss', train_metrics['loss'], epoch) |
|
|
writer.add_scalar('epoch/train_acc', train_metrics['acc'], epoch) |
|
|
writer.add_scalar('epoch/test_loss', test_metrics['loss'], epoch) |
|
|
writer.add_scalar('epoch/test_acc', test_metrics['acc'], epoch) |
|
|
writer.add_scalar('epoch/learning_rate', optimizer.param_groups[0]['lr'], epoch) |
|
|
writer.add_scalar('epoch/time_seconds', epoch_time, epoch) |
|
|
|
|
|
|
|
|
for scale in model.config.scales: |
|
|
writer.add_scalar(f'scales/acc_{scale}', test_metrics[f'acc_{scale}'], epoch) |
|
|
|
|
|
|
|
|
writer.add_scalar('epoch/generalization_gap', test_metrics['acc'] - train_metrics['acc'], epoch) |
|
|
|
|
|
|
|
|
if (epoch + 1) % 5 == 0: |
|
|
writer.flush() |
|
|
|
|
|
scale_accs = " | ".join([f"{s}:{test_metrics[f'acc_{s}']:.1f}%" for s in model.config.scales]) |
|
|
star = "β
" if test_metrics['acc'] > best_acc else "" |
|
|
|
|
|
print(f" β Train: {train_metrics['acc']:.1f}% | Test: {test_metrics['acc']:.1f}% | " |
|
|
f"Scales: [{scale_accs}] | {epoch_time:.0f}s {star}") |
|
|
|
|
|
if test_metrics['acc'] > best_acc: |
|
|
best_acc = test_metrics['acc'] |
|
|
|
|
|
torch.save({ |
|
|
'epoch': epoch, |
|
|
'model_state_dict': model.state_dict(), |
|
|
'optimizer_state_dict': optimizer.state_dict(), |
|
|
'best_acc': best_acc, |
|
|
'model_config': model_config.__dict__, |
|
|
'train_config': train_config.to_dict() |
|
|
}, output_dir / "best_model.pt") |
|
|
|
|
|
if (epoch + 1) % train_config.save_interval == 0: |
|
|
torch.save({ |
|
|
'epoch': epoch, |
|
|
'model_state_dict': model.state_dict(), |
|
|
'optimizer_state_dict': optimizer.state_dict(), |
|
|
'best_acc': best_acc |
|
|
}, output_dir / f"checkpoint_epoch_{epoch + 1}.pt") |
|
|
|
|
|
|
|
|
if train_config.push_to_hub and HF_HUB_AVAILABLE and effective_hub_repo_id: |
|
|
try: |
|
|
|
|
|
checkpoint = torch.load(output_dir / "best_model.pt", map_location='cpu') |
|
|
model_cpu = DavidBeans(model_config) |
|
|
model_cpu.load_state_dict(checkpoint['model_state_dict']) |
|
|
|
|
|
hub_dir = save_for_hub( |
|
|
model=model_cpu, |
|
|
model_config=model_config, |
|
|
train_config=train_config, |
|
|
best_acc=best_acc, |
|
|
output_dir=output_dir, |
|
|
training_history=tracker.get_history() |
|
|
) |
|
|
|
|
|
push_to_hub( |
|
|
hub_dir=hub_dir, |
|
|
repo_id=effective_hub_repo_id, |
|
|
private=train_config.hub_private, |
|
|
commit_message=f"Checkpoint epoch {epoch + 1} - {best_acc:.2f}% acc" |
|
|
) |
|
|
print(f" π€ Uploaded to {effective_hub_repo_id}") |
|
|
except Exception as e: |
|
|
print(f" [!] Hub upload failed: {e}") |
|
|
|
|
|
tracker.end_epoch() |
|
|
|
|
|
|
|
|
print("\n" + "=" * 70) |
|
|
print(" TRAINING COMPLETE") |
|
|
print("=" * 70) |
|
|
print(f"\n Best Test Accuracy: {best_acc:.2f}%") |
|
|
print(f" Model saved to: {output_dir / 'best_model.pt'}") |
|
|
|
|
|
|
|
|
history = tracker.get_history() |
|
|
with open(output_dir / "training_history.json", "w") as f: |
|
|
json.dump(history, f, indent=2) |
|
|
|
|
|
|
|
|
if writer is not None: |
|
|
|
|
|
hparams = { |
|
|
'dim': model_config.dim, |
|
|
'num_layers': model_config.num_layers, |
|
|
'num_heads': model_config.num_heads, |
|
|
'num_experts': model_config.num_experts, |
|
|
'k_neighbors': model_config.k_neighbors, |
|
|
'cantor_weight': model_config.cantor_weight, |
|
|
'learning_rate': train_config.learning_rate, |
|
|
'weight_decay': train_config.weight_decay, |
|
|
'batch_size': train_config.batch_size, |
|
|
'mixup_alpha': train_config.mixup_alpha, |
|
|
'cutmix_alpha': train_config.cutmix_alpha, |
|
|
} |
|
|
writer.add_hparams(hparams, {'hparam/best_acc': best_acc}) |
|
|
writer.add_scalar('final/best_acc', best_acc, 0) |
|
|
writer.close() |
|
|
print(f" TensorBoard logs: {output_dir / 'tensorboard'}") |
|
|
|
|
|
|
|
|
if train_config.push_to_hub: |
|
|
print("\n" + "=" * 70) |
|
|
print(" UPLOADING TO HUGGINGFACE HUB") |
|
|
print("=" * 70) |
|
|
|
|
|
if not HF_HUB_AVAILABLE: |
|
|
print(" [!] huggingface_hub not installed. Skipping upload.") |
|
|
elif not effective_hub_repo_id: |
|
|
print(" [!] hub_repo_id not set. Skipping upload.") |
|
|
else: |
|
|
checkpoint = torch.load(output_dir / "best_model.pt", map_location='cpu') |
|
|
model.load_state_dict(checkpoint['model_state_dict']) |
|
|
|
|
|
print(f"\n Preparing model for upload...") |
|
|
hub_dir = save_for_hub( |
|
|
model=model, |
|
|
model_config=model_config, |
|
|
train_config=train_config, |
|
|
best_acc=best_acc, |
|
|
output_dir=output_dir, |
|
|
training_history=history |
|
|
) |
|
|
|
|
|
print(f"\n Uploading to {effective_hub_repo_id}...") |
|
|
push_to_hub( |
|
|
hub_dir=hub_dir, |
|
|
repo_id=effective_hub_repo_id, |
|
|
private=train_config.hub_private |
|
|
) |
|
|
|
|
|
print(f"\n π Model uploaded to: https://huggingface.co/{effective_hub_repo_id}") |
|
|
|
|
|
return model, best_acc |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def train_cifar10_small(run_name: str = "cifar10_small"): |
|
|
"""Small model for CIFAR-10.""" |
|
|
model_config = DavidBeansConfig( |
|
|
image_size=32, patch_size=4, dim=256, num_layers=4, |
|
|
num_heads=4, num_experts=5, k_neighbors=16, |
|
|
cantor_weight=0.3, scales=[64, 128, 256, 512], |
|
|
num_classes=10, dropout=0.1 |
|
|
) |
|
|
|
|
|
train_config = TrainingConfig( |
|
|
run_name=run_name, |
|
|
dataset="cifar10", epochs=50, batch_size=128, |
|
|
learning_rate=1e-3, weight_decay=0.05, warmup_epochs=5, |
|
|
cayley_weight=0.01, contrast_weight=0.3, |
|
|
output_dir="./checkpoints/cifar10" |
|
|
) |
|
|
|
|
|
return train_david_beans(model_config, train_config) |
|
|
|
|
|
|
|
|
def train_cifar100( |
|
|
run_name: str = "cifar100_base", |
|
|
push_to_hub: bool = False, |
|
|
hub_repo_id: Optional[str] = None, |
|
|
resume: bool = False |
|
|
): |
|
|
"""Model for CIFAR-100 with optional HF Hub upload and resume.""" |
|
|
model_config = DavidBeansConfig( |
|
|
image_size=32, patch_size=4, dim=512, num_layers=8, |
|
|
num_heads=8, num_experts=5, k_neighbors=32, |
|
|
cantor_weight=0.3, scales=[256, 512, 768], |
|
|
num_classes=100, dropout=0.15 |
|
|
) |
|
|
|
|
|
train_config = TrainingConfig( |
|
|
run_name=run_name, |
|
|
dataset="cifar100", epochs=200, batch_size=128, |
|
|
learning_rate=5e-4, weight_decay=0.1, warmup_epochs=20, |
|
|
cayley_weight=0.01, contrast_weight=0.5, |
|
|
label_smoothing=0.1, mixup_alpha=0.3, cutmix_alpha=1.0, |
|
|
output_dir="./checkpoints/cifar100", |
|
|
resume_from="latest" if resume else None, |
|
|
push_to_hub=push_to_hub, hub_repo_id=hub_repo_id, hub_private=False |
|
|
) |
|
|
|
|
|
return train_david_beans(model_config, train_config) |
|
|
|
|
|
|
|
|
def resume_training( |
|
|
checkpoint_dir: str = "./checkpoints/cifar100", |
|
|
push_to_hub: bool = False, |
|
|
hub_repo_id: Optional[str] = None |
|
|
): |
|
|
""" |
|
|
Resume training from the latest checkpoint in a directory. |
|
|
|
|
|
Usage: |
|
|
resume_training("./checkpoints/cifar100", push_to_hub=True, hub_repo_id="user/repo") |
|
|
""" |
|
|
output_dir = Path(checkpoint_dir) |
|
|
|
|
|
|
|
|
model_config_path = output_dir / "model_config.json" |
|
|
train_config_path = output_dir / "train_config.json" |
|
|
|
|
|
if not model_config_path.exists(): |
|
|
raise FileNotFoundError(f"No model_config.json in {output_dir}") |
|
|
|
|
|
with open(model_config_path) as f: |
|
|
model_config_dict = json.load(f) |
|
|
|
|
|
with open(train_config_path) as f: |
|
|
train_config_dict = json.load(f) |
|
|
|
|
|
|
|
|
if 'betas' in train_config_dict and isinstance(train_config_dict['betas'], list): |
|
|
train_config_dict['betas'] = tuple(train_config_dict['betas']) |
|
|
|
|
|
model_config = DavidBeansConfig(**model_config_dict) |
|
|
train_config = TrainingConfig(**train_config_dict) |
|
|
|
|
|
|
|
|
train_config.resume_from = "latest" |
|
|
train_config.push_to_hub = push_to_hub |
|
|
if hub_repo_id: |
|
|
train_config.hub_repo_id = hub_repo_id |
|
|
|
|
|
return train_david_beans(model_config, train_config) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def upload_checkpoint( |
|
|
checkpoint_path: str, |
|
|
repo_id: str, |
|
|
best_acc: Optional[float] = None, |
|
|
private: bool = False |
|
|
): |
|
|
""" |
|
|
Upload an existing checkpoint to HuggingFace Hub. |
|
|
|
|
|
Usage: |
|
|
upload_checkpoint( |
|
|
checkpoint_path="./checkpoints/cifar100/best_model.pt", |
|
|
repo_id="AbstractPhil/david-beans-cifar100", |
|
|
best_acc=70.0 # Optional, will read from checkpoint if available |
|
|
) |
|
|
""" |
|
|
if not HF_HUB_AVAILABLE: |
|
|
raise RuntimeError("huggingface_hub not installed. Run: pip install huggingface_hub") |
|
|
|
|
|
print(f"\nπ¦ Loading checkpoint: {checkpoint_path}") |
|
|
checkpoint = torch.load(checkpoint_path, map_location='cpu') |
|
|
|
|
|
|
|
|
model_config_dict = checkpoint.get('model_config', {}) |
|
|
train_config_dict = checkpoint.get('train_config', {}) |
|
|
|
|
|
model_config = DavidBeansConfig(**model_config_dict) |
|
|
train_config = TrainingConfig(**train_config_dict) |
|
|
train_config.hub_repo_id = repo_id |
|
|
|
|
|
|
|
|
model = DavidBeans(model_config) |
|
|
model.load_state_dict(checkpoint['model_state_dict']) |
|
|
|
|
|
actual_best_acc = best_acc or checkpoint.get('best_acc', 0.0) |
|
|
|
|
|
|
|
|
output_dir = Path(checkpoint_path).parent |
|
|
|
|
|
print(f"\nπ Preparing files for upload...") |
|
|
hub_dir = save_for_hub( |
|
|
model=model, |
|
|
model_config=model_config, |
|
|
train_config=train_config, |
|
|
best_acc=actual_best_acc, |
|
|
output_dir=output_dir |
|
|
) |
|
|
|
|
|
print(f"\nπ Uploading to {repo_id}...") |
|
|
push_to_hub(hub_dir, repo_id, private=private) |
|
|
|
|
|
print(f"\nπ Done! https://huggingface.co/{repo_id}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
PRESET = "cifar100" |
|
|
RESUME = False |
|
|
RUN_NAME = "5expert_3scale" |
|
|
|
|
|
|
|
|
PUSH_TO_HUB = False |
|
|
HUB_REPO_ID = "AbstractPhil/geovit-david-beans" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if PRESET == "test": |
|
|
print("π§ͺ Quick test...") |
|
|
model_config = DavidBeansConfig( |
|
|
image_size=32, patch_size=4, dim=128, num_layers=2, |
|
|
num_heads=4, num_experts=5, k_neighbors=8, |
|
|
scales=[32, 64, 128], num_classes=10 |
|
|
) |
|
|
train_config = TrainingConfig( |
|
|
run_name="test", |
|
|
epochs=2, batch_size=32, |
|
|
use_augmentation=False, mixup_alpha=0.0, cutmix_alpha=0.0 |
|
|
) |
|
|
model, acc = train_david_beans(model_config, train_config) |
|
|
|
|
|
elif PRESET == "small": |
|
|
print("π«π Training DavidBeans - Small (CIFAR-10)...") |
|
|
model, acc = train_cifar10_small() |
|
|
|
|
|
elif PRESET == "cifar100": |
|
|
print("π«π Training DavidBeans - CIFAR-100...") |
|
|
model, acc = train_cifar100( |
|
|
run_name=RUN_NAME, |
|
|
push_to_hub=PUSH_TO_HUB, |
|
|
hub_repo_id=HUB_REPO_ID, |
|
|
resume=RESUME |
|
|
) |
|
|
|
|
|
elif PRESET == "resume": |
|
|
print("π Resuming training from latest checkpoint...") |
|
|
model, acc = resume_training( |
|
|
checkpoint_dir="./checkpoints/cifar100", |
|
|
push_to_hub=PUSH_TO_HUB, |
|
|
hub_repo_id=HUB_REPO_ID |
|
|
) |
|
|
|
|
|
else: |
|
|
raise ValueError(f"Unknown preset: {PRESET}") |
|
|
|
|
|
print(f"\nπ Done! Best accuracy: {acc:.2f}%") |