AbstractPhil commited on
Commit
ff31041
Β·
verified Β·
1 Parent(s): 5bdefc8

Create trainer_v3_v21.py

Browse files
Files changed (1) hide show
  1. trainer_v3_v21.py +1790 -0
trainer_v3_v21.py ADDED
@@ -0,0 +1,1790 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Train DavidBeans V2: Wormhole Routing Architecture
3
+ ===================================================
4
+
5
+ β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
6
+ β”‚ BEANS V2.1 β”‚ "I learn where to look..."
7
+ β”‚ (Wormhole ViT)β”‚
8
+ β”‚ πŸŒ€ β†’ πŸŒ€ β†’ πŸŒ€ β”‚ Learned sparse routing
9
+ β””β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”˜
10
+ β”‚
11
+ β–Ό
12
+ β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
13
+ β”‚ DAVID β”‚ "I know the crystals..."
14
+ β”‚ (Classifier) β”‚
15
+ β”‚ πŸ’Ž β†’ πŸ’Ž β†’ πŸ’Ž β”‚ Multi-scale projection
16
+ β””β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”˜
17
+ β”‚
18
+ β–Ό
19
+ [Prediction]
20
+
21
+ Key findings from wormhole experiments:
22
+ 1. When routing IS the task, routing learns structure
23
+ 2. Auxiliary losses can be gamed - removed in V2
24
+ 3. Gradient flow through router is critical - verified
25
+ 4. Cross-contrastive aligns patch↔scale features
26
+
27
+ V2.1 additions:
28
+ - AlphaMix augmentation (localized transparent overlay)
29
+ - Configurable normalization (standard, none, center_only, unit_var)
30
+ - Support for redundant scales, conv spine, collective mode
31
+ - Configurable belly depth
32
+
33
+ Author: AbstractPhil
34
+ Date: November 30, 2025
35
+ """
36
+
37
+ import torch
38
+ import torch.nn as nn
39
+ import torch.nn.functional as F
40
+ from torch.utils.data import DataLoader
41
+ from torch.optim import AdamW
42
+ from torch.optim.lr_scheduler import CosineAnnealingLR, OneCycleLR
43
+ from tqdm.auto import tqdm
44
+ import time
45
+ import math
46
+ from pathlib import Path
47
+ from typing import Dict, Optional, Tuple, List, Union
48
+ from dataclasses import dataclass, field
49
+ import json
50
+ from datetime import datetime
51
+ import os
52
+ import shutil
53
+
54
+ from google.colab import userdata
55
+
56
+ os.environ['HF_TOKEN'] = userdata.get('HF_TOKEN')
57
+ HF_TOKEN = userdata.get('HF_TOKEN')
58
+
59
+ try:
60
+ from google.colab import userdata
61
+ HF_TOKEN = userdata.get('HF_TOKEN')
62
+ os.environ['HF_TOKEN'] = HF_TOKEN
63
+ except:
64
+ pass
65
+
66
+ # Import both model versions
67
+ from geofractal.model.david_beans.model import DavidBeans, DavidBeansConfig
68
+ from geofractal.model.david_beans.model_v2 import DavidBeansV2, DavidBeansV2Config
69
+
70
+ # HuggingFace Hub integration
71
+ try:
72
+ from huggingface_hub import HfApi, create_repo, upload_folder
73
+ HF_HUB_AVAILABLE = True
74
+ except ImportError:
75
+ HF_HUB_AVAILABLE = False
76
+ print(" [!] huggingface_hub not installed. Run: pip install huggingface_hub")
77
+
78
+ # Safetensors support
79
+ try:
80
+ from safetensors.torch import save_file as save_safetensors
81
+ SAFETENSORS_AVAILABLE = True
82
+ except ImportError:
83
+ SAFETENSORS_AVAILABLE = False
84
+
85
+ # TensorBoard support
86
+ try:
87
+ from torch.utils.tensorboard import SummaryWriter
88
+ TENSORBOARD_AVAILABLE = True
89
+ except ImportError:
90
+ TENSORBOARD_AVAILABLE = False
91
+ print(" [!] tensorboard not installed. Run: pip install tensorboard")
92
+
93
+ import numpy as np
94
+
95
+
96
+ # ============================================================================
97
+ # TRAINING CONFIGURATION V2.1
98
+ # ============================================================================
99
+
100
+ @dataclass
101
+ class TrainingConfigV2:
102
+ """Training configuration for DavidBeans V2 with wormhole routing."""
103
+
104
+ # Run identification
105
+ run_name: str = "default"
106
+ run_number: Optional[int] = None
107
+
108
+ # Model version
109
+ model_version: int = 2 # 1 = original, 2 = wormhole
110
+
111
+ # Data
112
+ dataset: str = "cifar100"
113
+ image_size: int = 32
114
+ batch_size: int = 128
115
+ num_workers: int = 4
116
+
117
+ # Normalization
118
+ normalization: str = "standard" # "standard", "none", "center_only", "unit_var"
119
+
120
+ # Training schedule
121
+ epochs: int = 200
122
+ warmup_epochs: int = 10
123
+
124
+ # Optimizer
125
+ learning_rate: float = 3e-4
126
+ weight_decay: float = 0.05
127
+ betas: Tuple[float, float] = (0.9, 0.999)
128
+
129
+ # Learning rate schedule
130
+ scheduler: str = "cosine"
131
+ min_lr: float = 1e-6
132
+
133
+ # Loss weights (based on experimental findings)
134
+ ce_weight: float = 1.0
135
+ contrast_weight: float = 0.5
136
+ # NOTE: No auxiliary routing loss - routing learns from task pressure
137
+
138
+ # Regularization
139
+ gradient_clip: float = 1.0
140
+ label_smoothing: float = 0.1
141
+
142
+ # Augmentation
143
+ use_augmentation: bool = True
144
+ mixup_alpha: float = 0.2
145
+ cutmix_alpha: float = 1.0
146
+
147
+ # AlphaMix augmentation
148
+ use_alphamix: bool = False
149
+ alphamix_alpha_range: Tuple[float, float] = (0.3, 0.7)
150
+ alphamix_spatial_ratio: float = 0.25
151
+
152
+ # Checkpointing
153
+ save_interval: int = 10
154
+ output_dir: str = "./checkpoints"
155
+ resume_from: Optional[str] = None
156
+
157
+ # TensorBoard
158
+ use_tensorboard: bool = True
159
+ log_interval: int = 50
160
+ log_routing: bool = True # Log routing patterns
161
+
162
+ # HuggingFace Hub
163
+ push_to_hub: bool = False
164
+ hub_repo_id: str = "AbstractPhil/geovit-david-beans"
165
+ hub_private: bool = False
166
+
167
+ # Device
168
+ device: str = "cuda" if torch.cuda.is_available() else "cpu"
169
+
170
+ def to_dict(self) -> Dict:
171
+ return {k: v for k, v in self.__dict__.items()}
172
+
173
+ def __post_init__(self):
174
+ assert self.normalization in ["standard", "none", "center_only", "unit_var"], \
175
+ f"Invalid normalization mode: {self.normalization}"
176
+
177
+
178
+ # ============================================================================
179
+ # ROUTING METRICS
180
+ # ============================================================================
181
+
182
+ class RoutingMetrics:
183
+ """Track and analyze wormhole routing patterns."""
184
+
185
+ def __init__(self):
186
+ self.reset()
187
+
188
+ def reset(self):
189
+ self.route_entropies = []
190
+ self.route_diversities = []
191
+ self.grad_norms = {'query': [], 'key': []}
192
+
193
+ @torch.no_grad()
194
+ def compute_route_entropy(self, soft_routes: torch.Tensor) -> float:
195
+ """Compute average entropy of routing distributions."""
196
+ eps = 1e-8
197
+ entropy = -(soft_routes * (soft_routes + eps).log()).sum(dim=-1)
198
+ return entropy.mean().item()
199
+
200
+ @torch.no_grad()
201
+ def compute_route_diversity(self, routes: torch.Tensor, num_positions: int) -> float:
202
+ """Compute how many unique destinations are used."""
203
+ unique_per_sample = []
204
+ for b in range(routes.shape[0]):
205
+ unique = routes[b].unique().numel()
206
+ unique_per_sample.append(unique / num_positions)
207
+ return sum(unique_per_sample) / len(unique_per_sample)
208
+
209
+ def update_from_routing_info(self, routing_info: List[Dict], model: nn.Module):
210
+ """Extract metrics from routing info returned by V2 model."""
211
+ if not routing_info:
212
+ return
213
+
214
+ for layer_info in routing_info:
215
+ if layer_info.get('attention'):
216
+ attn = layer_info['attention']
217
+ if attn.get('weights') is not None:
218
+ entropy = self.compute_route_entropy(attn['weights'])
219
+ self.route_entropies.append(entropy)
220
+ if attn.get('routes') is not None:
221
+ P = attn['routes'].shape[1]
222
+ diversity = self.compute_route_diversity(attn['routes'], P)
223
+ self.route_diversities.append(diversity)
224
+
225
+ if layer_info.get('expert'):
226
+ exp = layer_info['expert']
227
+ if exp.get('weights') is not None:
228
+ entropy = self.compute_route_entropy(exp['weights'])
229
+ self.route_entropies.append(entropy)
230
+
231
+ def update_grad_norms(self, model: nn.Module):
232
+ """Track gradient norms through router projections."""
233
+ for name, param in model.named_parameters():
234
+ if param.grad is not None:
235
+ if 'query_proj' in name and 'weight' in name:
236
+ self.grad_norms['query'].append(param.grad.norm().item())
237
+ elif 'key_proj' in name and 'weight' in name:
238
+ self.grad_norms['key'].append(param.grad.norm().item())
239
+
240
+ def get_summary(self) -> Dict[str, float]:
241
+ """Get summary statistics."""
242
+ summary = {}
243
+
244
+ if self.route_entropies:
245
+ summary['route_entropy'] = sum(self.route_entropies) / len(self.route_entropies)
246
+ if self.route_diversities:
247
+ summary['route_diversity'] = sum(self.route_diversities) / len(self.route_diversities)
248
+ if self.grad_norms['query']:
249
+ summary['grad_query'] = sum(self.grad_norms['query']) / len(self.grad_norms['query'])
250
+ if self.grad_norms['key']:
251
+ summary['grad_key'] = sum(self.grad_norms['key']) / len(self.grad_norms['key'])
252
+
253
+ return summary
254
+
255
+
256
+ # ============================================================================
257
+ # DATA LOADING WITH NORMALIZATION OPTIONS
258
+ # ============================================================================
259
+
260
+ def get_normalization_transform(config: TrainingConfigV2, dataset: str):
261
+ """Get normalization transform based on config."""
262
+ import torchvision.transforms as T
263
+
264
+ if dataset == "cifar10":
265
+ mean = (0.4914, 0.4822, 0.4465)
266
+ std = (0.2470, 0.2435, 0.2616)
267
+ elif dataset == "cifar100":
268
+ mean = (0.5071, 0.4867, 0.4408)
269
+ std = (0.2675, 0.2565, 0.2761)
270
+ else:
271
+ mean = (0.5, 0.5, 0.5)
272
+ std = (0.5, 0.5, 0.5)
273
+
274
+ if config.normalization == "standard":
275
+ return T.Normalize(mean, std)
276
+ elif config.normalization == "none":
277
+ # No normalization - raw [0, 1] from ToTensor
278
+ return None
279
+ elif config.normalization == "center_only":
280
+ # Center at 0 but don't scale variance
281
+ return T.Normalize(mean=(0.5, 0.5, 0.5), std=(1.0, 1.0, 1.0))
282
+ elif config.normalization == "unit_var":
283
+ # Scale variance but don't center
284
+ return T.Normalize(mean=(0.0, 0.0, 0.0), std=std)
285
+ else:
286
+ return T.Normalize(mean, std)
287
+
288
+
289
+ def get_dataloaders(config: TrainingConfigV2) -> Tuple[DataLoader, DataLoader, int]:
290
+ """Get train and test dataloaders with configurable normalization."""
291
+
292
+ try:
293
+ import torchvision
294
+ import torchvision.transforms as T
295
+
296
+ if config.dataset == "cifar10":
297
+ num_classes = 10
298
+ DatasetClass = torchvision.datasets.CIFAR10
299
+ elif config.dataset == "cifar100":
300
+ num_classes = 100
301
+ DatasetClass = torchvision.datasets.CIFAR100
302
+ else:
303
+ raise ValueError(f"Unknown dataset: {config.dataset}")
304
+
305
+ # Get normalization transform
306
+ norm_transform = get_normalization_transform(config, config.dataset)
307
+
308
+ # Build train transforms
309
+ train_transforms = [
310
+ T.RandomCrop(32, padding=4),
311
+ T.RandomHorizontalFlip(),
312
+ ]
313
+
314
+ if config.use_augmentation:
315
+ train_transforms.append(T.AutoAugment(T.AutoAugmentPolicy.CIFAR10))
316
+
317
+ train_transforms.append(T.ToTensor())
318
+
319
+ if norm_transform is not None:
320
+ train_transforms.append(norm_transform)
321
+
322
+ train_transform = T.Compose(train_transforms)
323
+
324
+ # Build test transforms
325
+ test_transforms = [T.ToTensor()]
326
+ if norm_transform is not None:
327
+ test_transforms.append(norm_transform)
328
+ test_transform = T.Compose(test_transforms)
329
+
330
+ print(f" Normalization: {config.normalization}")
331
+
332
+ train_dataset = DatasetClass(
333
+ root='./data', train=True, download=True, transform=train_transform
334
+ )
335
+ test_dataset = DatasetClass(
336
+ root='./data', train=False, download=True, transform=test_transform
337
+ )
338
+
339
+ train_loader = DataLoader(
340
+ train_dataset,
341
+ batch_size=config.batch_size,
342
+ shuffle=True,
343
+ num_workers=config.num_workers,
344
+ pin_memory=True,
345
+ persistent_workers=config.num_workers > 0,
346
+ drop_last=True
347
+ )
348
+ test_loader = DataLoader(
349
+ test_dataset,
350
+ batch_size=config.batch_size,
351
+ shuffle=False,
352
+ num_workers=config.num_workers,
353
+ pin_memory=True,
354
+ persistent_workers=config.num_workers > 0
355
+ )
356
+
357
+ return train_loader, test_loader, num_classes
358
+
359
+ except ImportError:
360
+ print(" [!] torchvision not available, using synthetic data")
361
+ return get_synthetic_dataloaders(config)
362
+
363
+
364
+ def get_synthetic_dataloaders(config: TrainingConfigV2) -> Tuple[DataLoader, DataLoader, int]:
365
+ """Fallback synthetic data for testing."""
366
+
367
+ class SyntheticDataset(torch.utils.data.Dataset):
368
+ def __init__(self, size: int, image_size: int, num_classes: int):
369
+ self.size = size
370
+ self.image_size = image_size
371
+ self.num_classes = num_classes
372
+
373
+ def __len__(self):
374
+ return self.size
375
+
376
+ def __getitem__(self, idx):
377
+ x = torch.randn(3, self.image_size, self.image_size)
378
+ y = idx % self.num_classes
379
+ return x, y
380
+
381
+ num_classes = 100 if config.dataset == "cifar100" else 10
382
+ train_dataset = SyntheticDataset(5000, config.image_size, num_classes)
383
+ test_dataset = SyntheticDataset(1000, config.image_size, num_classes)
384
+
385
+ train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True)
386
+ test_loader = DataLoader(test_dataset, batch_size=config.batch_size, shuffle=False)
387
+
388
+ return train_loader, test_loader, num_classes
389
+
390
+
391
+ # ============================================================================
392
+ # MIXING AUGMENTATIONS
393
+ # ============================================================================
394
+
395
+ def mixup_data(x: torch.Tensor, y: torch.Tensor, alpha: float = 0.2):
396
+ """Mixup augmentation."""
397
+ if alpha > 0:
398
+ lam = torch.distributions.Beta(alpha, alpha).sample().item()
399
+ else:
400
+ lam = 1.0
401
+
402
+ batch_size = x.size(0)
403
+ index = torch.randperm(batch_size, device=x.device)
404
+
405
+ mixed_x = lam * x + (1 - lam) * x[index]
406
+ y_a, y_b = y, y[index]
407
+
408
+ return mixed_x, y_a, y_b, lam
409
+
410
+
411
+ def cutmix_data(x: torch.Tensor, y: torch.Tensor, alpha: float = 1.0):
412
+ """CutMix augmentation."""
413
+ if alpha > 0:
414
+ lam = torch.distributions.Beta(alpha, alpha).sample().item()
415
+ else:
416
+ lam = 1.0
417
+
418
+ batch_size = x.size(0)
419
+ index = torch.randperm(batch_size, device=x.device)
420
+
421
+ _, _, H, W = x.shape
422
+
423
+ cut_ratio = math.sqrt(1 - lam)
424
+ cut_h = int(H * cut_ratio)
425
+ cut_w = int(W * cut_ratio)
426
+
427
+ cx = torch.randint(0, H, (1,)).item()
428
+ cy = torch.randint(0, W, (1,)).item()
429
+
430
+ x1 = max(0, cx - cut_h // 2)
431
+ x2 = min(H, cx + cut_h // 2)
432
+ y1 = max(0, cy - cut_w // 2)
433
+ y2 = min(W, cy + cut_w // 2)
434
+
435
+ mixed_x = x.clone()
436
+ mixed_x[:, :, x1:x2, y1:y2] = x[index, :, x1:x2, y1:y2]
437
+
438
+ lam = 1 - ((x2 - x1) * (y2 - y1)) / (H * W)
439
+
440
+ y_a, y_b = y, y[index]
441
+
442
+ return mixed_x, y_a, y_b, lam
443
+
444
+
445
+ def alphamix_data(
446
+ x: torch.Tensor,
447
+ y: torch.Tensor,
448
+ alpha_range: Tuple[float, float] = (0.3, 0.7),
449
+ spatial_ratio: float = 0.25
450
+ ):
451
+ """
452
+ AlphaMix: Spatially localized transparent overlay.
453
+
454
+ Unlike CutMix (full replacement) or Mixup (global blend),
455
+ AlphaMix creates a localized alpha-blended region.
456
+
457
+ Args:
458
+ x: [B, C, H, W] input images
459
+ y: [B] labels
460
+ alpha_range: (min, max) for alpha blending in overlay region
461
+ spatial_ratio: Fraction of image area for overlay
462
+
463
+ Returns:
464
+ mixed_x, y_a, y_b, lam (effective lambda for loss weighting)
465
+ """
466
+ batch_size = x.size(0)
467
+ index = torch.randperm(batch_size, device=x.device)
468
+
469
+ y_a, y_b = y, y[index]
470
+
471
+ # Sample alpha from beta distribution within range
472
+ alpha_min, alpha_max = alpha_range
473
+ beta_sample = np.random.beta(2, 2)
474
+ alpha = alpha_min + (alpha_max - alpha_min) * beta_sample
475
+
476
+ _, _, H, W = x.shape
477
+
478
+ # Compute overlay region size
479
+ overlay_ratio = np.sqrt(spatial_ratio)
480
+ overlay_h = max(4, int(H * overlay_ratio))
481
+ overlay_w = max(4, int(W * overlay_ratio))
482
+
483
+ # Random position for overlay
484
+ top = np.random.randint(0, max(1, H - overlay_h + 1))
485
+ left = np.random.randint(0, max(1, W - overlay_w + 1))
486
+
487
+ # Create composited image
488
+ composited_x = x.clone()
489
+
490
+ # Alpha blend in the overlay region
491
+ overlay_region = alpha * x[:, :, top:top + overlay_h, left:left + overlay_w]
492
+ background_region = (1 - alpha) * x[index, :, top:top + overlay_h, left:left + overlay_w]
493
+ composited_x[:, :, top:top + overlay_h, left:left + overlay_w] = overlay_region + background_region
494
+
495
+ # Compute effective lambda based on blended area
496
+ blended_area = (overlay_h * overlay_w) / (H * W)
497
+ # lam represents contribution of original sample
498
+ # In non-blended region: 100% original
499
+ # In blended region: alpha% original
500
+ lam = 1.0 - blended_area * (1 - alpha)
501
+
502
+ return composited_x, y_a, y_b, lam
503
+
504
+
505
+ # ============================================================================
506
+ # METRICS TRACKER
507
+ # ============================================================================
508
+
509
+ class MetricsTracker:
510
+ """Track training metrics with EMA smoothing."""
511
+
512
+ def __init__(self, ema_decay: float = 0.9):
513
+ self.ema_decay = ema_decay
514
+ self.metrics = {}
515
+ self.ema_metrics = {}
516
+ self.history = {}
517
+
518
+ def update(self, **kwargs):
519
+ for k, v in kwargs.items():
520
+ if isinstance(v, torch.Tensor):
521
+ v = v.item()
522
+
523
+ if k not in self.metrics:
524
+ self.metrics[k] = []
525
+ self.ema_metrics[k] = v
526
+ self.history[k] = []
527
+
528
+ self.metrics[k].append(v)
529
+ self.ema_metrics[k] = self.ema_decay * self.ema_metrics[k] + (1 - self.ema_decay) * v
530
+
531
+ def get_ema(self, key: str) -> float:
532
+ return self.ema_metrics.get(key, 0.0)
533
+
534
+ def get_epoch_mean(self, key: str) -> float:
535
+ values = self.metrics.get(key, [])
536
+ return sum(values) / len(values) if values else 0.0
537
+
538
+ def end_epoch(self):
539
+ for k, v in self.metrics.items():
540
+ if v:
541
+ self.history[k].append(sum(v) / len(v))
542
+ self.metrics = {k: [] for k in self.metrics}
543
+
544
+ def get_history(self) -> Dict:
545
+ return self.history
546
+
547
+
548
+ # ============================================================================
549
+ # CHECKPOINT UTILITIES
550
+ # ============================================================================
551
+
552
+ def find_latest_checkpoint(output_dir: Path) -> Optional[Path]:
553
+ """Find the most recent checkpoint in output directory."""
554
+ checkpoints = list(output_dir.glob("checkpoint_epoch_*.pt"))
555
+
556
+ if not checkpoints:
557
+ best_model = output_dir / "best_model.pt"
558
+ if best_model.exists():
559
+ return best_model
560
+ return None
561
+
562
+ def get_epoch(p):
563
+ try:
564
+ return int(p.stem.split("_")[-1])
565
+ except:
566
+ return 0
567
+
568
+ checkpoints.sort(key=get_epoch, reverse=True)
569
+ return checkpoints[0]
570
+
571
+
572
+ def get_next_run_number(base_dir: Path) -> int:
573
+ """Get the next run number by scanning existing run directories."""
574
+ if not base_dir.exists():
575
+ return 1
576
+
577
+ max_num = 0
578
+ for d in base_dir.iterdir():
579
+ if d.is_dir() and d.name.startswith("run_"):
580
+ try:
581
+ num = int(d.name.split("_")[1])
582
+ max_num = max(max_num, num)
583
+ except (IndexError, ValueError):
584
+ continue
585
+
586
+ return max_num + 1
587
+
588
+
589
+ def generate_run_dir_name(run_number: int, run_name: str, version: int = 2) -> str:
590
+ """Generate a run directory name."""
591
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
592
+ safe_name = "".join(c if c.isalnum() or c == "_" else "_" for c in run_name.lower())
593
+ safe_name = "_".join(filter(None, safe_name.split("_")))
594
+ return f"run_{run_number:03d}_v{version}_{safe_name}_{timestamp}"
595
+
596
+
597
+ def find_latest_run_dir(base_dir: Path) -> Optional[Path]:
598
+ """Find the most recent run directory."""
599
+ if not base_dir.exists():
600
+ return None
601
+
602
+ run_dirs = [d for d in base_dir.iterdir() if d.is_dir() and d.name.startswith("run_")]
603
+
604
+ if not run_dirs:
605
+ return None
606
+
607
+ run_dirs.sort(key=lambda d: d.stat().st_mtime, reverse=True)
608
+ return run_dirs[0]
609
+
610
+
611
+ def load_checkpoint(
612
+ checkpoint_path: Path,
613
+ model: nn.Module,
614
+ optimizer: Optional[torch.optim.Optimizer] = None,
615
+ device: str = "cuda"
616
+ ) -> Tuple[int, float]:
617
+ """Load checkpoint and return (start_epoch, best_acc)."""
618
+ print(f"\nπŸ“‚ Loading checkpoint: {checkpoint_path}")
619
+ checkpoint = torch.load(checkpoint_path, map_location=device)
620
+
621
+ model.load_state_dict(checkpoint['model_state_dict'])
622
+ print(f" βœ“ Loaded model weights")
623
+
624
+ if optimizer is not None and 'optimizer_state_dict' in checkpoint:
625
+ optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
626
+ print(f" βœ“ Loaded optimizer state")
627
+
628
+ epoch = checkpoint.get('epoch', 0)
629
+ best_acc = checkpoint.get('best_acc', 0.0)
630
+
631
+ print(f" βœ“ Resuming from epoch {epoch + 1}, best_acc={best_acc:.2f}%")
632
+
633
+ return epoch + 1, best_acc
634
+
635
+
636
+ # ============================================================================
637
+ # HUGGINGFACE HUB INTEGRATION
638
+ # ============================================================================
639
+
640
+ def generate_run_readme(
641
+ model_config: Union[DavidBeansConfig, DavidBeansV2Config],
642
+ train_config: TrainingConfigV2,
643
+ best_acc: float,
644
+ run_dir_name: str
645
+ ) -> str:
646
+ """Generate README for a specific run."""
647
+
648
+ scales_str = ", ".join([str(s) for s in model_config.scales])
649
+
650
+ # V2 specific info
651
+ if isinstance(model_config, DavidBeansV2Config):
652
+ copies_str = ""
653
+ if model_config.scale_copies:
654
+ copies_str = f"\n| Scale Copies | {model_config.scale_copies} |"
655
+
656
+ routing_info = f"""
657
+ ## Wormhole Routing (V2)
658
+ | Parameter | Value |
659
+ |-----------|-------|
660
+ | Mode | {model_config.wormhole_mode} |
661
+ | Wormholes/Position | {model_config.num_wormholes} |
662
+ | Temperature | {model_config.wormhole_temperature} |
663
+ | Tiles | {model_config.num_tiles} |
664
+ | Tile Wormholes | {model_config.tile_wormholes} |
665
+
666
+ ## Crystal Head
667
+ | Parameter | Value |
668
+ |-----------|-------|
669
+ | Scales | [{scales_str}] |{copies_str}
670
+ | Weighting Mode | {model_config.weighting_mode} |
671
+ | Belly Layers | {model_config.belly_layers} |
672
+ | Belly Residual | {model_config.belly_residual} |
673
+ | Use Spine | {model_config.use_spine} |
674
+ | Use Collective | {model_config.use_collective} |
675
+ """
676
+ else:
677
+ routing_info = f"""
678
+ ## Routing (V1)
679
+ | Parameter | Value |
680
+ |-----------|-------|
681
+ | k_neighbors | {model_config.k_neighbors} |
682
+ | Cantor Weight | {model_config.cantor_weight} |
683
+ """
684
+
685
+ aug_info = f"""
686
+ ## Augmentation
687
+ | Parameter | Value |
688
+ |-----------|-------|
689
+ | Normalization | {train_config.normalization} |
690
+ | Mixup Alpha | {train_config.mixup_alpha} |
691
+ | CutMix Alpha | {train_config.cutmix_alpha} |
692
+ | AlphaMix | {train_config.use_alphamix} |
693
+ | Label Smoothing | {train_config.label_smoothing} |
694
+ """
695
+
696
+ return f"""# Run: {run_dir_name}
697
+
698
+ ## Results
699
+ - **Best Accuracy**: {best_acc:.2f}%
700
+ - **Dataset**: {train_config.dataset}
701
+ - **Epochs**: {train_config.epochs}
702
+ - **Model Version**: V{train_config.model_version}
703
+
704
+ ## Model Config
705
+ | Parameter | Value |
706
+ |-----------|-------|
707
+ | Dim | {model_config.dim} |
708
+ | Layers | {model_config.num_layers} |
709
+ | Heads | {model_config.num_heads} |
710
+ | Patch Size | {model_config.patch_size} |
711
+ {routing_info}
712
+
713
+ ## Training Config
714
+ | Parameter | Value |
715
+ |-----------|-------|
716
+ | Learning Rate | {train_config.learning_rate} |
717
+ | Weight Decay | {train_config.weight_decay} |
718
+ | Batch Size | {train_config.batch_size} |
719
+ | CE Weight | {train_config.ce_weight} |
720
+ | Contrast Weight | {train_config.contrast_weight} |
721
+ {aug_info}
722
+
723
+ ## Key Findings Applied
724
+ - Routing learns from task pressure (no auxiliary routing losses)
725
+ - Gradients verified to flow through router
726
+ - Cross-contrastive aligns patch↔scale features
727
+ """
728
+
729
+
730
+ def prepare_run_for_hub(
731
+ model: nn.Module,
732
+ model_config: Union[DavidBeansConfig, DavidBeansV2Config],
733
+ train_config: TrainingConfigV2,
734
+ best_acc: float,
735
+ output_dir: Path,
736
+ run_dir_name: str,
737
+ training_history: Optional[Dict] = None
738
+ ) -> Path:
739
+ """Prepare run files for upload to HuggingFace Hub."""
740
+
741
+ hub_dir = output_dir / "hub_upload"
742
+ run_hub_dir = hub_dir / "weights" / run_dir_name
743
+ run_hub_dir.mkdir(parents=True, exist_ok=True)
744
+
745
+ state_dict = {k: v.clone() for k, v in model.state_dict().items()}
746
+
747
+ if SAFETENSORS_AVAILABLE:
748
+ try:
749
+ save_safetensors(state_dict, run_hub_dir / "best.safetensors")
750
+ print(f" βœ“ Saved best.safetensors")
751
+ except Exception as e:
752
+ print(f" [!] Safetensors failed ({e}), using pytorch format")
753
+ torch.save(state_dict, run_hub_dir / "best.pt")
754
+ else:
755
+ torch.save(state_dict, run_hub_dir / "best.pt")
756
+
757
+ config_dict = {
758
+ "architecture": f"DavidBeans_V{train_config.model_version}",
759
+ "model_type": "david_beans_v2" if train_config.model_version == 2 else "david_beans",
760
+ **model_config.__dict__
761
+ }
762
+ with open(run_hub_dir / "config.json", "w") as f:
763
+ json.dump(config_dict, f, indent=2, default=str)
764
+
765
+ with open(run_hub_dir / "training_config.json", "w") as f:
766
+ json.dump(train_config.to_dict(), f, indent=2, default=str)
767
+
768
+ run_readme = generate_run_readme(model_config, train_config, best_acc, run_dir_name)
769
+ with open(run_hub_dir / "README.md", "w") as f:
770
+ f.write(run_readme)
771
+
772
+ if training_history:
773
+ with open(run_hub_dir / "training_history.json", "w") as f:
774
+ json.dump(training_history, f, indent=2)
775
+
776
+ tb_dir = output_dir / "tensorboard"
777
+ if tb_dir.exists():
778
+ hub_tb_dir = run_hub_dir / "tensorboard"
779
+ if hub_tb_dir.exists():
780
+ shutil.rmtree(hub_tb_dir)
781
+ shutil.copytree(tb_dir, hub_tb_dir)
782
+
783
+ return hub_dir
784
+
785
+
786
+ def push_run_to_hub(
787
+ hub_dir: Path,
788
+ repo_id: str,
789
+ run_dir_name: str,
790
+ private: bool = False,
791
+ commit_message: Optional[str] = None
792
+ ) -> str:
793
+ """Push run files to HuggingFace Hub."""
794
+
795
+ if not HF_HUB_AVAILABLE:
796
+ raise RuntimeError("huggingface_hub not installed")
797
+
798
+ api = HfApi()
799
+
800
+ try:
801
+ create_repo(repo_id, private=private, exist_ok=True)
802
+ except Exception as e:
803
+ print(f" [!] Repo creation note: {e}")
804
+
805
+ run_upload_dir = hub_dir / "weights" / run_dir_name
806
+
807
+ if commit_message is None:
808
+ commit_message = f"Update {run_dir_name} - {datetime.now().strftime('%Y-%m-%d %H:%M')}"
809
+
810
+ url = upload_folder(
811
+ folder_path=str(run_upload_dir),
812
+ repo_id=repo_id,
813
+ path_in_repo=f"weights/{run_dir_name}",
814
+ commit_message=commit_message
815
+ )
816
+
817
+ return url
818
+
819
+
820
+ # ============================================================================
821
+ # TRAINING LOOP V2
822
+ # ============================================================================
823
+
824
+ def train_epoch_v2(
825
+ model: nn.Module,
826
+ train_loader: DataLoader,
827
+ optimizer: torch.optim.Optimizer,
828
+ scheduler: Optional[torch.optim.lr_scheduler._LRScheduler],
829
+ config: TrainingConfigV2,
830
+ epoch: int,
831
+ tracker: MetricsTracker,
832
+ routing_metrics: RoutingMetrics,
833
+ writer: Optional['SummaryWriter'] = None
834
+ ) -> Dict[str, float]:
835
+ """Train for one epoch with V2 routing metrics and AlphaMix support."""
836
+
837
+ model.train()
838
+ device = config.device
839
+ is_v2 = config.model_version == 2
840
+
841
+ total_loss = 0.0
842
+ total_correct = 0
843
+ total_samples = 0
844
+ global_step = epoch * len(train_loader)
845
+
846
+ routing_metrics.reset()
847
+
848
+ pbar = tqdm(train_loader, desc=f"Epoch {epoch + 1}", leave=True)
849
+
850
+ for batch_idx, (images, targets) in enumerate(pbar):
851
+ images = images.to(device, non_blocking=True)
852
+ targets = targets.to(device, non_blocking=True)
853
+
854
+ # Apply mixing augmentations
855
+ use_mixup = config.use_augmentation and config.mixup_alpha > 0
856
+ use_cutmix = config.use_augmentation and config.cutmix_alpha > 0
857
+ use_alphamix = config.use_alphamix
858
+
859
+ mixed = False
860
+ mix_type = None
861
+
862
+ if use_mixup or use_cutmix or use_alphamix:
863
+ r = torch.rand(1).item()
864
+
865
+ # Probability distribution for mix types
866
+ # If all three enabled: 40% none, 20% mixup, 20% cutmix, 20% alphamix
867
+ # Adjust based on what's enabled
868
+ thresholds = [0.4] # Base: 40% no mixing
869
+
870
+ enabled_mixes = []
871
+ if use_mixup:
872
+ enabled_mixes.append(('mixup', config.mixup_alpha))
873
+ if use_cutmix:
874
+ enabled_mixes.append(('cutmix', config.cutmix_alpha))
875
+ if use_alphamix:
876
+ enabled_mixes.append(('alphamix', None))
877
+
878
+ if enabled_mixes:
879
+ mix_prob = 0.6 / len(enabled_mixes) # Split remaining 60% among enabled
880
+
881
+ cumulative = 0.4
882
+ for i, (mix_name, _) in enumerate(enabled_mixes):
883
+ cumulative += mix_prob
884
+ thresholds.append(cumulative)
885
+
886
+ # Determine which mix to use
887
+ if r < 0.4:
888
+ pass # No mixing
889
+ else:
890
+ for i, (mix_name, mix_param) in enumerate(enabled_mixes):
891
+ if r < thresholds[i + 1]:
892
+ mix_type = mix_name
893
+ break
894
+
895
+ if mix_type == 'mixup':
896
+ images, targets_a, targets_b, lam = mixup_data(images, targets, config.mixup_alpha)
897
+ mixed = True
898
+ elif mix_type == 'cutmix':
899
+ images, targets_a, targets_b, lam = cutmix_data(images, targets, config.cutmix_alpha)
900
+ mixed = True
901
+ elif mix_type == 'alphamix':
902
+ images, targets_a, targets_b, lam = alphamix_data(
903
+ images, targets,
904
+ alpha_range=config.alphamix_alpha_range,
905
+ spatial_ratio=config.alphamix_spatial_ratio
906
+ )
907
+ mixed = True
908
+
909
+ # Forward pass
910
+ if is_v2:
911
+ result = model(
912
+ images,
913
+ targets=targets,
914
+ return_loss=True,
915
+ return_routing=(batch_idx % 10 == 0)
916
+ )
917
+ else:
918
+ result = model(images, targets=targets, return_loss=True)
919
+
920
+ losses = result['losses']
921
+
922
+ # Handle mixed CE loss
923
+ if mixed:
924
+ logits = result['logits']
925
+ ce_loss = lam * F.cross_entropy(logits, targets_a, label_smoothing=config.label_smoothing) + \
926
+ (1 - lam) * F.cross_entropy(logits, targets_b, label_smoothing=config.label_smoothing)
927
+ losses['ce'] = ce_loss
928
+
929
+ # Compute total loss (NO auxiliary routing loss - key finding!)
930
+ loss = (
931
+ config.ce_weight * losses['ce'] +
932
+ config.contrast_weight * losses.get('contrast', torch.tensor(0.0, device=device))
933
+ )
934
+
935
+ # Add scale CE losses (handle both regular and copy scales)
936
+ for key, val in losses.items():
937
+ if key.startswith('ce_') and key != 'ce':
938
+ if isinstance(val, torch.Tensor):
939
+ loss = loss + 0.1 * val
940
+
941
+ # Backward pass
942
+ optimizer.zero_grad()
943
+ loss.backward()
944
+
945
+ # Track routing gradient norms
946
+ if is_v2:
947
+ routing_metrics.update_grad_norms(model)
948
+
949
+ if config.gradient_clip > 0:
950
+ grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), config.gradient_clip)
951
+ else:
952
+ grad_norm = 0.0
953
+
954
+ optimizer.step()
955
+
956
+ if scheduler is not None and config.scheduler == "onecycle":
957
+ scheduler.step()
958
+
959
+ # Update routing metrics
960
+ if is_v2 and result.get('routing'):
961
+ routing_metrics.update_from_routing_info(result['routing'], model)
962
+
963
+ # Compute accuracy
964
+ with torch.no_grad():
965
+ logits = result['logits']
966
+ preds = logits.argmax(dim=-1)
967
+
968
+ if mixed:
969
+ correct = (lam * (preds == targets_a).float() +
970
+ (1 - lam) * (preds == targets_b).float()).sum()
971
+ else:
972
+ correct = (preds == targets).sum()
973
+
974
+ total_correct += correct.item()
975
+ total_samples += targets.size(0)
976
+ total_loss += loss.item()
977
+
978
+ # Track metrics
979
+ def to_float(v):
980
+ return v.item() if isinstance(v, torch.Tensor) else float(v)
981
+
982
+ contrast_loss = to_float(losses.get('contrast', 0.0))
983
+ current_lr = optimizer.param_groups[0]['lr']
984
+
985
+ tracker.update(
986
+ loss=loss.item(),
987
+ ce=losses['ce'].item(),
988
+ contrast=contrast_loss,
989
+ lr=current_lr
990
+ )
991
+
992
+ # TensorBoard logging
993
+ if writer is not None and (batch_idx + 1) % config.log_interval == 0:
994
+ step = global_step + batch_idx
995
+ writer.add_scalar('train/loss_total', loss.item(), step)
996
+ writer.add_scalar('train/loss_ce', losses['ce'].item(), step)
997
+ writer.add_scalar('train/loss_contrast', contrast_loss, step)
998
+ writer.add_scalar('train/learning_rate', current_lr, step)
999
+ writer.add_scalar('train/grad_norm', to_float(grad_norm), step)
1000
+
1001
+ if is_v2 and config.log_routing:
1002
+ routing_summary = routing_metrics.get_summary()
1003
+ for k, v in routing_summary.items():
1004
+ writer.add_scalar(f'routing/{k}', v, step)
1005
+
1006
+ # Progress bar
1007
+ routing_summary = routing_metrics.get_summary()
1008
+ postfix = {
1009
+ 'loss': f"{tracker.get_ema('loss'):.3f}",
1010
+ 'acc': f"{100.0 * total_correct / total_samples:.1f}%",
1011
+ }
1012
+ if is_v2 and 'grad_query' in routing_summary:
1013
+ postfix['βˆ‡q'] = f"{routing_summary['grad_query']:.2f}"
1014
+ if 'route_entropy' in routing_summary:
1015
+ postfix['H'] = f"{routing_summary['route_entropy']:.2f}"
1016
+
1017
+ pbar.set_postfix(postfix)
1018
+
1019
+ if scheduler is not None and config.scheduler == "cosine":
1020
+ scheduler.step()
1021
+
1022
+ return {
1023
+ 'loss': total_loss / len(train_loader),
1024
+ 'acc': 100.0 * total_correct / total_samples,
1025
+ **routing_metrics.get_summary()
1026
+ }
1027
+
1028
+
1029
+ @torch.no_grad()
1030
+ def evaluate_v2(
1031
+ model: nn.Module,
1032
+ test_loader: DataLoader,
1033
+ config: TrainingConfigV2
1034
+ ) -> Dict[str, float]:
1035
+ """Evaluate on test set."""
1036
+
1037
+ model.eval()
1038
+ device = config.device
1039
+
1040
+ total_loss = 0.0
1041
+ total_correct = 0
1042
+ total_samples = 0
1043
+
1044
+ # Handle variable number of scale heads (including copies)
1045
+ num_heads = len(model.head.heads) if hasattr(model.head, 'heads') else len(model.config.scales)
1046
+ head_correct = [0] * num_heads
1047
+
1048
+ for images, targets in test_loader:
1049
+ images = images.to(device, non_blocking=True)
1050
+ targets = targets.to(device, non_blocking=True)
1051
+
1052
+ result = model(images, targets=targets, return_loss=True)
1053
+
1054
+ logits = result['logits']
1055
+ losses = result['losses']
1056
+
1057
+ loss = losses['total']
1058
+ preds = logits.argmax(dim=-1)
1059
+
1060
+ total_loss += loss.item() * targets.size(0)
1061
+ total_correct += (preds == targets).sum().item()
1062
+ total_samples += targets.size(0)
1063
+
1064
+ # Per-head accuracy
1065
+ for i, scale_logits in enumerate(result['scale_logits']):
1066
+ scale_preds = scale_logits.argmax(dim=-1)
1067
+ head_correct[i] += (scale_preds == targets).sum().item()
1068
+
1069
+ metrics = {
1070
+ 'loss': total_loss / total_samples,
1071
+ 'acc': 100.0 * total_correct / total_samples
1072
+ }
1073
+
1074
+ # Map head indices to scale names
1075
+ if hasattr(model.head, 'head_scale_map'):
1076
+ for i, (scale, copy_idx) in enumerate(model.head.head_scale_map):
1077
+ key = f'acc_{scale}' if copy_idx == 0 else f'acc_{scale}_c{copy_idx}'
1078
+ metrics[key] = 100.0 * head_correct[i] / total_samples
1079
+ else:
1080
+ for i, scale in enumerate(model.config.scales):
1081
+ metrics[f'acc_{scale}'] = 100.0 * head_correct[i] / total_samples
1082
+
1083
+ return metrics
1084
+
1085
+
1086
+ # ============================================================================
1087
+ # MAIN TRAINING FUNCTION V2
1088
+ # ============================================================================
1089
+
1090
+ def train_david_beans_v2(
1091
+ model_config: Optional[Union[DavidBeansConfig, DavidBeansV2Config]] = None,
1092
+ train_config: Optional[TrainingConfigV2] = None
1093
+ ):
1094
+ """Main training function for DavidBeans V1 or V2."""
1095
+
1096
+ print("=" * 70)
1097
+ print(" DAVID-BEANS V2.1 TRAINING: Wormhole Routing")
1098
+ print("=" * 70)
1099
+ print()
1100
+ print(" πŸŒ€ WORMHOLES: Learned sparse routing")
1101
+ print(" πŸ’Ž CRYSTALS: Multi-scale projection")
1102
+ print()
1103
+ print(" Key insight: When routing IS the task, routing learns structure")
1104
+ print()
1105
+ print("=" * 70)
1106
+
1107
+ if train_config is None:
1108
+ train_config = TrainingConfigV2()
1109
+
1110
+ base_output_dir = Path(train_config.output_dir)
1111
+ base_output_dir.mkdir(parents=True, exist_ok=True)
1112
+
1113
+ # Checkpoint resolution
1114
+ checkpoint_path = None
1115
+ run_dir = None
1116
+ run_dir_name = None
1117
+
1118
+ if train_config.resume_from:
1119
+ resume_path = Path(train_config.resume_from)
1120
+
1121
+ if resume_path.is_file():
1122
+ checkpoint_path = resume_path
1123
+ run_dir = checkpoint_path.parent
1124
+ run_dir_name = run_dir.name
1125
+ print(f"\nπŸ“‚ Found checkpoint file: {checkpoint_path.name}")
1126
+ elif resume_path.is_dir():
1127
+ checkpoint_path = find_latest_checkpoint(resume_path)
1128
+ if checkpoint_path:
1129
+ run_dir = resume_path
1130
+ run_dir_name = resume_path.name
1131
+ print(f"\nπŸ“‚ Found checkpoint in dir: {checkpoint_path.name}")
1132
+ else:
1133
+ possible_dir = base_output_dir / train_config.resume_from
1134
+ if possible_dir.is_dir():
1135
+ checkpoint_path = find_latest_checkpoint(possible_dir)
1136
+ if checkpoint_path:
1137
+ run_dir = possible_dir
1138
+ run_dir_name = possible_dir.name
1139
+ print(f"\nπŸ“‚ Found checkpoint in: {run_dir_name}")
1140
+
1141
+ if checkpoint_path is None:
1142
+ possible_file = base_output_dir / train_config.resume_from
1143
+ if possible_file.is_file():
1144
+ checkpoint_path = possible_file
1145
+ run_dir = checkpoint_path.parent
1146
+ run_dir_name = run_dir.name
1147
+ print(f"\nπŸ“‚ Found checkpoint: {checkpoint_path.name}")
1148
+
1149
+ if checkpoint_path is None:
1150
+ print(f"\n [!] Could not find checkpoint: {train_config.resume_from}")
1151
+ print(f" [!] Starting fresh run instead")
1152
+ else:
1153
+ print(f" βœ“ Will resume from: {checkpoint_path}")
1154
+
1155
+ # Create new run directory if not resuming
1156
+ if run_dir is None:
1157
+ run_number = train_config.run_number or get_next_run_number(base_output_dir)
1158
+ run_dir_name = generate_run_dir_name(run_number, train_config.run_name, train_config.model_version)
1159
+ run_dir = base_output_dir / run_dir_name
1160
+ run_dir.mkdir(parents=True, exist_ok=True)
1161
+ print(f"\nπŸ“ New run: {run_dir_name}")
1162
+ else:
1163
+ print(f"\nπŸ“ Resuming run: {run_dir_name}")
1164
+
1165
+ output_dir = run_dir
1166
+
1167
+ # Model config
1168
+ if checkpoint_path and checkpoint_path.exists() and model_config is None:
1169
+ try:
1170
+ ckpt = torch.load(checkpoint_path, map_location='cpu')
1171
+ if 'model_config' in ckpt:
1172
+ saved_config = ckpt['model_config']
1173
+ print(f" βœ“ Loading model config from checkpoint")
1174
+ if train_config.model_version == 2:
1175
+ model_config = DavidBeansV2Config(**saved_config)
1176
+ else:
1177
+ model_config = DavidBeansConfig(**saved_config)
1178
+ except Exception as e:
1179
+ print(f" [!] Could not load config from checkpoint: {e}")
1180
+
1181
+ if model_config is None:
1182
+ if train_config.model_version == 2:
1183
+ model_config = DavidBeansV2Config(
1184
+ image_size=train_config.image_size,
1185
+ patch_size=4,
1186
+ dim=512,
1187
+ num_layers=4,
1188
+ num_heads=8,
1189
+ num_wormholes=8,
1190
+ wormhole_temperature=0.1,
1191
+ wormhole_mode="hybrid",
1192
+ num_tiles=16,
1193
+ tile_wormholes=4,
1194
+ scales=[64, 128, 256, 384, 512],
1195
+ num_classes=100,
1196
+ contrast_weight=train_config.contrast_weight,
1197
+ dropout=0.1
1198
+ )
1199
+ else:
1200
+ model_config = DavidBeansConfig(
1201
+ image_size=train_config.image_size,
1202
+ patch_size=4,
1203
+ dim=512,
1204
+ num_layers=4,
1205
+ num_heads=8,
1206
+ num_experts=5,
1207
+ k_neighbors=16,
1208
+ cantor_weight=0.3,
1209
+ scales=[64, 128, 256, 384, 512],
1210
+ num_classes=100,
1211
+ dropout=0.1
1212
+ )
1213
+
1214
+ device = train_config.device
1215
+ print(f"\nDevice: {device}")
1216
+ print(f"Model version: V{train_config.model_version}")
1217
+
1218
+ # Data
1219
+ print("\nLoading data...")
1220
+ train_loader, test_loader, num_classes = get_dataloaders(train_config)
1221
+ print(f" Dataset: {train_config.dataset}")
1222
+ print(f" Train: {len(train_loader.dataset)}, Test: {len(test_loader.dataset)}")
1223
+ print(f" Classes: {num_classes}")
1224
+
1225
+ model_config.num_classes = num_classes
1226
+
1227
+ # Model
1228
+ print("\nBuilding model...")
1229
+ if train_config.model_version == 2:
1230
+ model = DavidBeansV2(model_config)
1231
+ else:
1232
+ model = DavidBeans(model_config)
1233
+
1234
+ model = model.to(device)
1235
+ print(f"\n{model}")
1236
+
1237
+ num_params = sum(p.numel() for p in model.parameters())
1238
+ print(f"\nParameters: {num_params:,}")
1239
+
1240
+ # Optimizer
1241
+ print("\nSetting up optimizer...")
1242
+
1243
+ decay_params = []
1244
+ no_decay_params = []
1245
+
1246
+ for name, param in model.named_parameters():
1247
+ if not param.requires_grad:
1248
+ continue
1249
+ if 'bias' in name or 'norm' in name or 'embedding' in name:
1250
+ no_decay_params.append(param)
1251
+ else:
1252
+ decay_params.append(param)
1253
+
1254
+ optimizer = AdamW([
1255
+ {'params': decay_params, 'weight_decay': train_config.weight_decay},
1256
+ {'params': no_decay_params, 'weight_decay': 0.0}
1257
+ ], lr=train_config.learning_rate, betas=train_config.betas)
1258
+
1259
+ if train_config.scheduler == "cosine":
1260
+ scheduler = CosineAnnealingLR(
1261
+ optimizer,
1262
+ T_max=train_config.epochs - train_config.warmup_epochs,
1263
+ eta_min=train_config.min_lr
1264
+ )
1265
+ elif train_config.scheduler == "onecycle":
1266
+ scheduler = OneCycleLR(
1267
+ optimizer,
1268
+ max_lr=train_config.learning_rate,
1269
+ epochs=train_config.epochs,
1270
+ steps_per_epoch=len(train_loader),
1271
+ pct_start=train_config.warmup_epochs / train_config.epochs
1272
+ )
1273
+ else:
1274
+ scheduler = None
1275
+
1276
+ print(f" Optimizer: AdamW (lr={train_config.learning_rate}, wd={train_config.weight_decay})")
1277
+ print(f" Scheduler: {train_config.scheduler}")
1278
+
1279
+ # Print augmentation config
1280
+ print(f"\nAugmentation:")
1281
+ print(f" Mixup: {train_config.mixup_alpha if train_config.mixup_alpha > 0 else 'disabled'}")
1282
+ print(f" CutMix: {train_config.cutmix_alpha if train_config.cutmix_alpha > 0 else 'disabled'}")
1283
+ print(f" AlphaMix: {train_config.alphamix_alpha_range if train_config.use_alphamix else 'disabled'}")
1284
+
1285
+ tracker = MetricsTracker()
1286
+ routing_metrics = RoutingMetrics()
1287
+ best_acc = 0.0
1288
+ start_epoch = 0
1289
+
1290
+ # Load checkpoint
1291
+ if checkpoint_path and checkpoint_path.exists():
1292
+ start_epoch, best_acc = load_checkpoint(checkpoint_path, model, optimizer, device)
1293
+
1294
+ if scheduler is not None and train_config.scheduler == "cosine":
1295
+ for _ in range(start_epoch):
1296
+ scheduler.step()
1297
+ print(f" βœ“ Advanced scheduler to epoch {start_epoch}")
1298
+
1299
+ # TensorBoard
1300
+ writer = None
1301
+ if train_config.use_tensorboard and TENSORBOARD_AVAILABLE:
1302
+ tb_dir = output_dir / "tensorboard"
1303
+ tb_dir.mkdir(parents=True, exist_ok=True)
1304
+ writer = SummaryWriter(log_dir=str(tb_dir))
1305
+ print(f" TensorBoard: {tb_dir}")
1306
+
1307
+ # Save configs
1308
+ with open(output_dir / "config.json", "w") as f:
1309
+ json.dump({**model_config.__dict__, "architecture": f"DavidBeans_V{train_config.model_version}"},
1310
+ f, indent=2, default=str)
1311
+ with open(output_dir / "training_config.json", "w") as f:
1312
+ json.dump(train_config.to_dict(), f, indent=2, default=str)
1313
+
1314
+ # Training loop
1315
+ print("\n" + "=" * 70)
1316
+ print(" TRAINING")
1317
+ print("=" * 70)
1318
+
1319
+ for epoch in range(start_epoch, train_config.epochs):
1320
+ epoch_start = time.time()
1321
+
1322
+ # Warmup
1323
+ if epoch < train_config.warmup_epochs and train_config.scheduler == "cosine":
1324
+ warmup_lr = train_config.learning_rate * (epoch + 1) / train_config.warmup_epochs
1325
+ for param_group in optimizer.param_groups:
1326
+ param_group['lr'] = warmup_lr
1327
+
1328
+ train_metrics = train_epoch_v2(
1329
+ model, train_loader, optimizer, scheduler,
1330
+ train_config, epoch, tracker, routing_metrics, writer
1331
+ )
1332
+
1333
+ test_metrics = evaluate_v2(model, test_loader, train_config)
1334
+
1335
+ epoch_time = time.time() - epoch_start
1336
+
1337
+ # TensorBoard
1338
+ if writer is not None:
1339
+ writer.add_scalar('epoch/train_loss', train_metrics['loss'], epoch)
1340
+ writer.add_scalar('epoch/train_acc', train_metrics['acc'], epoch)
1341
+ writer.add_scalar('epoch/test_loss', test_metrics['loss'], epoch)
1342
+ writer.add_scalar('epoch/test_acc', test_metrics['acc'], epoch)
1343
+
1344
+ # Log all scale accuracies
1345
+ for key, val in test_metrics.items():
1346
+ if key.startswith('acc_'):
1347
+ writer.add_scalar(f'scales/{key}', val, epoch)
1348
+
1349
+ # Print summary - show primary scales only (not copies)
1350
+ primary_scale_accs = []
1351
+ for scale in model.config.scales:
1352
+ if f'acc_{scale}' in test_metrics:
1353
+ primary_scale_accs.append(f"{scale}:{test_metrics[f'acc_{scale}']:.1f}%")
1354
+ scale_accs = " | ".join(primary_scale_accs)
1355
+
1356
+ star = "β˜…" if test_metrics['acc'] > best_acc else ""
1357
+
1358
+ routing_info = ""
1359
+ if train_config.model_version == 2 and 'grad_query' in train_metrics:
1360
+ routing_info = f" | βˆ‡q:{train_metrics.get('grad_query', 0):.2f}"
1361
+
1362
+ print(f" β†’ Train: {train_metrics['acc']:.1f}% | Test: {test_metrics['acc']:.1f}% | "
1363
+ f"[{scale_accs}]{routing_info} | {epoch_time:.0f}s {star}")
1364
+
1365
+ # Save best model
1366
+ if test_metrics['acc'] > best_acc:
1367
+ best_acc = test_metrics['acc']
1368
+ torch.save({
1369
+ 'epoch': epoch,
1370
+ 'model_state_dict': model.state_dict(),
1371
+ 'optimizer_state_dict': optimizer.state_dict(),
1372
+ 'best_acc': best_acc,
1373
+ 'model_config': model_config.__dict__,
1374
+ 'train_config': train_config.to_dict()
1375
+ }, output_dir / "best_model.pt")
1376
+
1377
+ # Periodic checkpoint
1378
+ if (epoch + 1) % train_config.save_interval == 0:
1379
+ torch.save({
1380
+ 'epoch': epoch,
1381
+ 'model_state_dict': model.state_dict(),
1382
+ 'optimizer_state_dict': optimizer.state_dict(),
1383
+ 'best_acc': best_acc,
1384
+ 'model_config': model_config.__dict__,
1385
+ 'train_config': train_config.to_dict()
1386
+ }, output_dir / f"checkpoint_epoch_{epoch + 1}.pt")
1387
+
1388
+ if train_config.push_to_hub and HF_HUB_AVAILABLE:
1389
+ try:
1390
+ hub_dir = prepare_run_for_hub(
1391
+ model=model,
1392
+ model_config=model_config,
1393
+ train_config=train_config,
1394
+ best_acc=best_acc,
1395
+ output_dir=output_dir,
1396
+ run_dir_name=run_dir_name,
1397
+ training_history=tracker.get_history()
1398
+ )
1399
+ push_run_to_hub(
1400
+ hub_dir=hub_dir,
1401
+ repo_id=train_config.hub_repo_id,
1402
+ run_dir_name=run_dir_name,
1403
+ commit_message=f"Epoch {epoch + 1} - {best_acc:.2f}% acc"
1404
+ )
1405
+ print(f" πŸ“€ Uploaded to hub")
1406
+ except Exception as e:
1407
+ print(f" [!] Hub upload failed: {e}")
1408
+
1409
+ tracker.end_epoch()
1410
+
1411
+ # Final summary
1412
+ print("\n" + "=" * 70)
1413
+ print(" TRAINING COMPLETE")
1414
+ print("=" * 70)
1415
+ print(f"\n Best Test Accuracy: {best_acc:.2f}%")
1416
+ print(f" Model saved to: {output_dir / 'best_model.pt'}")
1417
+
1418
+ if writer is not None:
1419
+ writer.close()
1420
+
1421
+ return model, best_acc
1422
+
1423
+
1424
+ # ============================================================================
1425
+ # PRESETS
1426
+ # ============================================================================
1427
+
1428
+ def train_cifar100_v2_wormhole(
1429
+ run_name: str = "wormhole_base",
1430
+ push_to_hub: bool = False,
1431
+ resume: bool = False
1432
+ ):
1433
+ """CIFAR-100 with V2 wormhole routing."""
1434
+
1435
+ model_config = DavidBeansV2Config(
1436
+ image_size=32,
1437
+ patch_size=2,
1438
+ dim=512,
1439
+ num_layers=4,
1440
+ num_heads=16,
1441
+ # Wormhole routing parameters
1442
+ num_wormholes=16,
1443
+ wormhole_temperature=0.1,
1444
+ wormhole_mode="hybrid",
1445
+ # Tessellation parameters
1446
+ num_tiles=16,
1447
+ tile_wormholes=4,
1448
+ # Crystal head
1449
+ scales=[64, 128, 256, 512, 1024],
1450
+ num_classes=100,
1451
+ # V2.1 additions
1452
+ belly_layers=2,
1453
+ belly_residual=False,
1454
+ weighting_mode="learned",
1455
+ scale_copies=None,
1456
+ use_spine=False,
1457
+ use_collective=False,
1458
+ # Other
1459
+ contrast_temperature=0.07,
1460
+ contrast_weight=0.5,
1461
+ dropout=0.1
1462
+ )
1463
+
1464
+ train_config = TrainingConfigV2(
1465
+ run_name=run_name,
1466
+ model_version=2,
1467
+ dataset="cifar100",
1468
+ epochs=300,
1469
+ batch_size=512,
1470
+ learning_rate=3e-4,
1471
+ weight_decay=0.05,
1472
+ warmup_epochs=15,
1473
+ # Normalization
1474
+ normalization="standard",
1475
+ # Loss weights
1476
+ ce_weight=1.0,
1477
+ contrast_weight=0.5,
1478
+ # Augmentation
1479
+ label_smoothing=0.1,
1480
+ mixup_alpha=0.2,
1481
+ cutmix_alpha=1.0,
1482
+ # AlphaMix
1483
+ use_alphamix=True,
1484
+ alphamix_alpha_range=(0.3, 0.7),
1485
+ alphamix_spatial_ratio=0.25,
1486
+ # Output
1487
+ output_dir="./checkpoints/cifar100_v2",
1488
+ resume_from=None,
1489
+ # Hub
1490
+ push_to_hub=push_to_hub,
1491
+ hub_repo_id="AbstractPhil/geovit-david-beans",
1492
+ # Routing logging
1493
+ log_routing=True
1494
+ )
1495
+
1496
+ return train_david_beans_v2(model_config, train_config)
1497
+
1498
+
1499
+ def train_cifar100_v2_with_spine(
1500
+ run_name: str = "wormhole_spine",
1501
+ push_to_hub: bool = False,
1502
+ resume: bool = False
1503
+ ):
1504
+ """CIFAR-100 with V2 wormhole routing + conv spine."""
1505
+
1506
+ model_config = DavidBeansV2Config(
1507
+ image_size=32,
1508
+ patch_size=4,
1509
+ dim=512,
1510
+ num_layers=4,
1511
+ num_heads=8,
1512
+ num_wormholes=8,
1513
+ wormhole_temperature=0.1,
1514
+ wormhole_mode="hybrid",
1515
+ num_tiles=16,
1516
+ tile_wormholes=4,
1517
+ scales=[64, 128, 256, 384, 512],
1518
+ num_classes=100,
1519
+ # Enable spine
1520
+ use_spine=True,
1521
+ spine_channels=[64, 128, 256],
1522
+ spine_cross_attn=True,
1523
+ spine_gate_init=0.0,
1524
+ # Belly
1525
+ belly_layers=2,
1526
+ weighting_mode="geometric",
1527
+ contrast_temperature=0.07,
1528
+ contrast_weight=0.5,
1529
+ dropout=0.1
1530
+ )
1531
+
1532
+ train_config = TrainingConfigV2(
1533
+ run_name=run_name,
1534
+ model_version=2,
1535
+ dataset="cifar100",
1536
+ epochs=200,
1537
+ batch_size=128,
1538
+ learning_rate=3e-4,
1539
+ weight_decay=0.05,
1540
+ warmup_epochs=10,
1541
+ normalization="standard",
1542
+ ce_weight=1.0,
1543
+ contrast_weight=0.5,
1544
+ label_smoothing=0.1,
1545
+ mixup_alpha=0.2,
1546
+ cutmix_alpha=1.0,
1547
+ use_alphamix=True,
1548
+ output_dir="./checkpoints/cifar100_v2",
1549
+ push_to_hub=push_to_hub,
1550
+ hub_repo_id="AbstractPhil/geovit-david-beans",
1551
+ log_routing=True
1552
+ )
1553
+
1554
+ return train_david_beans_v2(model_config, train_config)
1555
+
1556
+
1557
+ def train_cifar100_v2_redundant_scales(
1558
+ run_name: str = "wormhole_redundant",
1559
+ push_to_hub: bool = False,
1560
+ resume: bool = False
1561
+ ):
1562
+ """CIFAR-100 with redundant small scales for ensemble effect."""
1563
+
1564
+ model_config = DavidBeansV2Config(
1565
+ image_size=32,
1566
+ patch_size=4,
1567
+ dim=512,
1568
+ num_layers=4,
1569
+ num_heads=8,
1570
+ num_wormholes=8,
1571
+ wormhole_temperature=0.1,
1572
+ wormhole_mode="hybrid",
1573
+ num_tiles=16,
1574
+ tile_wormholes=4,
1575
+ scales=[64, 128, 256, 512],
1576
+ # Redundant copies: 4x 64d, 2x 128d, 1x 256d, 1x 512d
1577
+ scale_copies=[4, 2, 1, 1],
1578
+ copy_theta_step=0.15,
1579
+ num_classes=100,
1580
+ weighting_mode="geometric",
1581
+ belly_layers=2,
1582
+ contrast_temperature=0.07,
1583
+ contrast_weight=0.5,
1584
+ dropout=0.1
1585
+ )
1586
+
1587
+ train_config = TrainingConfigV2(
1588
+ run_name=run_name,
1589
+ model_version=2,
1590
+ dataset="cifar100",
1591
+ epochs=200,
1592
+ batch_size=128,
1593
+ learning_rate=3e-4,
1594
+ weight_decay=0.05,
1595
+ warmup_epochs=10,
1596
+ normalization="standard",
1597
+ ce_weight=1.0,
1598
+ contrast_weight=0.5,
1599
+ label_smoothing=0.1,
1600
+ mixup_alpha=0.2,
1601
+ cutmix_alpha=1.0,
1602
+ use_alphamix=True,
1603
+ output_dir="./checkpoints/cifar100_v2",
1604
+ push_to_hub=push_to_hub,
1605
+ hub_repo_id="AbstractPhil/geovit-david-beans",
1606
+ log_routing=True
1607
+ )
1608
+
1609
+ return train_david_beans_v2(model_config, train_config)
1610
+
1611
+
1612
+ def train_cifar100_v2_no_norm(
1613
+ run_name: str = "wormhole_no_norm",
1614
+ push_to_hub: bool = False,
1615
+ resume: bool = False
1616
+ ):
1617
+ """CIFAR-100 with no normalization (raw pixels) for geometric components."""
1618
+
1619
+ model_config = DavidBeansV2Config(
1620
+ image_size=32,
1621
+ patch_size=4,
1622
+ dim=512,
1623
+ num_layers=4,
1624
+ num_heads=8,
1625
+ num_wormholes=8,
1626
+ wormhole_temperature=0.1,
1627
+ wormhole_mode="hybrid",
1628
+ num_tiles=16,
1629
+ tile_wormholes=4,
1630
+ scales=[64, 128, 256, 384, 512],
1631
+ num_classes=100,
1632
+ belly_layers=2,
1633
+ weighting_mode="learned",
1634
+ contrast_temperature=0.07,
1635
+ contrast_weight=0.5,
1636
+ dropout=0.1
1637
+ )
1638
+
1639
+ train_config = TrainingConfigV2(
1640
+ run_name=run_name,
1641
+ model_version=2,
1642
+ dataset="cifar100",
1643
+ epochs=200,
1644
+ batch_size=128,
1645
+ learning_rate=3e-4,
1646
+ weight_decay=0.05,
1647
+ warmup_epochs=10,
1648
+ # No normalization - raw [0,1] pixels
1649
+ normalization="none",
1650
+ ce_weight=1.0,
1651
+ contrast_weight=0.5,
1652
+ label_smoothing=0.1,
1653
+ mixup_alpha=0.2,
1654
+ cutmix_alpha=1.0,
1655
+ use_alphamix=True,
1656
+ output_dir="./checkpoints/cifar100_v2",
1657
+ push_to_hub=push_to_hub,
1658
+ hub_repo_id="AbstractPhil/geovit-david-beans",
1659
+ log_routing=True
1660
+ )
1661
+
1662
+ return train_david_beans_v2(model_config, train_config)
1663
+
1664
+
1665
+ def train_cifar100_v1_baseline(
1666
+ run_name: str = "v1_baseline",
1667
+ push_to_hub: bool = False,
1668
+ resume: bool = False
1669
+ ):
1670
+ """CIFAR-100 with V1 (fixed Cantor routing) for comparison."""
1671
+
1672
+ model_config = DavidBeansConfig(
1673
+ image_size=32,
1674
+ patch_size=4,
1675
+ dim=512,
1676
+ num_layers=4,
1677
+ num_heads=8,
1678
+ num_experts=5,
1679
+ k_neighbors=16,
1680
+ cantor_weight=0.3,
1681
+ scales=[64, 128, 256, 384, 512],
1682
+ num_classes=100,
1683
+ dropout=0.1
1684
+ )
1685
+
1686
+ train_config = TrainingConfigV2(
1687
+ run_name=run_name,
1688
+ model_version=1,
1689
+ dataset="cifar100",
1690
+ epochs=200,
1691
+ batch_size=128,
1692
+ learning_rate=3e-4,
1693
+ weight_decay=0.05,
1694
+ warmup_epochs=10,
1695
+ normalization="standard",
1696
+ ce_weight=1.0,
1697
+ contrast_weight=0.5,
1698
+ label_smoothing=0.1,
1699
+ mixup_alpha=0.2,
1700
+ cutmix_alpha=1.0,
1701
+ use_alphamix=False, # V1 doesn't benefit as much
1702
+ output_dir="./checkpoints/cifar100_v1",
1703
+ resume_from="latest" if resume else None,
1704
+ push_to_hub=push_to_hub,
1705
+ hub_repo_id="AbstractPhil/geovit-david-beans",
1706
+ log_routing=False
1707
+ )
1708
+
1709
+ return train_david_beans_v2(model_config, train_config)
1710
+
1711
+
1712
+ # ============================================================================
1713
+ # MAIN
1714
+ # ============================================================================
1715
+
1716
+ if __name__ == "__main__":
1717
+
1718
+ # =====================================================
1719
+ # CONFIGURATION
1720
+ # =====================================================
1721
+
1722
+ PRESET = "v2_wormhole" # Options: "v1_baseline", "v2_wormhole", "v2_spine", "v2_redundant", "v2_no_norm", "test"
1723
+ RESUME = False
1724
+ RUN_NAME = "5scale_2x2patch_alphamix_d512_4layer"
1725
+ PUSH_TO_HUB = True
1726
+
1727
+ # =====================================================
1728
+ # RUN
1729
+ # =====================================================
1730
+
1731
+ if PRESET == "test":
1732
+ print("πŸ§ͺ Quick test...")
1733
+ model_config = DavidBeansV2Config(
1734
+ image_size=32, patch_size=4, dim=128, num_layers=2,
1735
+ num_heads=4, num_wormholes=4, num_tiles=8,
1736
+ scales=[32, 64, 128], num_classes=10,
1737
+ belly_layers=2
1738
+ )
1739
+ train_config = TrainingConfigV2(
1740
+ run_name="test", model_version=2,
1741
+ epochs=2, batch_size=32,
1742
+ use_augmentation=False, mixup_alpha=0.0, cutmix_alpha=0.0,
1743
+ use_alphamix=False
1744
+ )
1745
+ model, acc = train_david_beans_v2(model_config, train_config)
1746
+
1747
+ elif PRESET == "v1_baseline":
1748
+ print("πŸ«˜πŸ’Ž Training DavidBeans V1 (Cantor routing)...")
1749
+ model, acc = train_cifar100_v1_baseline(
1750
+ run_name=RUN_NAME,
1751
+ push_to_hub=PUSH_TO_HUB,
1752
+ resume=RESUME
1753
+ )
1754
+
1755
+ elif PRESET == "v2_wormhole":
1756
+ print("πŸ’Ž Training DavidBeans V2 (Wormhole routing)...")
1757
+ model, acc = train_cifar100_v2_wormhole(
1758
+ run_name=RUN_NAME,
1759
+ push_to_hub=PUSH_TO_HUB,
1760
+ resume=RESUME
1761
+ )
1762
+
1763
+ elif PRESET == "v2_spine":
1764
+ print("πŸ’ŽπŸ¦΄ Training DavidBeans V2 with Conv Spine...")
1765
+ model, acc = train_cifar100_v2_with_spine(
1766
+ run_name=RUN_NAME,
1767
+ push_to_hub=PUSH_TO_HUB,
1768
+ resume=RESUME
1769
+ )
1770
+
1771
+ elif PRESET == "v2_redundant":
1772
+ print("πŸ’Žβœ–οΈ Training DavidBeans V2 with Redundant Scales...")
1773
+ model, acc = train_cifar100_v2_redundant_scales(
1774
+ run_name=RUN_NAME,
1775
+ push_to_hub=PUSH_TO_HUB,
1776
+ resume=RESUME
1777
+ )
1778
+
1779
+ elif PRESET == "v2_no_norm":
1780
+ print("πŸ’ŽπŸ“· Training DavidBeans V2 with No Normalization...")
1781
+ model, acc = train_cifar100_v2_no_norm(
1782
+ run_name=RUN_NAME,
1783
+ push_to_hub=PUSH_TO_HUB,
1784
+ resume=RESUME
1785
+ )
1786
+
1787
+ else:
1788
+ raise ValueError(f"Unknown preset: {PRESET}")
1789
+
1790
+ print(f"\nπŸŽ‰ Done! Best accuracy: {acc:.2f}%")