AbstractPhil commited on
Commit
6640107
·
verified ·
1 Parent(s): ff31041

Create trainer_v2_wormhole_routing.py

Browse files
Files changed (1) hide show
  1. trainer_v2_wormhole_routing.py +1415 -0
trainer_v2_wormhole_routing.py ADDED
@@ -0,0 +1,1415 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Train DavidBeans V2: Wormhole Routing Architecture
3
+ ===================================================
4
+
5
+ ┌─────────────────┐
6
+ │ BEANS V2 │ "I learn where to look..."
7
+ │ (Wormhole ViT)│
8
+ │ 🌀 → 🌀 → 🌀 │ Learned sparse routing
9
+ └────────┬────────┘
10
+
11
+
12
+ ┌─────────────────┐
13
+ │ DAVID │ "I know the crystals..."
14
+ │ (Classifier) │
15
+ │ 💎 → 💎 → 💎 │ Multi-scale projection
16
+ └────────┬────────┘
17
+
18
+
19
+ [Prediction]
20
+
21
+ Key findings from wormhole experiments:
22
+ 1. When routing IS the task, routing learns structure
23
+ 2. Auxiliary losses can be gamed - removed in V2
24
+ 3. Gradient flow through router is critical - verified
25
+ 4. Cross-contrastive aligns patch↔scale features
26
+
27
+ Author: AbstractPhil
28
+ Date: November 29, 2025
29
+ """
30
+
31
+ import torch
32
+ import torch.nn as nn
33
+ import torch.nn.functional as F
34
+ from torch.utils.data import DataLoader
35
+ from torch.optim import AdamW
36
+ from torch.optim.lr_scheduler import CosineAnnealingLR, OneCycleLR
37
+ from tqdm.auto import tqdm
38
+ import time
39
+ import math
40
+ from pathlib import Path
41
+ from typing import Dict, Optional, Tuple, List, Union
42
+ from dataclasses import dataclass, field
43
+ import json
44
+ from datetime import datetime
45
+ import os
46
+ import shutil
47
+
48
+ from google.colab import userdata
49
+
50
+ os.environ['HF_TOKEN'] = userdata.get('HF_TOKEN')
51
+ HF_TOKEN = userdata.get('HF_TOKEN')
52
+
53
+ try:
54
+ from google.colab import userdata
55
+ HF_TOKEN = userdata.get('HF_TOKEN')
56
+ os.environ['HF_TOKEN'] = HF_TOKEN
57
+ except:
58
+ pass
59
+
60
+ # Import both model versions
61
+ from geofractal.model.david_beans.model import DavidBeans, DavidBeansConfig
62
+ from geofractal.model.david_beans.model_v2 import DavidBeansV2, DavidBeansV2Config
63
+
64
+ # HuggingFace Hub integration
65
+ try:
66
+ from huggingface_hub import HfApi, create_repo, upload_folder
67
+ HF_HUB_AVAILABLE = True
68
+ except ImportError:
69
+ HF_HUB_AVAILABLE = False
70
+ print(" [!] huggingface_hub not installed. Run: pip install huggingface_hub")
71
+
72
+ # Safetensors support
73
+ try:
74
+ from safetensors.torch import save_file as save_safetensors
75
+ SAFETENSORS_AVAILABLE = True
76
+ except ImportError:
77
+ SAFETENSORS_AVAILABLE = False
78
+
79
+ # TensorBoard support
80
+ try:
81
+ from torch.utils.tensorboard import SummaryWriter
82
+ TENSORBOARD_AVAILABLE = True
83
+ except ImportError:
84
+ TENSORBOARD_AVAILABLE = False
85
+ print(" [!] tensorboard not installed. Run: pip install tensorboard")
86
+
87
+ import numpy as np
88
+
89
+ # ============================================================================
90
+ # TRAINING CONFIGURATION V2
91
+ # ============================================================================
92
+
93
+ @dataclass
94
+ class TrainingConfigV2:
95
+ """Training configuration for DavidBeans V2 with wormhole routing."""
96
+
97
+ # Run identification
98
+ run_name: str = "default"
99
+ run_number: Optional[int] = None
100
+
101
+ # Model version
102
+ model_version: int = 2 # 1 = original, 2 = wormhole
103
+
104
+ # Data
105
+ dataset: str = "cifar100"
106
+ image_size: int = 32
107
+ batch_size: int = 128
108
+ num_workers: int = 4
109
+
110
+ # Training schedule
111
+ epochs: int = 200
112
+ warmup_epochs: int = 10
113
+
114
+ # Optimizer
115
+ learning_rate: float = 3e-4
116
+ weight_decay: float = 0.05
117
+ betas: Tuple[float, float] = (0.9, 0.999)
118
+
119
+ # Learning rate schedule
120
+ scheduler: str = "cosine"
121
+ min_lr: float = 1e-6
122
+
123
+ # Loss weights (based on experimental findings)
124
+ ce_weight: float = 1.0
125
+ contrast_weight: float = 0.5
126
+ # NOTE: No auxiliary routing loss - routing learns from task pressure
127
+
128
+ # Regularization
129
+ gradient_clip: float = 1.0
130
+ label_smoothing: float = 0.1
131
+
132
+ # Augmentation
133
+ use_augmentation: bool = True
134
+ mixup_alpha: float = 0.2
135
+ cutmix_alpha: float = 1.0
136
+
137
+ # Checkpointing
138
+ save_interval: int = 10
139
+ output_dir: str = "./checkpoints"
140
+ resume_from: Optional[str] = None
141
+
142
+ # TensorBoard
143
+ use_tensorboard: bool = True
144
+ log_interval: int = 50
145
+ log_routing: bool = True # Log routing patterns
146
+
147
+ # HuggingFace Hub
148
+ push_to_hub: bool = False
149
+ hub_repo_id: str = "AbstractPhil/geovit-david-beans"
150
+ hub_private: bool = False
151
+
152
+ # Device
153
+ device: str = "cuda" if torch.cuda.is_available() else "cpu"
154
+
155
+ def to_dict(self) -> Dict:
156
+ return {k: v for k, v in self.__dict__.items()}
157
+
158
+
159
+ # ============================================================================
160
+ # ROUTING METRICS
161
+ # ============================================================================
162
+
163
+ class RoutingMetrics:
164
+ """Track and analyze wormhole routing patterns."""
165
+
166
+ def __init__(self):
167
+ self.reset()
168
+
169
+ def reset(self):
170
+ self.route_entropies = []
171
+ self.route_diversities = []
172
+ self.grad_norms = {'query': [], 'key': []}
173
+
174
+ @torch.no_grad()
175
+ def compute_route_entropy(self, soft_routes: torch.Tensor) -> float:
176
+ """Compute average entropy of routing distributions."""
177
+ # soft_routes: [B, P, K] or [B, T, K]
178
+ # Higher entropy = more diverse routing
179
+ eps = 1e-8
180
+ entropy = -(soft_routes * (soft_routes + eps).log()).sum(dim=-1)
181
+ return entropy.mean().item()
182
+
183
+ @torch.no_grad()
184
+ def compute_route_diversity(self, routes: torch.Tensor, num_positions: int) -> float:
185
+ """Compute how many unique destinations are used."""
186
+ # routes: [B, P, K] indices
187
+ unique_per_sample = []
188
+ for b in range(routes.shape[0]):
189
+ unique = routes[b].unique().numel()
190
+ unique_per_sample.append(unique / num_positions)
191
+ return sum(unique_per_sample) / len(unique_per_sample)
192
+
193
+ def update_from_routing_info(self, routing_info: List[Dict], model: nn.Module):
194
+ """Extract metrics from routing info returned by V2 model."""
195
+ if not routing_info:
196
+ return
197
+
198
+ for layer_info in routing_info:
199
+ # Attention routing
200
+ if layer_info.get('attention'):
201
+ attn = layer_info['attention']
202
+ if attn.get('weights') is not None:
203
+ entropy = self.compute_route_entropy(attn['weights'])
204
+ self.route_entropies.append(entropy)
205
+ if attn.get('routes') is not None:
206
+ P = attn['routes'].shape[1]
207
+ diversity = self.compute_route_diversity(attn['routes'], P)
208
+ self.route_diversities.append(diversity)
209
+
210
+ # Expert routing
211
+ if layer_info.get('expert'):
212
+ exp = layer_info['expert']
213
+ if exp.get('weights') is not None:
214
+ entropy = self.compute_route_entropy(exp['weights'])
215
+ self.route_entropies.append(entropy)
216
+
217
+ def update_grad_norms(self, model: nn.Module):
218
+ """Track gradient norms through router projections."""
219
+ for name, param in model.named_parameters():
220
+ if param.grad is not None:
221
+ if 'query_proj' in name and 'weight' in name:
222
+ self.grad_norms['query'].append(param.grad.norm().item())
223
+ elif 'key_proj' in name and 'weight' in name:
224
+ self.grad_norms['key'].append(param.grad.norm().item())
225
+
226
+ def get_summary(self) -> Dict[str, float]:
227
+ """Get summary statistics."""
228
+ summary = {}
229
+
230
+ if self.route_entropies:
231
+ summary['route_entropy'] = sum(self.route_entropies) / len(self.route_entropies)
232
+ if self.route_diversities:
233
+ summary['route_diversity'] = sum(self.route_diversities) / len(self.route_diversities)
234
+ if self.grad_norms['query']:
235
+ summary['grad_query'] = sum(self.grad_norms['query']) / len(self.grad_norms['query'])
236
+ if self.grad_norms['key']:
237
+ summary['grad_key'] = sum(self.grad_norms['key']) / len(self.grad_norms['key'])
238
+
239
+ return summary
240
+
241
+
242
+ # ============================================================================
243
+ # DATA LOADING (unchanged from V1)
244
+ # ============================================================================
245
+
246
+ def get_dataloaders(config: TrainingConfigV2) -> Tuple[DataLoader, DataLoader, int]:
247
+ """Get train and test dataloaders."""
248
+
249
+ try:
250
+ import torchvision
251
+ import torchvision.transforms as T
252
+
253
+ if config.dataset == "cifar10":
254
+ mean, std = (0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)
255
+ num_classes = 10
256
+ DatasetClass = torchvision.datasets.CIFAR10
257
+ elif config.dataset == "cifar100":
258
+ mean, std = (0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)
259
+ num_classes = 100
260
+ DatasetClass = torchvision.datasets.CIFAR100
261
+ else:
262
+ raise ValueError(f"Unknown dataset: {config.dataset}")
263
+
264
+ if config.use_augmentation:
265
+ train_transform = T.Compose([
266
+ T.RandomCrop(32, padding=4),
267
+ T.RandomHorizontalFlip(),
268
+ T.AutoAugment(T.AutoAugmentPolicy.CIFAR10),
269
+ T.ToTensor(),
270
+ T.Normalize(mean, std)
271
+ ])
272
+ else:
273
+ train_transform = T.Compose([
274
+ T.ToTensor(),
275
+ T.Normalize(mean, std)
276
+ ])
277
+
278
+ test_transform = T.Compose([
279
+ T.ToTensor(),
280
+ T.Normalize(mean, std)
281
+ ])
282
+
283
+ train_dataset = DatasetClass(
284
+ root='./data', train=True, download=True, transform=train_transform
285
+ )
286
+ test_dataset = DatasetClass(
287
+ root='./data', train=False, download=True, transform=test_transform
288
+ )
289
+
290
+ train_loader = DataLoader(
291
+ train_dataset,
292
+ batch_size=config.batch_size,
293
+ shuffle=True,
294
+ num_workers=config.num_workers,
295
+ pin_memory=True,
296
+ persistent_workers=config.num_workers > 0,
297
+ drop_last=True
298
+ )
299
+ test_loader = DataLoader(
300
+ test_dataset,
301
+ batch_size=config.batch_size,
302
+ shuffle=False,
303
+ num_workers=config.num_workers,
304
+ pin_memory=True,
305
+ persistent_workers=config.num_workers > 0
306
+ )
307
+
308
+ return train_loader, test_loader, num_classes
309
+
310
+ except ImportError:
311
+ print(" [!] torchvision not available, using synthetic data")
312
+ return get_synthetic_dataloaders(config)
313
+
314
+
315
+ def get_synthetic_dataloaders(config: TrainingConfigV2) -> Tuple[DataLoader, DataLoader, int]:
316
+ """Fallback synthetic data for testing."""
317
+
318
+ class SyntheticDataset(torch.utils.data.Dataset):
319
+ def __init__(self, size: int, image_size: int, num_classes: int):
320
+ self.size = size
321
+ self.image_size = image_size
322
+ self.num_classes = num_classes
323
+
324
+ def __len__(self):
325
+ return self.size
326
+
327
+ def __getitem__(self, idx):
328
+ x = torch.randn(3, self.image_size, self.image_size)
329
+ y = idx % self.num_classes
330
+ return x, y
331
+
332
+ num_classes = 100 if config.dataset == "cifar100" else 10
333
+ train_dataset = SyntheticDataset(5000, config.image_size, num_classes)
334
+ test_dataset = SyntheticDataset(1000, config.image_size, num_classes)
335
+
336
+ train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True)
337
+ test_loader = DataLoader(test_dataset, batch_size=config.batch_size, shuffle=False)
338
+
339
+ return train_loader, test_loader, num_classes
340
+
341
+
342
+ # ============================================================================
343
+ # MIXUP / CUTMIX AUGMENTATION
344
+ # ============================================================================
345
+
346
+ def mixup_data(x: torch.Tensor, y: torch.Tensor, alpha: float = 0.2):
347
+ """Mixup augmentation."""
348
+ if alpha > 0:
349
+ lam = torch.distributions.Beta(alpha, alpha).sample().item()
350
+ else:
351
+ lam = 1.0
352
+
353
+ batch_size = x.size(0)
354
+ index = torch.randperm(batch_size, device=x.device)
355
+
356
+ mixed_x = lam * x + (1 - lam) * x[index]
357
+ y_a, y_b = y, y[index]
358
+
359
+ return mixed_x, y_a, y_b, lam
360
+
361
+
362
+ def cutmix_data(x: torch.Tensor, y: torch.Tensor, alpha: float = 1.0):
363
+ """CutMix augmentation."""
364
+ if alpha > 0:
365
+ lam = torch.distributions.Beta(alpha, alpha).sample().item()
366
+ else:
367
+ lam = 1.0
368
+
369
+ batch_size = x.size(0)
370
+ index = torch.randperm(batch_size, device=x.device)
371
+
372
+ _, _, H, W = x.shape
373
+
374
+ cut_ratio = math.sqrt(1 - lam)
375
+ cut_h = int(H * cut_ratio)
376
+ cut_w = int(W * cut_ratio)
377
+
378
+ cx = torch.randint(0, H, (1,)).item()
379
+ cy = torch.randint(0, W, (1,)).item()
380
+
381
+ x1 = max(0, cx - cut_h // 2)
382
+ x2 = min(H, cx + cut_h // 2)
383
+ y1 = max(0, cy - cut_w // 2)
384
+ y2 = min(W, cy + cut_w // 2)
385
+
386
+ mixed_x = x.clone()
387
+ mixed_x[:, :, x1:x2, y1:y2] = x[index, :, x1:x2, y1:y2]
388
+
389
+ lam = 1 - ((x2 - x1) * (y2 - y1)) / (H * W)
390
+
391
+ y_a, y_b = y, y[index]
392
+
393
+ return mixed_x, y_a, y_b, lam
394
+
395
+
396
+ # ============================================================================
397
+ # METRICS TRACKER
398
+ # ============================================================================
399
+
400
+ class MetricsTracker:
401
+ """Track training metrics with EMA smoothing."""
402
+
403
+ def __init__(self, ema_decay: float = 0.9):
404
+ self.ema_decay = ema_decay
405
+ self.metrics = {}
406
+ self.ema_metrics = {}
407
+ self.history = {}
408
+
409
+ def update(self, **kwargs):
410
+ for k, v in kwargs.items():
411
+ if isinstance(v, torch.Tensor):
412
+ v = v.item()
413
+
414
+ if k not in self.metrics:
415
+ self.metrics[k] = []
416
+ self.ema_metrics[k] = v
417
+ self.history[k] = []
418
+
419
+ self.metrics[k].append(v)
420
+ self.ema_metrics[k] = self.ema_decay * self.ema_metrics[k] + (1 - self.ema_decay) * v
421
+
422
+ def get_ema(self, key: str) -> float:
423
+ return self.ema_metrics.get(key, 0.0)
424
+
425
+ def get_epoch_mean(self, key: str) -> float:
426
+ values = self.metrics.get(key, [])
427
+ return sum(values) / len(values) if values else 0.0
428
+
429
+ def end_epoch(self):
430
+ for k, v in self.metrics.items():
431
+ if v:
432
+ self.history[k].append(sum(v) / len(v))
433
+ self.metrics = {k: [] for k in self.metrics}
434
+
435
+ def get_history(self) -> Dict:
436
+ return self.history
437
+
438
+
439
+ # ============================================================================
440
+ # CHECKPOINT UTILITIES
441
+ # ============================================================================
442
+
443
+ def find_latest_checkpoint(output_dir: Path) -> Optional[Path]:
444
+ """Find the most recent checkpoint in output directory."""
445
+ checkpoints = list(output_dir.glob("checkpoint_epoch_*.pt"))
446
+
447
+ if not checkpoints:
448
+ best_model = output_dir / "best_model.pt"
449
+ if best_model.exists():
450
+ return best_model
451
+ return None
452
+
453
+ def get_epoch(p):
454
+ try:
455
+ return int(p.stem.split("_")[-1])
456
+ except:
457
+ return 0
458
+
459
+ checkpoints.sort(key=get_epoch, reverse=True)
460
+ return checkpoints[0]
461
+
462
+
463
+ def get_next_run_number(base_dir: Path) -> int:
464
+ """Get the next run number by scanning existing run directories."""
465
+ if not base_dir.exists():
466
+ return 1
467
+
468
+ max_num = 0
469
+ for d in base_dir.iterdir():
470
+ if d.is_dir() and d.name.startswith("run_"):
471
+ try:
472
+ num = int(d.name.split("_")[1])
473
+ max_num = max(max_num, num)
474
+ except (IndexError, ValueError):
475
+ continue
476
+
477
+ return max_num + 1
478
+
479
+
480
+ def generate_run_dir_name(run_number: int, run_name: str, version: int = 2) -> str:
481
+ """Generate a run directory name."""
482
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
483
+ safe_name = "".join(c if c.isalnum() or c == "_" else "_" for c in run_name.lower())
484
+ safe_name = "_".join(filter(None, safe_name.split("_")))
485
+ return f"run_{run_number:03d}_v{version}_{safe_name}_{timestamp}"
486
+
487
+
488
+ def find_latest_run_dir(base_dir: Path) -> Optional[Path]:
489
+ """Find the most recent run directory."""
490
+ if not base_dir.exists():
491
+ return None
492
+
493
+ run_dirs = [d for d in base_dir.iterdir() if d.is_dir() and d.name.startswith("run_")]
494
+
495
+ if not run_dirs:
496
+ return None
497
+
498
+ run_dirs.sort(key=lambda d: d.stat().st_mtime, reverse=True)
499
+ return run_dirs[0]
500
+
501
+
502
+ def load_checkpoint(
503
+ checkpoint_path: Path,
504
+ model: nn.Module,
505
+ optimizer: Optional[torch.optim.Optimizer] = None,
506
+ device: str = "cuda"
507
+ ) -> Tuple[int, float]:
508
+ """Load checkpoint and return (start_epoch, best_acc)."""
509
+ print(f"\n📂 Loading checkpoint: {checkpoint_path}")
510
+ checkpoint = torch.load(checkpoint_path, map_location=device)
511
+
512
+ model.load_state_dict(checkpoint['model_state_dict'])
513
+ print(f" ✓ Loaded model weights")
514
+
515
+ if optimizer is not None and 'optimizer_state_dict' in checkpoint:
516
+ optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
517
+ print(f" ✓ Loaded optimizer state")
518
+
519
+ epoch = checkpoint.get('epoch', 0)
520
+ best_acc = checkpoint.get('best_acc', 0.0)
521
+
522
+ print(f" ✓ Resuming from epoch {epoch + 1}, best_acc={best_acc:.2f}%")
523
+
524
+ return epoch + 1, best_acc
525
+
526
+
527
+ # ============================================================================
528
+ # HUGGINGFACE HUB INTEGRATION
529
+ # ============================================================================
530
+
531
+ def generate_run_readme(
532
+ model_config: Union[DavidBeansConfig, DavidBeansV2Config],
533
+ train_config: TrainingConfigV2,
534
+ best_acc: float,
535
+ run_dir_name: str
536
+ ) -> str:
537
+ """Generate README for a specific run."""
538
+
539
+ scales_str = ", ".join([str(s) for s in model_config.scales])
540
+
541
+ # V2 specific info
542
+ if isinstance(model_config, DavidBeansV2Config):
543
+ routing_info = f"""
544
+ ## Wormhole Routing (V2)
545
+ | Parameter | Value |
546
+ |-----------|-------|
547
+ | Mode | {model_config.wormhole_mode} |
548
+ | Wormholes/Position | {model_config.num_wormholes} |
549
+ | Temperature | {model_config.wormhole_temperature} |
550
+ | Tiles | {model_config.num_tiles} |
551
+ | Tile Wormholes | {model_config.tile_wormholes} |
552
+ """
553
+ else:
554
+ routing_info = """
555
+ ## Routing (V1)
556
+ | Parameter | Value |
557
+ |-----------|-------|
558
+ | k_neighbors | {model_config.k_neighbors} |
559
+ | Cantor Weight | {model_config.cantor_weight} |
560
+ """
561
+
562
+ return f"""# Run: {run_dir_name}
563
+
564
+ ## Results
565
+ - **Best Accuracy**: {best_acc:.2f}%
566
+ - **Dataset**: {train_config.dataset}
567
+ - **Epochs**: {train_config.epochs}
568
+ - **Model Version**: V{train_config.model_version}
569
+
570
+ ## Model Config
571
+ | Parameter | Value |
572
+ |-----------|-------|
573
+ | Dim | {model_config.dim} |
574
+ | Layers | {model_config.num_layers} |
575
+ | Heads | {model_config.num_heads} |
576
+ | Scales | [{scales_str}] |
577
+ {routing_info}
578
+
579
+ ## Training Config
580
+ | Parameter | Value |
581
+ |-----------|-------|
582
+ | Learning Rate | {train_config.learning_rate} |
583
+ | Weight Decay | {train_config.weight_decay} |
584
+ | Batch Size | {train_config.batch_size} |
585
+ | CE Weight | {train_config.ce_weight} |
586
+ | Contrast Weight | {train_config.contrast_weight} |
587
+
588
+ ## Key Findings Applied
589
+ - Routing learns from task pressure (no auxiliary routing losses)
590
+ - Gradients verified to flow through router
591
+ - Cross-contrastive aligns patch↔scale features
592
+ """
593
+
594
+
595
+ def prepare_run_for_hub(
596
+ model: nn.Module,
597
+ model_config: Union[DavidBeansConfig, DavidBeansV2Config],
598
+ train_config: TrainingConfigV2,
599
+ best_acc: float,
600
+ output_dir: Path,
601
+ run_dir_name: str,
602
+ training_history: Optional[Dict] = None
603
+ ) -> Path:
604
+ """Prepare run files for upload to HuggingFace Hub."""
605
+
606
+ hub_dir = output_dir / "hub_upload"
607
+ run_hub_dir = hub_dir / "weights" / run_dir_name
608
+ run_hub_dir.mkdir(parents=True, exist_ok=True)
609
+
610
+ # Save best model weights
611
+ state_dict = {k: v.clone() for k, v in model.state_dict().items()}
612
+
613
+ if SAFETENSORS_AVAILABLE:
614
+ try:
615
+ save_safetensors(state_dict, run_hub_dir / "best.safetensors")
616
+ print(f" ✓ Saved best.safetensors")
617
+ except Exception as e:
618
+ print(f" [!] Safetensors failed ({e}), using pytorch format")
619
+ torch.save(state_dict, run_hub_dir / "best.pt")
620
+ else:
621
+ torch.save(state_dict, run_hub_dir / "best.pt")
622
+
623
+ # Save model config
624
+ config_dict = {
625
+ "architecture": f"DavidBeans_V{train_config.model_version}",
626
+ "model_type": "david_beans_v2" if train_config.model_version == 2 else "david_beans",
627
+ **model_config.__dict__
628
+ }
629
+ with open(run_hub_dir / "config.json", "w") as f:
630
+ json.dump(config_dict, f, indent=2, default=str)
631
+
632
+ # Save training config
633
+ with open(run_hub_dir / "training_config.json", "w") as f:
634
+ json.dump(train_config.to_dict(), f, indent=2, default=str)
635
+
636
+ # Generate README
637
+ run_readme = generate_run_readme(model_config, train_config, best_acc, run_dir_name)
638
+ with open(run_hub_dir / "README.md", "w") as f:
639
+ f.write(run_readme)
640
+
641
+ # Save training history
642
+ if training_history:
643
+ with open(run_hub_dir / "training_history.json", "w") as f:
644
+ json.dump(training_history, f, indent=2)
645
+
646
+ # Copy TensorBoard logs
647
+ tb_dir = output_dir / "tensorboard"
648
+ if tb_dir.exists():
649
+ hub_tb_dir = run_hub_dir / "tensorboard"
650
+ if hub_tb_dir.exists():
651
+ shutil.rmtree(hub_tb_dir)
652
+ shutil.copytree(tb_dir, hub_tb_dir)
653
+
654
+ return hub_dir
655
+
656
+
657
+ def push_run_to_hub(
658
+ hub_dir: Path,
659
+ repo_id: str,
660
+ run_dir_name: str,
661
+ private: bool = False,
662
+ commit_message: Optional[str] = None
663
+ ) -> str:
664
+ """Push run files to HuggingFace Hub."""
665
+
666
+ if not HF_HUB_AVAILABLE:
667
+ raise RuntimeError("huggingface_hub not installed")
668
+
669
+ api = HfApi()
670
+
671
+ try:
672
+ create_repo(repo_id, private=private, exist_ok=True)
673
+ except Exception as e:
674
+ print(f" [!] Repo creation note: {e}")
675
+
676
+ run_upload_dir = hub_dir / "weights" / run_dir_name
677
+
678
+ if commit_message is None:
679
+ commit_message = f"Update {run_dir_name} - {datetime.now().strftime('%Y-%m-%d %H:%M')}"
680
+
681
+ url = upload_folder(
682
+ folder_path=str(run_upload_dir),
683
+ repo_id=repo_id,
684
+ path_in_repo=f"weights/{run_dir_name}",
685
+ commit_message=commit_message
686
+ )
687
+
688
+ return url
689
+
690
+
691
+ # ============================================================================
692
+ # TRAINING LOOP V2
693
+ # ============================================================================
694
+
695
+ def train_epoch_v2(
696
+ model: nn.Module,
697
+ train_loader: DataLoader,
698
+ optimizer: torch.optim.Optimizer,
699
+ scheduler: Optional[torch.optim.lr_scheduler._LRScheduler],
700
+ config: TrainingConfigV2,
701
+ epoch: int,
702
+ tracker: MetricsTracker,
703
+ routing_metrics: RoutingMetrics,
704
+ writer: Optional['SummaryWriter'] = None
705
+ ) -> Dict[str, float]:
706
+ """Train for one epoch with V2 routing metrics."""
707
+
708
+ model.train()
709
+ device = config.device
710
+ is_v2 = config.model_version == 2
711
+
712
+ total_loss = 0.0
713
+ total_correct = 0
714
+ total_samples = 0
715
+ global_step = epoch * len(train_loader)
716
+
717
+ routing_metrics.reset()
718
+
719
+ pbar = tqdm(train_loader, desc=f"Epoch {epoch + 1}", leave=True)
720
+
721
+ for batch_idx, (images, targets) in enumerate(pbar):
722
+ images = images.to(device, non_blocking=True)
723
+ targets = targets.to(device, non_blocking=True)
724
+
725
+ # Apply mixup/cutmix
726
+ use_mixup = config.use_augmentation and config.mixup_alpha > 0
727
+ use_cutmix = config.use_augmentation and config.cutmix_alpha > 0
728
+
729
+ mixed = False
730
+ if use_mixup or use_cutmix:
731
+ r = torch.rand(1).item()
732
+ if r < 0.5:
733
+ pass
734
+ elif r < 0.75 and use_mixup:
735
+ images, targets_a, targets_b, lam = mixup_data(images, targets, config.mixup_alpha)
736
+ mixed = True
737
+ elif use_cutmix:
738
+ images, targets_a, targets_b, lam = cutmix_data(images, targets, config.cutmix_alpha)
739
+ mixed = True
740
+
741
+ # Forward pass
742
+ if is_v2:
743
+ result = model(
744
+ images,
745
+ targets=targets,
746
+ return_loss=True,
747
+ return_routing=(batch_idx % 10 == 0) # Sample routing every 10 batches
748
+ )
749
+ else:
750
+ result = model(images, targets=targets, return_loss=True)
751
+
752
+ losses = result['losses']
753
+
754
+ # Handle mixup CE loss
755
+ if mixed:
756
+ logits = result['logits']
757
+ ce_loss = lam * F.cross_entropy(logits, targets_a, label_smoothing=config.label_smoothing) + \
758
+ (1 - lam) * F.cross_entropy(logits, targets_b, label_smoothing=config.label_smoothing)
759
+ losses['ce'] = ce_loss
760
+
761
+ # Compute total loss (NO auxiliary routing loss - key finding!)
762
+ loss = (
763
+ config.ce_weight * losses['ce'] +
764
+ config.contrast_weight * losses.get('contrast', torch.tensor(0.0, device=device))
765
+ )
766
+
767
+ # Add scale CE losses
768
+ for scale in model.config.scales:
769
+ scale_ce = losses.get(f'ce_{scale}', 0.0)
770
+ if isinstance(scale_ce, torch.Tensor):
771
+ loss = loss + 0.1 * scale_ce
772
+
773
+ # Backward pass
774
+ optimizer.zero_grad()
775
+ loss.backward()
776
+
777
+ # Track routing gradient norms (verify gradients flow!)
778
+ if is_v2:
779
+ routing_metrics.update_grad_norms(model)
780
+
781
+ if config.gradient_clip > 0:
782
+ grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), config.gradient_clip)
783
+ else:
784
+ grad_norm = 0.0
785
+
786
+ optimizer.step()
787
+
788
+ if scheduler is not None and config.scheduler == "onecycle":
789
+ scheduler.step()
790
+
791
+ # Update routing metrics from forward pass
792
+ if is_v2 and result.get('routing'):
793
+ routing_metrics.update_from_routing_info(result['routing'], model)
794
+
795
+ # Compute accuracy
796
+ with torch.no_grad():
797
+ logits = result['logits']
798
+ preds = logits.argmax(dim=-1)
799
+
800
+ if mixed:
801
+ correct = (lam * (preds == targets_a).float() +
802
+ (1 - lam) * (preds == targets_b).float()).sum()
803
+ else:
804
+ correct = (preds == targets).sum()
805
+
806
+ total_correct += correct.item()
807
+ total_samples += targets.size(0)
808
+ total_loss += loss.item()
809
+
810
+ # Track metrics
811
+ def to_float(v):
812
+ return v.item() if isinstance(v, torch.Tensor) else float(v)
813
+
814
+ contrast_loss = to_float(losses.get('contrast', 0.0))
815
+ current_lr = optimizer.param_groups[0]['lr']
816
+
817
+ tracker.update(
818
+ loss=loss.item(),
819
+ ce=losses['ce'].item(),
820
+ contrast=contrast_loss,
821
+ lr=current_lr
822
+ )
823
+
824
+ # TensorBoard logging
825
+ if writer is not None and (batch_idx + 1) % config.log_interval == 0:
826
+ step = global_step + batch_idx
827
+ writer.add_scalar('train/loss_total', loss.item(), step)
828
+ writer.add_scalar('train/loss_ce', losses['ce'].item(), step)
829
+ writer.add_scalar('train/loss_contrast', contrast_loss, step)
830
+ writer.add_scalar('train/learning_rate', current_lr, step)
831
+ writer.add_scalar('train/grad_norm', to_float(grad_norm), step)
832
+
833
+ # Log routing metrics for V2
834
+ if is_v2 and config.log_routing:
835
+ routing_summary = routing_metrics.get_summary()
836
+ for k, v in routing_summary.items():
837
+ writer.add_scalar(f'routing/{k}', v, step)
838
+
839
+ # Progress bar
840
+ routing_summary = routing_metrics.get_summary()
841
+ postfix = {
842
+ 'loss': f"{tracker.get_ema('loss'):.3f}",
843
+ 'acc': f"{100.0 * total_correct / total_samples:.1f}%",
844
+ }
845
+ if is_v2 and 'grad_query' in routing_summary:
846
+ postfix['∇q'] = f"{routing_summary['grad_query']:.2f}"
847
+ if 'route_entropy' in routing_summary:
848
+ postfix['H'] = f"{routing_summary['route_entropy']:.2f}"
849
+
850
+ pbar.set_postfix(postfix)
851
+
852
+ if scheduler is not None and config.scheduler == "cosine":
853
+ scheduler.step()
854
+
855
+ return {
856
+ 'loss': total_loss / len(train_loader),
857
+ 'acc': 100.0 * total_correct / total_samples,
858
+ **routing_metrics.get_summary()
859
+ }
860
+
861
+
862
+ @torch.no_grad()
863
+ def evaluate_v2(
864
+ model: nn.Module,
865
+ test_loader: DataLoader,
866
+ config: TrainingConfigV2
867
+ ) -> Dict[str, float]:
868
+ """Evaluate on test set."""
869
+
870
+ model.eval()
871
+ device = config.device
872
+
873
+ total_loss = 0.0
874
+ total_correct = 0
875
+ total_samples = 0
876
+ scale_correct = {s: 0 for s in model.config.scales}
877
+
878
+ for images, targets in test_loader:
879
+ images = images.to(device, non_blocking=True)
880
+ targets = targets.to(device, non_blocking=True)
881
+
882
+ result = model(images, targets=targets, return_loss=True)
883
+
884
+ logits = result['logits']
885
+ losses = result['losses']
886
+
887
+ loss = losses['total']
888
+ preds = logits.argmax(dim=-1)
889
+
890
+ total_loss += loss.item() * targets.size(0)
891
+ total_correct += (preds == targets).sum().item()
892
+ total_samples += targets.size(0)
893
+
894
+ for i, scale in enumerate(model.config.scales):
895
+ scale_logits = result['scale_logits'][i]
896
+ scale_preds = scale_logits.argmax(dim=-1)
897
+ scale_correct[scale] += (scale_preds == targets).sum().item()
898
+
899
+ metrics = {
900
+ 'loss': total_loss / total_samples,
901
+ 'acc': 100.0 * total_correct / total_samples
902
+ }
903
+
904
+ for scale, correct in scale_correct.items():
905
+ metrics[f'acc_{scale}'] = 100.0 * correct / total_samples
906
+
907
+ return metrics
908
+
909
+
910
+ # ============================================================================
911
+ # MAIN TRAINING FUNCTION V2
912
+ # ============================================================================
913
+
914
+ def train_david_beans_v2(
915
+ model_config: Optional[Union[DavidBeansConfig, DavidBeansV2Config]] = None,
916
+ train_config: Optional[TrainingConfigV2] = None
917
+ ):
918
+ """Main training function for DavidBeans V1 or V2."""
919
+
920
+ print("=" * 70)
921
+ print(" DAVID-BEANS V2 TRAINING: Wormhole Routing")
922
+ print("=" * 70)
923
+ print()
924
+ print(" 🌀 WORMHOLES: Learned sparse routing")
925
+ print(" 💎 CRYSTALS: Multi-scale projection")
926
+ print()
927
+ print(" Key insight: When routing IS the task, routing learns structure")
928
+ print()
929
+ print("=" * 70)
930
+
931
+ if train_config is None:
932
+ train_config = TrainingConfigV2()
933
+
934
+ base_output_dir = Path(train_config.output_dir)
935
+ base_output_dir.mkdir(parents=True, exist_ok=True)
936
+
937
+ # =========================================================================
938
+ # FIXED: Proper checkpoint resolution
939
+ # =========================================================================
940
+ checkpoint_path = None
941
+ run_dir = None
942
+ run_dir_name = None
943
+
944
+ if train_config.resume_from:
945
+ resume_path = Path(train_config.resume_from)
946
+
947
+ # Case 1: Direct absolute/relative file path
948
+ if resume_path.is_file():
949
+ checkpoint_path = resume_path
950
+ run_dir = checkpoint_path.parent
951
+ run_dir_name = run_dir.name
952
+ print(f"\n📂 Found checkpoint file: {checkpoint_path.name}")
953
+
954
+ # Case 2: Directory path - find best/latest checkpoint inside
955
+ elif resume_path.is_dir():
956
+ checkpoint_path = find_latest_checkpoint(resume_path)
957
+ if checkpoint_path:
958
+ run_dir = resume_path
959
+ run_dir_name = resume_path.name
960
+ print(f"\n📂 Found checkpoint in dir: {checkpoint_path.name}")
961
+
962
+ # Case 3: Try as path relative to base_output_dir
963
+ else:
964
+ # Try as subdirectory name
965
+ possible_dir = base_output_dir / train_config.resume_from
966
+ if possible_dir.is_dir():
967
+ checkpoint_path = find_latest_checkpoint(possible_dir)
968
+ if checkpoint_path:
969
+ run_dir = possible_dir
970
+ run_dir_name = possible_dir.name
971
+ print(f"\n📂 Found checkpoint in: {run_dir_name}")
972
+
973
+ # Try as relative file path
974
+ if checkpoint_path is None:
975
+ possible_file = base_output_dir / train_config.resume_from
976
+ if possible_file.is_file():
977
+ checkpoint_path = possible_file
978
+ run_dir = checkpoint_path.parent
979
+ run_dir_name = run_dir.name
980
+ print(f"\n📂 Found checkpoint: {checkpoint_path.name}")
981
+
982
+ # Report if not found
983
+ if checkpoint_path is None:
984
+ print(f"\n [!] Could not find checkpoint: {train_config.resume_from}")
985
+ print(f" [!] Checked:")
986
+ print(f" - As file: {resume_path}")
987
+ print(f" - As dir: {resume_path}")
988
+ print(f" - Under {base_output_dir}")
989
+ print(f" [!] Starting fresh run instead")
990
+ else:
991
+ print(f" ✓ Will resume from: {checkpoint_path}")
992
+
993
+ # Create new run directory if not resuming
994
+ if run_dir is None:
995
+ run_number = train_config.run_number or get_next_run_number(base_output_dir)
996
+ run_dir_name = generate_run_dir_name(run_number, train_config.run_name, train_config.model_version)
997
+ run_dir = base_output_dir / run_dir_name
998
+ run_dir.mkdir(parents=True, exist_ok=True)
999
+ print(f"\n📁 New run: {run_dir_name}")
1000
+ else:
1001
+ print(f"\n📁 Resuming run: {run_dir_name}")
1002
+
1003
+ output_dir = run_dir
1004
+
1005
+ # =========================================================================
1006
+ # Model config - load from checkpoint if resuming, else use provided/default
1007
+ # =========================================================================
1008
+ if checkpoint_path and checkpoint_path.exists() and model_config is None:
1009
+ # Try to load config from checkpoint
1010
+ try:
1011
+ ckpt = torch.load(checkpoint_path, map_location='cpu')
1012
+ if 'model_config' in ckpt:
1013
+ saved_config = ckpt['model_config']
1014
+ print(f" ✓ Loading model config from checkpoint")
1015
+ if train_config.model_version == 2:
1016
+ model_config = DavidBeansV2Config(**saved_config)
1017
+ else:
1018
+ model_config = DavidBeansConfig(**saved_config)
1019
+ except Exception as e:
1020
+ print(f" [!] Could not load config from checkpoint: {e}")
1021
+
1022
+ # Create default config if still None
1023
+ if model_config is None:
1024
+ if train_config.model_version == 2:
1025
+ model_config = DavidBeansV2Config(
1026
+ image_size=train_config.image_size,
1027
+ patch_size=4,
1028
+ dim=512,
1029
+ num_layers=4,
1030
+ num_heads=8,
1031
+ num_wormholes=8,
1032
+ wormhole_temperature=0.1,
1033
+ wormhole_mode="hybrid",
1034
+ num_tiles=16,
1035
+ tile_wormholes=4,
1036
+ scales=[64, 128, 256, 384, 512],
1037
+ num_classes=100,
1038
+ contrast_weight=train_config.contrast_weight,
1039
+ dropout=0.1
1040
+ )
1041
+ else:
1042
+ model_config = DavidBeansConfig(
1043
+ image_size=train_config.image_size,
1044
+ patch_size=4,
1045
+ dim=512,
1046
+ num_layers=4,
1047
+ num_heads=8,
1048
+ num_experts=5,
1049
+ k_neighbors=16,
1050
+ cantor_weight=0.3,
1051
+ scales=[64, 128, 256, 384, 512],
1052
+ num_classes=100,
1053
+ dropout=0.1
1054
+ )
1055
+
1056
+ device = train_config.device
1057
+ print(f"\nDevice: {device}")
1058
+ print(f"Model version: V{train_config.model_version}")
1059
+
1060
+ # Data
1061
+ print("\nLoading data...")
1062
+ train_loader, test_loader, num_classes = get_dataloaders(train_config)
1063
+ print(f" Dataset: {train_config.dataset}")
1064
+ print(f" Train: {len(train_loader.dataset)}, Test: {len(test_loader.dataset)}")
1065
+ print(f" Classes: {num_classes}")
1066
+
1067
+ model_config.num_classes = num_classes
1068
+
1069
+ # Model
1070
+ print("\nBuilding model...")
1071
+ if train_config.model_version == 2:
1072
+ model = DavidBeansV2(model_config)
1073
+ else:
1074
+ model = DavidBeans(model_config)
1075
+
1076
+ model = model.to(device)
1077
+ print(f"\n{model}")
1078
+
1079
+ num_params = sum(p.numel() for p in model.parameters())
1080
+ print(f"\nParameters: {num_params:,}")
1081
+
1082
+ # Optimizer
1083
+ print("\nSetting up optimizer...")
1084
+
1085
+ decay_params = []
1086
+ no_decay_params = []
1087
+
1088
+ for name, param in model.named_parameters():
1089
+ if not param.requires_grad:
1090
+ continue
1091
+ if 'bias' in name or 'norm' in name or 'embedding' in name:
1092
+ no_decay_params.append(param)
1093
+ else:
1094
+ decay_params.append(param)
1095
+
1096
+ optimizer = AdamW([
1097
+ {'params': decay_params, 'weight_decay': train_config.weight_decay},
1098
+ {'params': no_decay_params, 'weight_decay': 0.0}
1099
+ ], lr=train_config.learning_rate, betas=train_config.betas)
1100
+
1101
+ if train_config.scheduler == "cosine":
1102
+ scheduler = CosineAnnealingLR(
1103
+ optimizer,
1104
+ T_max=train_config.epochs - train_config.warmup_epochs,
1105
+ eta_min=train_config.min_lr
1106
+ )
1107
+ elif train_config.scheduler == "onecycle":
1108
+ scheduler = OneCycleLR(
1109
+ optimizer,
1110
+ max_lr=train_config.learning_rate,
1111
+ epochs=train_config.epochs,
1112
+ steps_per_epoch=len(train_loader),
1113
+ pct_start=train_config.warmup_epochs / train_config.epochs
1114
+ )
1115
+ else:
1116
+ scheduler = None
1117
+
1118
+ print(f" Optimizer: AdamW (lr={train_config.learning_rate}, wd={train_config.weight_decay})")
1119
+ print(f" Scheduler: {train_config.scheduler}")
1120
+
1121
+ tracker = MetricsTracker()
1122
+ routing_metrics = RoutingMetrics()
1123
+ best_acc = 0.0
1124
+ start_epoch = 0
1125
+
1126
+ # =========================================================================
1127
+ # Load checkpoint weights and optimizer state
1128
+ # =========================================================================
1129
+ if checkpoint_path and checkpoint_path.exists():
1130
+ start_epoch, best_acc = load_checkpoint(checkpoint_path, model, optimizer, device)
1131
+
1132
+ # Advance scheduler to correct position
1133
+ if scheduler is not None and train_config.scheduler == "cosine":
1134
+ for _ in range(start_epoch):
1135
+ scheduler.step()
1136
+ print(f" ✓ Advanced scheduler to epoch {start_epoch}")
1137
+
1138
+ # TensorBoard
1139
+ writer = None
1140
+ if train_config.use_tensorboard and TENSORBOARD_AVAILABLE:
1141
+ tb_dir = output_dir / "tensorboard"
1142
+ tb_dir.mkdir(parents=True, exist_ok=True)
1143
+ writer = SummaryWriter(log_dir=str(tb_dir))
1144
+ print(f" TensorBoard: {tb_dir}")
1145
+
1146
+ # Save configs
1147
+ with open(output_dir / "config.json", "w") as f:
1148
+ json.dump({**model_config.__dict__, "architecture": f"DavidBeans_V{train_config.model_version}"},
1149
+ f, indent=2, default=str)
1150
+ with open(output_dir / "training_config.json", "w") as f:
1151
+ json.dump(train_config.to_dict(), f, indent=2, default=str)
1152
+
1153
+ # Training loop
1154
+ print("\n" + "=" * 70)
1155
+ print(" TRAINING")
1156
+ print("=" * 70)
1157
+
1158
+ for epoch in range(start_epoch, train_config.epochs):
1159
+ epoch_start = time.time()
1160
+
1161
+ # Warmup
1162
+ if epoch < train_config.warmup_epochs and train_config.scheduler == "cosine":
1163
+ warmup_lr = train_config.learning_rate * (epoch + 1) / train_config.warmup_epochs
1164
+ for param_group in optimizer.param_groups:
1165
+ param_group['lr'] = warmup_lr
1166
+
1167
+ train_metrics = train_epoch_v2(
1168
+ model, train_loader, optimizer, scheduler,
1169
+ train_config, epoch, tracker, routing_metrics, writer
1170
+ )
1171
+
1172
+ test_metrics = evaluate_v2(model, test_loader, train_config)
1173
+
1174
+ epoch_time = time.time() - epoch_start
1175
+
1176
+ # TensorBoard
1177
+ if writer is not None:
1178
+ writer.add_scalar('epoch/train_loss', train_metrics['loss'], epoch)
1179
+ writer.add_scalar('epoch/train_acc', train_metrics['acc'], epoch)
1180
+ writer.add_scalar('epoch/test_loss', test_metrics['loss'], epoch)
1181
+ writer.add_scalar('epoch/test_acc', test_metrics['acc'], epoch)
1182
+
1183
+ for scale in model.config.scales:
1184
+ writer.add_scalar(f'scales/acc_{scale}', test_metrics[f'acc_{scale}'], epoch)
1185
+
1186
+ # Print summary - show ALL scales
1187
+ scale_accs = " | ".join([f"{s}:{test_metrics[f'acc_{s}']:.1f}%" for s in model.config.scales])
1188
+ star = "★" if test_metrics['acc'] > best_acc else ""
1189
+
1190
+ routing_info = ""
1191
+ if train_config.model_version == 2 and 'grad_query' in train_metrics:
1192
+ routing_info = f" | ∇q:{train_metrics.get('grad_query', 0):.2f}"
1193
+
1194
+ print(f" → Train: {train_metrics['acc']:.1f}% | Test: {test_metrics['acc']:.1f}% | "
1195
+ f"[{scale_accs}]{routing_info} | {epoch_time:.0f}s {star}")
1196
+
1197
+ # Save best model
1198
+ if test_metrics['acc'] > best_acc:
1199
+ best_acc = test_metrics['acc']
1200
+ torch.save({
1201
+ 'epoch': epoch,
1202
+ 'model_state_dict': model.state_dict(),
1203
+ 'optimizer_state_dict': optimizer.state_dict(),
1204
+ 'best_acc': best_acc,
1205
+ 'model_config': model_config.__dict__,
1206
+ 'train_config': train_config.to_dict()
1207
+ }, output_dir / "best_model.pt")
1208
+
1209
+ # Periodic checkpoint
1210
+ if (epoch + 1) % train_config.save_interval == 0:
1211
+ torch.save({
1212
+ 'epoch': epoch,
1213
+ 'model_state_dict': model.state_dict(),
1214
+ 'optimizer_state_dict': optimizer.state_dict(),
1215
+ 'best_acc': best_acc,
1216
+ 'model_config': model_config.__dict__,
1217
+ 'train_config': train_config.to_dict()
1218
+ }, output_dir / f"checkpoint_epoch_{epoch + 1}.pt")
1219
+
1220
+ # Upload to hub
1221
+ if train_config.push_to_hub and HF_HUB_AVAILABLE:
1222
+ try:
1223
+ hub_dir = prepare_run_for_hub(
1224
+ model=model,
1225
+ model_config=model_config,
1226
+ train_config=train_config,
1227
+ best_acc=best_acc,
1228
+ output_dir=output_dir,
1229
+ run_dir_name=run_dir_name,
1230
+ training_history=tracker.get_history()
1231
+ )
1232
+ push_run_to_hub(
1233
+ hub_dir=hub_dir,
1234
+ repo_id=train_config.hub_repo_id,
1235
+ run_dir_name=run_dir_name,
1236
+ commit_message=f"Epoch {epoch + 1} - {best_acc:.2f}% acc"
1237
+ )
1238
+ print(f" 📤 Uploaded to hub")
1239
+ except Exception as e:
1240
+ print(f" [!] Hub upload failed: {e}")
1241
+
1242
+ tracker.end_epoch()
1243
+
1244
+ # Final summary
1245
+ print("\n" + "=" * 70)
1246
+ print(" TRAINING COMPLETE")
1247
+ print("=" * 70)
1248
+ print(f"\n Best Test Accuracy: {best_acc:.2f}%")
1249
+ print(f" Model saved to: {output_dir / 'best_model.pt'}")
1250
+
1251
+ if writer is not None:
1252
+ writer.close()
1253
+
1254
+ return model, best_acc
1255
+
1256
+
1257
+ # ============================================================================
1258
+ # PRESETS
1259
+ # ============================================================================
1260
+
1261
+ def train_cifar100_v2_wormhole(
1262
+ run_name: str = "wormhole_base",
1263
+ push_to_hub: bool = False,
1264
+ resume: bool = False
1265
+ ):
1266
+ """CIFAR-100 with V2 wormhole routing."""
1267
+
1268
+ model_config = DavidBeansV2Config(
1269
+ image_size=32,
1270
+ patch_size=2,
1271
+ dim=512,
1272
+ num_layers=4,
1273
+ num_heads=16,
1274
+ # Wormhole routing parameters
1275
+ num_wormholes=16,
1276
+ wormhole_temperature=0.1,
1277
+ wormhole_mode="hybrid",
1278
+ # Tessellation parameters
1279
+ num_tiles=16,
1280
+ tile_wormholes=4,
1281
+ # Crystal head
1282
+ scales=[64, 128, 256, 512, 1024],
1283
+ num_classes=100,
1284
+ contrast_temperature=0.07,
1285
+ contrast_weight=0.5,
1286
+ dropout=0.1
1287
+ )
1288
+
1289
+ train_config = TrainingConfigV2(
1290
+ run_name=run_name,
1291
+ model_version=2,
1292
+ dataset="cifar100",
1293
+ epochs=300,
1294
+ batch_size=512,
1295
+ learning_rate=3e-4,
1296
+ weight_decay=0.05,
1297
+ warmup_epochs=15,
1298
+ # Loss weights (no auxiliary routing loss!)
1299
+ ce_weight=1.0,
1300
+ contrast_weight=0.5,
1301
+ # Augmentation
1302
+ label_smoothing=0.1,
1303
+ mixup_alpha=0.2,
1304
+ cutmix_alpha=1.0,
1305
+ # Output
1306
+ output_dir="./checkpoints/cifar100_v2",
1307
+ resume_from=None, #"./checkpoints/cifar100_v2/run_002_v2_16patch_4tilewormholes_d768_4layer_20251130_045437/best_model.pt",
1308
+ # Hub
1309
+ push_to_hub=push_to_hub,
1310
+ hub_repo_id="AbstractPhil/geovit-david-beans",
1311
+ # Routing logging
1312
+ log_routing=True
1313
+ )
1314
+
1315
+ return train_david_beans_v2(model_config, train_config)
1316
+
1317
+
1318
+ def train_cifar100_v1_baseline(
1319
+ run_name: str = "v1_baseline",
1320
+ push_to_hub: bool = False,
1321
+ resume: bool = False
1322
+ ):
1323
+ """CIFAR-100 with V1 (fixed Cantor routing) for comparison."""
1324
+
1325
+ model_config = DavidBeansConfig(
1326
+ image_size=32,
1327
+ patch_size=4,
1328
+ dim=512,
1329
+ num_layers=4,
1330
+ num_heads=8,
1331
+ num_experts=5,
1332
+ k_neighbors=16,
1333
+ cantor_weight=0.3,
1334
+ scales=[64, 128, 256, 384, 512],
1335
+ num_classes=100,
1336
+ dropout=0.1
1337
+ )
1338
+
1339
+ train_config = TrainingConfigV2(
1340
+ run_name=run_name,
1341
+ model_version=1,
1342
+ dataset="cifar100",
1343
+ epochs=200,
1344
+ batch_size=128,
1345
+ learning_rate=3e-4,
1346
+ weight_decay=0.05,
1347
+ warmup_epochs=10,
1348
+ ce_weight=1.0,
1349
+ contrast_weight=0.5,
1350
+ label_smoothing=0.1,
1351
+ mixup_alpha=0.2,
1352
+ cutmix_alpha=1.0,
1353
+ output_dir="./checkpoints/cifar100_v1",
1354
+ resume_from="latest" if resume else None,
1355
+ push_to_hub=push_to_hub,
1356
+ hub_repo_id="AbstractPhil/geovit-david-beans",
1357
+ log_routing=False
1358
+ )
1359
+
1360
+ return train_david_beans_v2(model_config, train_config)
1361
+
1362
+
1363
+ # ============================================================================
1364
+ # MAIN
1365
+ # ============================================================================
1366
+
1367
+ if __name__ == "__main__":
1368
+
1369
+ # =====================================================
1370
+ # CONFIGURATION
1371
+ # =====================================================
1372
+
1373
+ PRESET = "v2_wormhole" # "v1_baseline", "v2_wormhole", "test"
1374
+ RESUME = False
1375
+ RUN_NAME = "5scale_2x2patch_4tilewormholes_d512_4layer"
1376
+ PUSH_TO_HUB = True
1377
+
1378
+ # =====================================================
1379
+ # RUN
1380
+ # =====================================================
1381
+
1382
+ if PRESET == "test":
1383
+ print("🧪 Quick test...")
1384
+ model_config = DavidBeansV2Config(
1385
+ image_size=32, patch_size=4, dim=128, num_layers=2,
1386
+ num_heads=4, num_wormholes=4, num_tiles=8,
1387
+ scales=[32, 64, 128], num_classes=10
1388
+ )
1389
+ train_config = TrainingConfigV2(
1390
+ run_name="test", model_version=2,
1391
+ epochs=2, batch_size=32,
1392
+ use_augmentation=False, mixup_alpha=0.0, cutmix_alpha=0.0
1393
+ )
1394
+ model, acc = train_david_beans_v2(model_config, train_config)
1395
+
1396
+ elif PRESET == "v1_baseline":
1397
+ print("🫘💎 Training DavidBeans V1 (Cantor routing)...")
1398
+ model, acc = train_cifar100_v1_baseline(
1399
+ run_name=RUN_NAME,
1400
+ push_to_hub=PUSH_TO_HUB,
1401
+ resume=RESUME
1402
+ )
1403
+
1404
+ elif PRESET == "v2_wormhole":
1405
+ print("💎 Training DavidBeans V2 (Wormhole routing)...")
1406
+ model, acc = train_cifar100_v2_wormhole(
1407
+ run_name=RUN_NAME,
1408
+ push_to_hub=PUSH_TO_HUB,
1409
+ resume=RESUME
1410
+ )
1411
+
1412
+ else:
1413
+ raise ValueError(f"Unknown preset: {PRESET}")
1414
+
1415
+ print(f"\n🎉 Done! Best accuracy: {acc:.2f}%")