geovit-david-beans / trainer_v1.py
AbstractPhil's picture
Create trainer_v1.py
18e6c6b verified
"""
Train DavidBeans: The Dynamic Duo
==================================
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚ BEANS β”‚ "I see the patches..."
β”‚ (ViT Backbone)β”‚
β”‚ 🫘 β†’ 🫘 β†’ 🫘 β”‚ Cantor-routed sparse attention
β””β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”˜
β”‚
β–Ό
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚ DAVID β”‚ "I know the crystals..."
β”‚ (Classifier) β”‚
β”‚ πŸ’Ž β†’ πŸ’Ž β†’ πŸ’Ž β”‚ Multi-scale projection
β””β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”˜
β”‚
β–Ό
[Prediction]
Cross-contrast learning aligns patch features with crystal anchors.
Unified Cayley-Menger loss maintains geometric structure throughout.
Features:
- HuggingFace Hub integration for model upload
- Automatic model card generation
- Checkpoint management
Author: AbstractPhil
Date: November 28, 2025
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR, OneCycleLR
from tqdm.auto import tqdm
import time
import math
from pathlib import Path
from typing import Dict, Optional, Tuple, List
from dataclasses import dataclass, field
import json
import os
from datetime import datetime
# Import the model
from geofractal.model.david_beans.model import DavidBeans, DavidBeansConfig
# HuggingFace Hub integration
try:
from huggingface_hub import HfApi, create_repo, upload_folder
HF_HUB_AVAILABLE = True
except ImportError:
HF_HUB_AVAILABLE = False
print(" [!] huggingface_hub not installed. Run: pip install huggingface_hub")
# Safetensors support
try:
from safetensors.torch import save_file as save_safetensors
SAFETENSORS_AVAILABLE = True
except ImportError:
SAFETENSORS_AVAILABLE = False
# TensorBoard support
try:
from torch.utils.tensorboard import SummaryWriter
TENSORBOARD_AVAILABLE = True
except ImportError:
TENSORBOARD_AVAILABLE = False
print(" [!] tensorboard not installed. Run: pip install tensorboard")
# ============================================================================
# TRAINING CONFIGURATION
# ============================================================================
@dataclass
class TrainingConfig:
"""Training hyperparameters."""
# Run identification
run_name: str = "default" # Descriptive name for this run
run_number: Optional[int] = None # Auto-incremented if None
# Data
dataset: str = "cifar10"
image_size: int = 32
batch_size: int = 128
num_workers: int = 4
# Training schedule
epochs: int = 100
warmup_epochs: int = 5
# Optimizer
learning_rate: float = 1e-3
weight_decay: float = 0.05
betas: Tuple[float, float] = (0.9, 0.999)
# Learning rate schedule
scheduler: str = "cosine"
min_lr: float = 1e-6
# Loss weights
ce_weight: float = 1.0
cayley_weight: float = 0.01
contrast_weight: float = 0.5
scale_ce_weight: float = 0.1
# Regularization
gradient_clip: float = 1.0
label_smoothing: float = 0.1
# Augmentation
use_augmentation: bool = True
mixup_alpha: float = 0.2
cutmix_alpha: float = 1.0
# Checkpointing
save_interval: int = 10
output_dir: str = "./checkpoints"
resume_from: Optional[str] = None # Path to checkpoint or "latest"
# TensorBoard
use_tensorboard: bool = True
log_interval: int = 50 # Log every N batches
# HuggingFace Hub
push_to_hub: bool = False
hub_repo_id: Optional[str] = None
hub_private: bool = False
hub_append_run: bool = True # Append run info to repo_id (e.g., repo-run001-baseline)
# Device
device: str = "cuda" if torch.cuda.is_available() else "cpu"
def to_dict(self) -> Dict:
return {k: v for k, v in self.__dict__.items()}
# ============================================================================
# HUGGINGFACE HUB INTEGRATION
# ============================================================================
def generate_model_card(
model_config: DavidBeansConfig,
train_config: TrainingConfig,
best_acc: float,
training_history: Optional[Dict] = None
) -> str:
"""Generate a model card for HuggingFace Hub."""
scales_str = ", ".join([str(s) for s in model_config.scales])
dataset_info = {
"cifar10": ("CIFAR-10", 10, "Image classification on 32x32 images"),
"cifar100": ("CIFAR-100", 100, "Fine-grained image classification on 32x32 images"),
}.get(train_config.dataset, (train_config.dataset, model_config.num_classes, ""))
card_content = f"""---
library_name: pytorch
license: apache-2.0
tags:
- vision
- image-classification
- geometric-deep-learning
- vit
- cantor-routing
- pentachoron
- multi-scale
datasets:
- {train_config.dataset}
metrics:
- accuracy
model-index:
- name: DavidBeans
results:
- task:
type: image-classification
name: Image Classification
dataset:
name: {dataset_info[0]}
type: {train_config.dataset}
metrics:
- type: accuracy
value: {best_acc:.2f}
name: Top-1 Accuracy
---
# πŸ«˜πŸ’Ž DavidBeans: Unified Vision-to-Crystal Architecture
DavidBeans combines **ViT-Beans** (Cantor-routed sparse attention) with **David** (multi-scale crystal classification) into a unified geometric deep learning architecture.
## Model Description
This model implements several novel techniques:
- **Hybrid Cantor Routing**: Combines fractal Cantor set distances with positional proximity for sparse attention patterns
- **Pentachoron Experts**: 5-vertex simplex structure with Cayley-Menger geometric regularization
- **Multi-Scale Crystal Projection**: Projects features to multiple representation scales with learned fusion
- **Cross-Contrastive Learning**: Aligns patch-level features with crystal anchors
## Architecture
```
Image [B, 3, {model_config.image_size}, {model_config.image_size}]
β”‚
β–Ό
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚ BEANS BACKBONE β”‚
β”‚ β”œβ”€ Patch Embed β†’ [{model_config.num_patches} patches, {model_config.dim}d]
β”‚ β”œβ”€ Hybrid Cantor Router (Ξ±={model_config.cantor_weight})
β”‚ β”œβ”€ {model_config.num_layers} Γ— Attention Blocks ({model_config.num_heads} heads)
β”‚ └─ {model_config.num_layers} Γ— Pentachoron Expert Layers
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
β”‚
β–Ό
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚ DAVID HEAD β”‚
β”‚ β”œβ”€ Multi-scale projection: [{scales_str}]
β”‚ β”œβ”€ Per-scale Crystal Heads
β”‚ └─ Geometric Fusion (learned weights)
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
β”‚
β–Ό
[{model_config.num_classes} classes]
```
## Training Details
| Parameter | Value |
|-----------|-------|
| Dataset | {dataset_info[0]} |
| Classes | {model_config.num_classes} |
| Image Size | {model_config.image_size}Γ—{model_config.image_size} |
| Patch Size | {model_config.patch_size}Γ—{model_config.patch_size} |
| Embedding Dim | {model_config.dim} |
| Layers | {model_config.num_layers} |
| Attention Heads | {model_config.num_heads} |
| Experts | {model_config.num_experts} (pentachoron) |
| Sparse Neighbors | k={model_config.k_neighbors} |
| Scales | [{scales_str}] |
| Epochs | {train_config.epochs} |
| Batch Size | {train_config.batch_size} |
| Learning Rate | {train_config.learning_rate} |
| Weight Decay | {train_config.weight_decay} |
| Mixup Ξ± | {train_config.mixup_alpha} |
| CutMix Ξ± | {train_config.cutmix_alpha} |
| Label Smoothing | {train_config.label_smoothing} |
## Results
| Metric | Value |
|--------|-------|
| **Top-1 Accuracy** | **{best_acc:.2f}%** |
## TensorBoard Logs
Training logs are included in the `tensorboard/` directory. To view:
```bash
tensorboard --logdir tensorboard/
```
## Usage
```python
import torch
from safetensors.torch import load_file
from david_beans import DavidBeans, DavidBeansConfig
# Load config
config = DavidBeansConfig(
image_size={model_config.image_size},
patch_size={model_config.patch_size},
dim={model_config.dim},
num_layers={model_config.num_layers},
num_heads={model_config.num_heads},
num_experts={model_config.num_experts},
k_neighbors={model_config.k_neighbors},
cantor_weight={model_config.cantor_weight},
scales={model_config.scales},
num_classes={model_config.num_classes}
)
# Create model and load weights
model = DavidBeans(config)
state_dict = load_file("model.safetensors")
model.load_state_dict(state_dict)
# Inference
model.eval()
with torch.no_grad():
output = model(images)
predictions = output['logits'].argmax(dim=-1)
```
## Citation
```bibtex
@misc{{davidbeans2025,
author = {{AbstractPhil}},
title = {{DavidBeans: Unified Vision-to-Crystal Architecture}},
year = {{2025}},
publisher = {{HuggingFace}},
url = {{https://huggingface.co/{train_config.hub_repo_id or 'AbstractPhil/david-beans'}}}
}}
```
## License
Apache 2.0
"""
return card_content
def save_for_hub(
model: DavidBeans,
model_config: DavidBeansConfig,
train_config: TrainingConfig,
best_acc: float,
output_dir: Path,
training_history: Optional[Dict] = None
) -> Path:
"""Save model in HuggingFace Hub format."""
hub_dir = output_dir / "hub"
hub_dir.mkdir(parents=True, exist_ok=True)
# 1. Save model weights - clone to avoid shared memory issues
state_dict = {k: v.clone() for k, v in model.state_dict().items()}
if SAFETENSORS_AVAILABLE:
try:
save_safetensors(state_dict, hub_dir / "model.safetensors")
print(f" βœ“ Saved model.safetensors")
except Exception as e:
print(f" [!] Safetensors failed ({e}), using pytorch format only")
# Also save PyTorch format for compatibility
torch.save(state_dict, hub_dir / "pytorch_model.bin")
print(f" βœ“ Saved pytorch_model.bin")
# 2. Save config
config_dict = {
"architecture": "DavidBeans",
"model_type": "david_beans",
**model_config.__dict__
}
with open(hub_dir / "config.json", "w") as f:
json.dump(config_dict, f, indent=2, default=str)
print(f" βœ“ Saved config.json")
# 3. Save training config
with open(hub_dir / "training_config.json", "w") as f:
json.dump(train_config.to_dict(), f, indent=2, default=str)
# 4. Generate and save model card
model_card = generate_model_card(model_config, train_config, best_acc, training_history)
with open(hub_dir / "README.md", "w") as f:
f.write(model_card)
print(f" βœ“ Generated README.md (model card)")
# 5. Save training history if available
if training_history:
with open(hub_dir / "training_history.json", "w") as f:
json.dump(training_history, f, indent=2)
# 6. Copy TensorBoard logs if they exist
tb_dir = output_dir / "tensorboard"
if tb_dir.exists():
import shutil
hub_tb_dir = hub_dir / "tensorboard"
if hub_tb_dir.exists():
shutil.rmtree(hub_tb_dir)
shutil.copytree(tb_dir, hub_tb_dir)
print(f" βœ“ Copied TensorBoard logs")
return hub_dir
def push_to_hub(
hub_dir: Path,
repo_id: str,
private: bool = False,
commit_message: Optional[str] = None
) -> str:
"""Push model to HuggingFace Hub."""
if not HF_HUB_AVAILABLE:
raise RuntimeError("huggingface_hub not installed. Run: pip install huggingface_hub")
api = HfApi()
# Create repo if it doesn't exist
try:
create_repo(repo_id, private=private, exist_ok=True)
print(f" βœ“ Repository ready: {repo_id}")
except Exception as e:
print(f" [!] Repo creation note: {e}")
# Upload
if commit_message is None:
commit_message = f"Upload DavidBeans model - {datetime.now().strftime('%Y-%m-%d %H:%M')}"
url = upload_folder(
folder_path=str(hub_dir),
repo_id=repo_id,
commit_message=commit_message
)
print(f" βœ“ Uploaded to: https://huggingface.co/{repo_id}")
return url
# ============================================================================
# DATA LOADING
# ============================================================================
def get_dataloaders(config: TrainingConfig) -> Tuple[DataLoader, DataLoader, int]:
"""Get train and test dataloaders."""
try:
import torchvision
import torchvision.transforms as T
if config.dataset == "cifar10":
if config.use_augmentation:
train_transform = T.Compose([
T.RandomCrop(32, padding=4),
T.RandomHorizontalFlip(),
T.AutoAugment(T.AutoAugmentPolicy.CIFAR10),
T.ToTensor(),
T.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))
])
else:
train_transform = T.Compose([
T.ToTensor(),
T.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))
])
test_transform = T.Compose([
T.ToTensor(),
T.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))
])
train_dataset = torchvision.datasets.CIFAR10(
root='./data', train=True, download=True, transform=train_transform
)
test_dataset = torchvision.datasets.CIFAR10(
root='./data', train=False, download=True, transform=test_transform
)
num_classes = 10
elif config.dataset == "cifar100":
if config.use_augmentation:
train_transform = T.Compose([
T.RandomCrop(32, padding=4),
T.RandomHorizontalFlip(),
T.AutoAugment(T.AutoAugmentPolicy.CIFAR10),
T.ToTensor(),
T.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761))
])
else:
train_transform = T.Compose([
T.ToTensor(),
T.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761))
])
test_transform = T.Compose([
T.ToTensor(),
T.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761))
])
train_dataset = torchvision.datasets.CIFAR100(
root='./data', train=True, download=True, transform=train_transform
)
test_dataset = torchvision.datasets.CIFAR100(
root='./data', train=False, download=True, transform=test_transform
)
num_classes = 100
else:
raise ValueError(f"Unknown dataset: {config.dataset}")
train_loader = DataLoader(
train_dataset,
batch_size=config.batch_size,
shuffle=True,
num_workers=config.num_workers,
pin_memory=True,
persistent_workers=config.num_workers > 0,
drop_last=True
)
test_loader = DataLoader(
test_dataset,
batch_size=config.batch_size,
shuffle=False,
num_workers=config.num_workers,
pin_memory=True,
persistent_workers=config.num_workers > 0
)
return train_loader, test_loader, num_classes
except ImportError:
print(" [!] torchvision not available, using synthetic data")
return get_synthetic_dataloaders(config)
def get_synthetic_dataloaders(config: TrainingConfig) -> Tuple[DataLoader, DataLoader, int]:
"""Fallback synthetic data for testing."""
class SyntheticDataset(torch.utils.data.Dataset):
def __init__(self, size: int, image_size: int, num_classes: int):
self.size = size
self.image_size = image_size
self.num_classes = num_classes
def __len__(self):
return self.size
def __getitem__(self, idx):
x = torch.randn(3, self.image_size, self.image_size)
y = idx % self.num_classes
return x, y
num_classes = 10
train_dataset = SyntheticDataset(5000, config.image_size, num_classes)
test_dataset = SyntheticDataset(1000, config.image_size, num_classes)
train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=config.batch_size, shuffle=False)
return train_loader, test_loader, num_classes
# ============================================================================
# MIXUP / CUTMIX AUGMENTATION
# ============================================================================
def mixup_data(x: torch.Tensor, y: torch.Tensor, alpha: float = 0.2):
"""Mixup augmentation."""
if alpha > 0:
lam = torch.distributions.Beta(alpha, alpha).sample().item()
else:
lam = 1.0
batch_size = x.size(0)
index = torch.randperm(batch_size, device=x.device)
mixed_x = lam * x + (1 - lam) * x[index]
y_a, y_b = y, y[index]
return mixed_x, y_a, y_b, lam
def cutmix_data(x: torch.Tensor, y: torch.Tensor, alpha: float = 1.0):
"""CutMix augmentation."""
if alpha > 0:
lam = torch.distributions.Beta(alpha, alpha).sample().item()
else:
lam = 1.0
batch_size = x.size(0)
index = torch.randperm(batch_size, device=x.device)
_, _, H, W = x.shape
cut_ratio = math.sqrt(1 - lam)
cut_h = int(H * cut_ratio)
cut_w = int(W * cut_ratio)
cx = torch.randint(0, H, (1,)).item()
cy = torch.randint(0, W, (1,)).item()
x1 = max(0, cx - cut_h // 2)
x2 = min(H, cx + cut_h // 2)
y1 = max(0, cy - cut_w // 2)
y2 = min(W, cy + cut_w // 2)
mixed_x = x.clone()
mixed_x[:, :, x1:x2, y1:y2] = x[index, :, x1:x2, y1:y2]
lam = 1 - ((x2 - x1) * (y2 - y1)) / (H * W)
y_a, y_b = y, y[index]
return mixed_x, y_a, y_b, lam
# ============================================================================
# METRICS TRACKER
# ============================================================================
class MetricsTracker:
"""Track training metrics with EMA smoothing."""
def __init__(self, ema_decay: float = 0.9):
self.ema_decay = ema_decay
self.metrics = {}
self.ema_metrics = {}
self.history = {}
def update(self, **kwargs):
for k, v in kwargs.items():
if isinstance(v, torch.Tensor):
v = v.item()
if k not in self.metrics:
self.metrics[k] = []
self.ema_metrics[k] = v
self.history[k] = []
self.metrics[k].append(v)
self.ema_metrics[k] = self.ema_decay * self.ema_metrics[k] + (1 - self.ema_decay) * v
def get_ema(self, key: str) -> float:
return self.ema_metrics.get(key, 0.0)
def get_epoch_mean(self, key: str) -> float:
values = self.metrics.get(key, [])
return sum(values) / len(values) if values else 0.0
def end_epoch(self):
for k, v in self.metrics.items():
if v:
self.history[k].append(sum(v) / len(v))
self.metrics = {k: [] for k in self.metrics}
def get_history(self) -> Dict:
return self.history
# ============================================================================
# CHECKPOINT UTILITIES
# ============================================================================
def find_latest_checkpoint(output_dir: Path) -> Optional[Path]:
"""Find the most recent checkpoint in output directory."""
checkpoints = list(output_dir.glob("checkpoint_epoch_*.pt"))
if not checkpoints:
# Try best_model.pt as fallback
best_model = output_dir / "best_model.pt"
if best_model.exists():
return best_model
return None
# Sort by epoch number
def get_epoch(p):
try:
return int(p.stem.split("_")[-1])
except:
return 0
checkpoints.sort(key=get_epoch, reverse=True)
return checkpoints[0]
def get_next_run_number(base_dir: Path) -> int:
"""Get the next run number by scanning existing run directories."""
if not base_dir.exists():
return 1
max_num = 0
for d in base_dir.iterdir():
if d.is_dir() and d.name.startswith("run_"):
try:
# Extract number from "run_XXX_name_timestamp"
num = int(d.name.split("_")[1])
max_num = max(max_num, num)
except (IndexError, ValueError):
continue
return max_num + 1
def generate_run_dir_name(run_number: int, run_name: str) -> str:
"""Generate a run directory name with number, name, and timestamp."""
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
# Sanitize run_name: lowercase, replace spaces with underscores, remove special chars
safe_name = "".join(c if c.isalnum() or c == "_" else "_" for c in run_name.lower())
safe_name = "_".join(filter(None, safe_name.split("_"))) # Remove consecutive underscores
return f"run_{run_number:03d}_{safe_name}_{timestamp}"
def find_latest_run_dir(base_dir: Path) -> Optional[Path]:
"""Find the most recent run directory."""
if not base_dir.exists():
return None
run_dirs = [d for d in base_dir.iterdir() if d.is_dir() and d.name.startswith("run_")]
if not run_dirs:
return None
# Sort by modification time (most recent first)
run_dirs.sort(key=lambda d: d.stat().st_mtime, reverse=True)
return run_dirs[0]
def find_checkpoint_in_runs(base_dir: Path, resume_from: str) -> Optional[Path]:
"""
Find a checkpoint to resume from.
Args:
base_dir: Base checkpoint directory (e.g., ./checkpoints/cifar100)
resume_from: Either "latest", a run directory name, or a full path
Returns:
Path to checkpoint file, or None
"""
if resume_from == "latest":
# Find most recent run directory
run_dir = find_latest_run_dir(base_dir)
if run_dir:
return find_latest_checkpoint(run_dir)
# Fallback: check base_dir itself (for old-style checkpoints)
return find_latest_checkpoint(base_dir)
# Check if it's a full path
full_path = Path(resume_from)
if full_path.exists():
if full_path.is_file():
return full_path
elif full_path.is_dir():
return find_latest_checkpoint(full_path)
# Check if it's a run directory name within base_dir
run_path = base_dir / resume_from
if run_path.exists():
return find_latest_checkpoint(run_path)
return None
def load_checkpoint(
checkpoint_path: Path,
model: DavidBeans,
optimizer: Optional[torch.optim.Optimizer] = None,
device: str = "cuda"
) -> Tuple[int, float]:
"""
Load checkpoint and return (start_epoch, best_acc).
Returns:
start_epoch: Epoch to resume from (checkpoint_epoch + 1)
best_acc: Best accuracy so far
"""
print(f"\nπŸ“‚ Loading checkpoint: {checkpoint_path}")
checkpoint = torch.load(checkpoint_path, map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
print(f" βœ“ Loaded model weights")
if optimizer is not None and 'optimizer_state_dict' in checkpoint:
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
print(f" βœ“ Loaded optimizer state")
epoch = checkpoint.get('epoch', 0)
best_acc = checkpoint.get('best_acc', 0.0)
print(f" βœ“ Loaded checkpoint from epoch {epoch + 1}, best_acc={best_acc:.2f}%")
print(f" βœ“ Will resume training from epoch {epoch + 2}")
return epoch + 1, best_acc
def get_config_from_checkpoint(checkpoint_path: Path) -> Tuple[DavidBeansConfig, dict]:
"""
Extract model and training configs from a checkpoint.
Returns:
(model_config, train_config_dict)
"""
checkpoint = torch.load(checkpoint_path, map_location='cpu')
model_config_dict = checkpoint.get('model_config', {})
train_config_dict = checkpoint.get('train_config', {})
# Handle tuple conversion for betas
if 'betas' in train_config_dict and isinstance(train_config_dict['betas'], list):
train_config_dict['betas'] = tuple(train_config_dict['betas'])
model_config = DavidBeansConfig(**model_config_dict)
return model_config, train_config_dict
# ============================================================================
# TRAINING LOOP
# ============================================================================
def train_epoch(
model: DavidBeans,
train_loader: DataLoader,
optimizer: torch.optim.Optimizer,
scheduler: Optional[torch.optim.lr_scheduler._LRScheduler],
config: TrainingConfig,
epoch: int,
tracker: MetricsTracker,
writer: Optional['SummaryWriter'] = None
) -> Dict[str, float]:
"""Train for one epoch."""
model.train()
device = config.device
total_loss = 0.0
total_correct = 0
total_samples = 0
global_step = epoch * len(train_loader)
pbar = tqdm(train_loader, desc=f"Epoch {epoch + 1}", leave=True)
for batch_idx, (images, targets) in enumerate(pbar):
images = images.to(device, non_blocking=True)
targets = targets.to(device, non_blocking=True)
# Apply mixup/cutmix
use_mixup = config.use_augmentation and config.mixup_alpha > 0
use_cutmix = config.use_augmentation and config.cutmix_alpha > 0
mixed = False
if use_mixup or use_cutmix:
r = torch.rand(1).item()
if r < 0.5:
pass
elif r < 0.75 and use_mixup:
images, targets_a, targets_b, lam = mixup_data(images, targets, config.mixup_alpha)
mixed = True
elif use_cutmix:
images, targets_a, targets_b, lam = cutmix_data(images, targets, config.cutmix_alpha)
mixed = True
# Forward pass
result = model(images, targets=targets, return_loss=True)
losses = result['losses']
if mixed:
logits = result['logits']
ce_loss = lam * F.cross_entropy(logits, targets_a, label_smoothing=config.label_smoothing) + \
(1 - lam) * F.cross_entropy(logits, targets_b, label_smoothing=config.label_smoothing)
losses['ce'] = ce_loss
# Compute total loss
loss = (
config.ce_weight * losses['ce'] +
config.cayley_weight * losses.get('geometric', torch.tensor(0.0, device=device)) +
config.contrast_weight * losses.get('contrast', torch.tensor(0.0, device=device))
)
for scale in model.config.scales:
scale_ce = losses.get(f'ce_{scale}', 0.0)
if isinstance(scale_ce, torch.Tensor):
loss = loss + config.scale_ce_weight * scale_ce
# Backward pass
optimizer.zero_grad()
loss.backward()
if config.gradient_clip > 0:
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), config.gradient_clip)
else:
grad_norm = 0.0
optimizer.step()
if scheduler is not None and config.scheduler == "onecycle":
scheduler.step()
# Compute accuracy
with torch.no_grad():
logits = result['logits']
preds = logits.argmax(dim=-1)
if mixed:
correct = (lam * (preds == targets_a).float() +
(1 - lam) * (preds == targets_b).float()).sum()
else:
correct = (preds == targets).sum()
total_correct += correct.item()
total_samples += targets.size(0)
total_loss += loss.item()
# Track metrics
def to_float(v):
return v.item() if isinstance(v, torch.Tensor) else float(v)
geo_loss = to_float(losses.get('geometric', 0.0))
contrast_loss = to_float(losses.get('contrast', 0.0))
expert_vol = to_float(losses.get('expert_volume', 0.0))
expert_collapse = to_float(losses.get('expert_collapse', 0.0))
expert_edge = to_float(losses.get('expert_edge_dev', 0.0))
current_lr = optimizer.param_groups[0]['lr']
tracker.update(
loss=loss.item(),
ce=losses['ce'].item(),
geo=geo_loss,
contrast=contrast_loss,
expert_vol=expert_vol,
expert_collapse=expert_collapse,
expert_edge=expert_edge,
lr=current_lr
)
# TensorBoard logging (every log_interval batches)
if writer is not None and (batch_idx + 1) % config.log_interval == 0:
step = global_step + batch_idx
# Loss components
writer.add_scalar('train/loss_total', loss.item(), step)
writer.add_scalar('train/loss_ce', losses['ce'].item(), step)
writer.add_scalar('train/loss_geometric', geo_loss, step)
writer.add_scalar('train/loss_contrast', contrast_loss, step)
# Geometric metrics
writer.add_scalar('train/expert_volume', expert_vol, step)
writer.add_scalar('train/expert_collapse', expert_collapse, step)
writer.add_scalar('train/expert_edge_dev', expert_edge, step)
# Training dynamics
writer.add_scalar('train/learning_rate', current_lr, step)
writer.add_scalar('train/grad_norm', to_float(grad_norm), step)
writer.add_scalar('train/batch_acc', 100.0 * correct.item() / targets.size(0), step)
pbar.set_postfix({
'loss': f"{tracker.get_ema('loss'):.3f}",
'acc': f"{100.0 * total_correct / total_samples:.1f}%",
'geo': f"{tracker.get_ema('geo'):.4f}",
'vol': f"{tracker.get_ema('expert_vol'):.4f}"
})
if scheduler is not None and config.scheduler == "cosine":
scheduler.step()
return {
'loss': total_loss / len(train_loader),
'acc': 100.0 * total_correct / total_samples
}
@torch.no_grad()
def evaluate(
model: DavidBeans,
test_loader: DataLoader,
config: TrainingConfig
) -> Dict[str, float]:
"""Evaluate on test set."""
model.eval()
device = config.device
total_loss = 0.0
total_correct = 0
total_samples = 0
scale_correct = {s: 0 for s in model.config.scales}
for images, targets in test_loader:
images = images.to(device, non_blocking=True)
targets = targets.to(device, non_blocking=True)
result = model(images, targets=targets, return_loss=True)
logits = result['logits']
losses = result['losses']
loss = losses['total']
preds = logits.argmax(dim=-1)
total_loss += loss.item() * targets.size(0)
total_correct += (preds == targets).sum().item()
total_samples += targets.size(0)
for i, scale in enumerate(model.config.scales):
scale_logits = result['scale_logits'][i]
scale_preds = scale_logits.argmax(dim=-1)
scale_correct[scale] += (scale_preds == targets).sum().item()
metrics = {
'loss': total_loss / total_samples,
'acc': 100.0 * total_correct / total_samples
}
for scale, correct in scale_correct.items():
metrics[f'acc_{scale}'] = 100.0 * correct / total_samples
return metrics
# ============================================================================
# MAIN TRAINING FUNCTION
# ============================================================================
def train_david_beans(
model_config: Optional[DavidBeansConfig] = None,
train_config: Optional[TrainingConfig] = None
):
"""Main training function."""
print("=" * 70)
print(" DAVID-BEANS TRAINING: The Dynamic Duo")
print("=" * 70)
print()
print(" 🫘 BEANS (ViT) + πŸ’Ž DAVID (Crystal)")
print(" Sparse Attention Multi-Scale Projection")
print()
print("=" * 70)
if train_config is None:
train_config = TrainingConfig()
base_output_dir = Path(train_config.output_dir)
base_output_dir.mkdir(parents=True, exist_ok=True)
# Check for resume FIRST - load config from checkpoint if resuming
checkpoint_path = None
run_dir = None # Will be set either from resume or new run
if train_config.resume_from:
# Find checkpoint using the new directory structure
checkpoint_path = find_checkpoint_in_runs(base_output_dir, train_config.resume_from)
if checkpoint_path and checkpoint_path.exists():
print(f"\nπŸ“‚ Found checkpoint: {checkpoint_path}")
# The run directory is the parent of the checkpoint
run_dir = checkpoint_path.parent
print(f" βœ“ Resuming in run directory: {run_dir.name}")
# Load config from checkpoint to ensure architecture matches
loaded_model_config, loaded_train_config_dict = get_config_from_checkpoint(checkpoint_path)
if model_config is None:
model_config = loaded_model_config
print(f" βœ“ Using model config from checkpoint")
else:
# Warn if configs differ
if model_config.dim != loaded_model_config.dim or model_config.scales != loaded_model_config.scales:
print(f" ⚠ WARNING: Provided config differs from checkpoint!")
print(f" Checkpoint: dim={loaded_model_config.dim}, scales={loaded_model_config.scales}")
print(f" Provided: dim={model_config.dim}, scales={model_config.scales}")
print(f" βœ“ Using checkpoint config to ensure compatibility")
model_config = loaded_model_config
else:
print(f" [!] Checkpoint not found: {train_config.resume_from}")
checkpoint_path = None
# If not resuming (or resume failed), create new run directory
if run_dir is None:
# Get run number
if train_config.run_number is None:
run_number = get_next_run_number(base_output_dir)
else:
run_number = train_config.run_number
# Generate run directory name
run_dir_name = generate_run_dir_name(run_number, train_config.run_name)
run_dir = base_output_dir / run_dir_name
run_dir.mkdir(parents=True, exist_ok=True)
print(f"\nπŸ“ New run: {run_dir_name}")
print(f" Run #{run_number}: {train_config.run_name}")
else:
# Extract run number from existing directory name for hub repo
try:
run_number = int(run_dir.name.split("_")[1])
except (IndexError, ValueError):
run_number = 1
# Update output_dir to point to the run directory
output_dir = run_dir
# Generate effective hub repo ID with run info
effective_hub_repo_id = train_config.hub_repo_id
if train_config.hub_repo_id and train_config.hub_append_run:
# Extract run name from directory (run_XXX_name_timestamp -> name)
parts = run_dir.name.split("_")
if len(parts) >= 3:
run_name_part = parts[2] # Get the name part
else:
run_name_part = train_config.run_name
effective_hub_repo_id = f"{train_config.hub_repo_id}-run{run_number:03d}-{run_name_part}"
print(f" Hub repo: {effective_hub_repo_id}")
if model_config is None:
model_config = DavidBeansConfig(
image_size=train_config.image_size,
patch_size=4,
dim=256,
num_layers=6,
num_heads=8,
num_experts=5,
k_neighbors=16,
cantor_weight=0.3,
scales=[64, 128, 256],
num_classes=10,
contrast_weight=train_config.contrast_weight,
cayley_weight=train_config.cayley_weight,
dropout=0.1
)
device = train_config.device
print(f"\nDevice: {device}")
# Data
print("\nLoading data...")
train_loader, test_loader, num_classes = get_dataloaders(train_config)
print(f" Dataset: {train_config.dataset}")
print(f" Train: {len(train_loader.dataset)}, Test: {len(test_loader.dataset)}")
print(f" Classes: {num_classes}")
model_config.num_classes = num_classes
# Model
print("\nBuilding model...")
model = DavidBeans(model_config)
model = model.to(device)
print(f"\n{model}")
num_params = sum(p.numel() for p in model.parameters())
num_trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"\nParameters: {num_params:,} ({num_trainable:,} trainable)")
# Optimizer
print("\nSetting up optimizer...")
decay_params = []
no_decay_params = []
for name, param in model.named_parameters():
if not param.requires_grad:
continue
if 'bias' in name or 'norm' in name or 'embedding' in name:
no_decay_params.append(param)
else:
decay_params.append(param)
optimizer = AdamW([
{'params': decay_params, 'weight_decay': train_config.weight_decay},
{'params': no_decay_params, 'weight_decay': 0.0}
], lr=train_config.learning_rate, betas=train_config.betas)
if train_config.scheduler == "cosine":
scheduler = CosineAnnealingLR(
optimizer,
T_max=train_config.epochs - train_config.warmup_epochs,
eta_min=train_config.min_lr
)
elif train_config.scheduler == "onecycle":
scheduler = OneCycleLR(
optimizer,
max_lr=train_config.learning_rate,
epochs=train_config.epochs,
steps_per_epoch=len(train_loader),
pct_start=train_config.warmup_epochs / train_config.epochs
)
else:
scheduler = None
print(f" Optimizer: AdamW (lr={train_config.learning_rate}, wd={train_config.weight_decay})")
print(f" Scheduler: {train_config.scheduler}")
print(f" TensorBoard: {output_dir / 'tensorboard'}")
tracker = MetricsTracker()
best_acc = 0.0
start_epoch = 0
print(f"\nOutput directory: {output_dir}")
# Load weights from checkpoint if we found one earlier
if checkpoint_path and checkpoint_path.exists():
start_epoch, best_acc = load_checkpoint(
checkpoint_path, model, optimizer, device
)
# Adjust scheduler to correct position
if scheduler is not None and train_config.scheduler == "cosine":
for _ in range(start_epoch):
scheduler.step()
# TensorBoard setup
writer = None
if train_config.use_tensorboard and TENSORBOARD_AVAILABLE:
tb_dir = output_dir / "tensorboard"
tb_dir.mkdir(parents=True, exist_ok=True)
writer = SummaryWriter(log_dir=str(tb_dir))
print(f" TensorBoard: {tb_dir}")
# Log model config as text
config_text = json.dumps(model_config.__dict__, indent=2, default=str)
writer.add_text("config/model", f"```json\n{config_text}\n```", 0)
train_text = json.dumps(train_config.to_dict(), indent=2, default=str)
writer.add_text("config/training", f"```json\n{train_text}\n```", 0)
elif train_config.use_tensorboard:
print(" [!] TensorBoard requested but not available")
with open(output_dir / "model_config.json", "w") as f:
json.dump(model_config.__dict__, f, indent=2, default=str)
with open(output_dir / "train_config.json", "w") as f:
json.dump(train_config.to_dict(), f, indent=2, default=str)
print(f"\nOutput directory: {output_dir}")
# Training loop
print("\n" + "=" * 70)
print(" TRAINING")
print("=" * 70)
if start_epoch > 0:
print(f" Resuming from epoch {start_epoch + 1}/{train_config.epochs}")
for epoch in range(start_epoch, train_config.epochs):
epoch_start = time.time()
if epoch < train_config.warmup_epochs and train_config.scheduler == "cosine":
warmup_lr = train_config.learning_rate * (epoch + 1) / train_config.warmup_epochs
for param_group in optimizer.param_groups:
param_group['lr'] = warmup_lr
train_metrics = train_epoch(
model, train_loader, optimizer, scheduler,
train_config, epoch, tracker, writer
)
test_metrics = evaluate(model, test_loader, train_config)
epoch_time = time.time() - epoch_start
# TensorBoard epoch logging
if writer is not None:
# Epoch-level metrics
writer.add_scalar('epoch/train_loss', train_metrics['loss'], epoch)
writer.add_scalar('epoch/train_acc', train_metrics['acc'], epoch)
writer.add_scalar('epoch/test_loss', test_metrics['loss'], epoch)
writer.add_scalar('epoch/test_acc', test_metrics['acc'], epoch)
writer.add_scalar('epoch/learning_rate', optimizer.param_groups[0]['lr'], epoch)
writer.add_scalar('epoch/time_seconds', epoch_time, epoch)
# Per-scale accuracies
for scale in model.config.scales:
writer.add_scalar(f'scales/acc_{scale}', test_metrics[f'acc_{scale}'], epoch)
# Generalization gap
writer.add_scalar('epoch/generalization_gap', test_metrics['acc'] - train_metrics['acc'], epoch)
# Flush periodically
if (epoch + 1) % 5 == 0:
writer.flush()
scale_accs = " | ".join([f"{s}:{test_metrics[f'acc_{s}']:.1f}%" for s in model.config.scales])
star = "β˜…" if test_metrics['acc'] > best_acc else ""
print(f" β†’ Train: {train_metrics['acc']:.1f}% | Test: {test_metrics['acc']:.1f}% | "
f"Scales: [{scale_accs}] | {epoch_time:.0f}s {star}")
if test_metrics['acc'] > best_acc:
best_acc = test_metrics['acc']
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'best_acc': best_acc,
'model_config': model_config.__dict__,
'train_config': train_config.to_dict()
}, output_dir / "best_model.pt")
if (epoch + 1) % train_config.save_interval == 0:
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'best_acc': best_acc
}, output_dir / f"checkpoint_epoch_{epoch + 1}.pt")
# Periodic HuggingFace Hub upload
if train_config.push_to_hub and HF_HUB_AVAILABLE and effective_hub_repo_id:
try:
# Save current best for upload
checkpoint = torch.load(output_dir / "best_model.pt", map_location='cpu')
model_cpu = DavidBeans(model_config)
model_cpu.load_state_dict(checkpoint['model_state_dict'])
hub_dir = save_for_hub(
model=model_cpu,
model_config=model_config,
train_config=train_config,
best_acc=best_acc,
output_dir=output_dir,
training_history=tracker.get_history()
)
push_to_hub(
hub_dir=hub_dir,
repo_id=effective_hub_repo_id,
private=train_config.hub_private,
commit_message=f"Checkpoint epoch {epoch + 1} - {best_acc:.2f}% acc"
)
print(f" πŸ“€ Uploaded to {effective_hub_repo_id}")
except Exception as e:
print(f" [!] Hub upload failed: {e}")
tracker.end_epoch()
# Final summary
print("\n" + "=" * 70)
print(" TRAINING COMPLETE")
print("=" * 70)
print(f"\n Best Test Accuracy: {best_acc:.2f}%")
print(f" Model saved to: {output_dir / 'best_model.pt'}")
# Save training history
history = tracker.get_history()
with open(output_dir / "training_history.json", "w") as f:
json.dump(history, f, indent=2)
# Final TensorBoard logging
if writer is not None:
# Log best accuracy as hparam metric
hparams = {
'dim': model_config.dim,
'num_layers': model_config.num_layers,
'num_heads': model_config.num_heads,
'num_experts': model_config.num_experts,
'k_neighbors': model_config.k_neighbors,
'cantor_weight': model_config.cantor_weight,
'learning_rate': train_config.learning_rate,
'weight_decay': train_config.weight_decay,
'batch_size': train_config.batch_size,
'mixup_alpha': train_config.mixup_alpha,
'cutmix_alpha': train_config.cutmix_alpha,
}
writer.add_hparams(hparams, {'hparam/best_acc': best_acc})
writer.add_scalar('final/best_acc', best_acc, 0)
writer.close()
print(f" TensorBoard logs: {output_dir / 'tensorboard'}")
# HuggingFace Hub upload
if train_config.push_to_hub:
print("\n" + "=" * 70)
print(" UPLOADING TO HUGGINGFACE HUB")
print("=" * 70)
if not HF_HUB_AVAILABLE:
print(" [!] huggingface_hub not installed. Skipping upload.")
elif not effective_hub_repo_id:
print(" [!] hub_repo_id not set. Skipping upload.")
else:
checkpoint = torch.load(output_dir / "best_model.pt", map_location='cpu')
model.load_state_dict(checkpoint['model_state_dict'])
print(f"\n Preparing model for upload...")
hub_dir = save_for_hub(
model=model,
model_config=model_config,
train_config=train_config,
best_acc=best_acc,
output_dir=output_dir,
training_history=history
)
print(f"\n Uploading to {effective_hub_repo_id}...")
push_to_hub(
hub_dir=hub_dir,
repo_id=effective_hub_repo_id,
private=train_config.hub_private
)
print(f"\n πŸŽ‰ Model uploaded to: https://huggingface.co/{effective_hub_repo_id}")
return model, best_acc
# ============================================================================
# PRESETS
# ============================================================================
def train_cifar10_small(run_name: str = "cifar10_small"):
"""Small model for CIFAR-10."""
model_config = DavidBeansConfig(
image_size=32, patch_size=4, dim=256, num_layers=4,
num_heads=4, num_experts=5, k_neighbors=16,
cantor_weight=0.3, scales=[64, 128, 256, 512],
num_classes=10, dropout=0.1
)
train_config = TrainingConfig(
run_name=run_name,
dataset="cifar10", epochs=50, batch_size=128,
learning_rate=1e-3, weight_decay=0.05, warmup_epochs=5,
cayley_weight=0.01, contrast_weight=0.3,
output_dir="./checkpoints/cifar10"
)
return train_david_beans(model_config, train_config)
def train_cifar100(
run_name: str = "cifar100_base",
push_to_hub: bool = False,
hub_repo_id: Optional[str] = None,
resume: bool = False
):
"""Model for CIFAR-100 with optional HF Hub upload and resume."""
model_config = DavidBeansConfig(
image_size=32, patch_size=4, dim=512, num_layers=8,
num_heads=8, num_experts=5, k_neighbors=32,
cantor_weight=0.3, scales=[256, 512, 768],
num_classes=100, dropout=0.15
)
train_config = TrainingConfig(
run_name=run_name,
dataset="cifar100", epochs=200, batch_size=128,
learning_rate=5e-4, weight_decay=0.1, warmup_epochs=20,
cayley_weight=0.01, contrast_weight=0.5,
label_smoothing=0.1, mixup_alpha=0.3, cutmix_alpha=1.0,
output_dir="./checkpoints/cifar100",
resume_from="latest" if resume else None,
push_to_hub=push_to_hub, hub_repo_id=hub_repo_id, hub_private=False
)
return train_david_beans(model_config, train_config)
def resume_training(
checkpoint_dir: str = "./checkpoints/cifar100",
push_to_hub: bool = False,
hub_repo_id: Optional[str] = None
):
"""
Resume training from the latest checkpoint in a directory.
Usage:
resume_training("./checkpoints/cifar100", push_to_hub=True, hub_repo_id="user/repo")
"""
output_dir = Path(checkpoint_dir)
# Load configs from checkpoint directory
model_config_path = output_dir / "model_config.json"
train_config_path = output_dir / "train_config.json"
if not model_config_path.exists():
raise FileNotFoundError(f"No model_config.json in {output_dir}")
with open(model_config_path) as f:
model_config_dict = json.load(f)
with open(train_config_path) as f:
train_config_dict = json.load(f)
# Handle tuple conversion for betas
if 'betas' in train_config_dict and isinstance(train_config_dict['betas'], list):
train_config_dict['betas'] = tuple(train_config_dict['betas'])
model_config = DavidBeansConfig(**model_config_dict)
train_config = TrainingConfig(**train_config_dict)
# Override with resume settings
train_config.resume_from = "latest"
train_config.push_to_hub = push_to_hub
if hub_repo_id:
train_config.hub_repo_id = hub_repo_id
return train_david_beans(model_config, train_config)
# ============================================================================
# STANDALONE UPLOAD FUNCTION
# ============================================================================
def upload_checkpoint(
checkpoint_path: str,
repo_id: str,
best_acc: Optional[float] = None,
private: bool = False
):
"""
Upload an existing checkpoint to HuggingFace Hub.
Usage:
upload_checkpoint(
checkpoint_path="./checkpoints/cifar100/best_model.pt",
repo_id="AbstractPhil/david-beans-cifar100",
best_acc=70.0 # Optional, will read from checkpoint if available
)
"""
if not HF_HUB_AVAILABLE:
raise RuntimeError("huggingface_hub not installed. Run: pip install huggingface_hub")
print(f"\nπŸ“¦ Loading checkpoint: {checkpoint_path}")
checkpoint = torch.load(checkpoint_path, map_location='cpu')
# Reconstruct configs
model_config_dict = checkpoint.get('model_config', {})
train_config_dict = checkpoint.get('train_config', {})
model_config = DavidBeansConfig(**model_config_dict)
train_config = TrainingConfig(**train_config_dict)
train_config.hub_repo_id = repo_id
# Build model and load weights
model = DavidBeans(model_config)
model.load_state_dict(checkpoint['model_state_dict'])
actual_best_acc = best_acc or checkpoint.get('best_acc', 0.0)
# Prepare and upload
output_dir = Path(checkpoint_path).parent
print(f"\nπŸ“ Preparing files for upload...")
hub_dir = save_for_hub(
model=model,
model_config=model_config,
train_config=train_config,
best_acc=actual_best_acc,
output_dir=output_dir
)
print(f"\nπŸš€ Uploading to {repo_id}...")
push_to_hub(hub_dir, repo_id, private=private)
print(f"\nπŸŽ‰ Done! https://huggingface.co/{repo_id}")
# ============================================================================
# MAIN
# ============================================================================
if __name__ == "__main__":
# =====================================================
# CONFIGURATION
# =====================================================
PRESET = "cifar100" # "test", "small", "cifar100", "resume"
RESUME = False # Set True to resume from latest checkpoint
RUN_NAME = "5expert_3scale" # Descriptive name for this run
# HuggingFace Hub settings
PUSH_TO_HUB = False
HUB_REPO_ID = "AbstractPhil/geovit-david-beans"
# =====================================================
# RUN
# =====================================================
if PRESET == "test":
print("πŸ§ͺ Quick test...")
model_config = DavidBeansConfig(
image_size=32, patch_size=4, dim=128, num_layers=2,
num_heads=4, num_experts=5, k_neighbors=8,
scales=[32, 64, 128], num_classes=10
)
train_config = TrainingConfig(
run_name="test",
epochs=2, batch_size=32,
use_augmentation=False, mixup_alpha=0.0, cutmix_alpha=0.0
)
model, acc = train_david_beans(model_config, train_config)
elif PRESET == "small":
print("πŸ«˜πŸ’Ž Training DavidBeans - Small (CIFAR-10)...")
model, acc = train_cifar10_small()
elif PRESET == "cifar100":
print("πŸ«˜πŸ’Ž Training DavidBeans - CIFAR-100...")
model, acc = train_cifar100(
run_name=RUN_NAME,
push_to_hub=PUSH_TO_HUB,
hub_repo_id=HUB_REPO_ID,
resume=RESUME
)
elif PRESET == "resume":
print("πŸ”„ Resuming training from latest checkpoint...")
model, acc = resume_training(
checkpoint_dir="./checkpoints/cifar100",
push_to_hub=PUSH_TO_HUB,
hub_repo_id=HUB_REPO_ID
)
else:
raise ValueError(f"Unknown preset: {PRESET}")
print(f"\nπŸŽ‰ Done! Best accuracy: {acc:.2f}%")