HarleyCooper commited on
Commit
48373f2
·
verified ·
1 Parent(s): 52b43ff

chore: restore custom modeling

Browse files
Files changed (1) hide show
  1. modeling_nanochat.py +365 -0
modeling_nanochat.py ADDED
@@ -0,0 +1,365 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Hugging Face-compatible nanochat Transformer implementation.
3
+
4
+ This file mirrors the architecture used during training (RoPE, RMSNorm,
5
+ multi-query attention, relu^2 MLP, untied embeddings, logits softcap) while
6
+ presenting the familiar `PreTrainedModel` interface so that checkpoints can be
7
+ served directly from the Hugging Face Hub.
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ import math
13
+ from typing import Optional, Tuple, Union
14
+
15
+ import torch
16
+ import torch.nn as nn
17
+ import torch.nn.functional as F
18
+ from torch import Tensor
19
+
20
+ from transformers.configuration_utils import PretrainedConfig
21
+ from transformers.modeling_outputs import CausalLMOutputWithPast
22
+ from transformers.modeling_utils import PreTrainedModel
23
+ from transformers.utils import logging
24
+ from transformers import AutoConfig, AutoModelForCausalLM
25
+
26
+ logger = logging.get_logger(__name__)
27
+
28
+
29
+ class NanoChatConfig(PretrainedConfig):
30
+ model_type = "nanochat"
31
+
32
+ def __init__(
33
+ self,
34
+ vocab_size=65536,
35
+ sequence_len=2048,
36
+ n_layer=20,
37
+ n_head=10,
38
+ n_kv_head=10,
39
+ n_embd=1280,
40
+ rotary_dim=None,
41
+ activation_function="relu_squared",
42
+ use_rope=True,
43
+ use_qk_norm=True,
44
+ tie_word_embeddings=False,
45
+ softcap=15.0,
46
+ bos_token_id=1,
47
+ eos_token_id=1,
48
+ pad_token_id=None,
49
+ **kwargs,
50
+ ):
51
+ super().__init__(
52
+ bos_token_id=bos_token_id,
53
+ eos_token_id=eos_token_id,
54
+ pad_token_id=pad_token_id,
55
+ **kwargs,
56
+ )
57
+ self.vocab_size = vocab_size
58
+ self.sequence_len = sequence_len
59
+ self.n_layer = n_layer
60
+ self.n_head = n_head
61
+ self.n_kv_head = n_kv_head
62
+ self.n_embd = n_embd
63
+ self.rotary_dim = rotary_dim or (n_embd // n_head)
64
+ self.activation_function = activation_function
65
+ self.use_rope = use_rope
66
+ self.use_qk_norm = use_qk_norm
67
+ self.tie_word_embeddings = tie_word_embeddings
68
+ self.softcap = softcap
69
+
70
+ # Aliases for transformers compatibility
71
+ self.num_hidden_layers = n_layer
72
+ self.hidden_size = n_embd
73
+ self.num_attention_heads = n_head
74
+ self.num_key_value_heads = n_kv_head
75
+
76
+
77
+ def rms_norm(x: Tensor) -> Tensor:
78
+ return F.rms_norm(x, (x.size(-1),))
79
+
80
+
81
+ def relu_squared(x: Tensor) -> Tensor:
82
+ return F.relu(x) ** 2
83
+
84
+
85
+ def rotate_half(x: Tensor) -> Tensor:
86
+ x1, x2 = x.chunk(2, dim=-1)
87
+ return torch.cat((-x2, x1), dim=-1)
88
+
89
+
90
+ def apply_rotary_emb(q: Tensor, k: Tensor, cos: Tensor, sin: Tensor) -> Tuple[Tensor, Tensor]:
91
+ q = (q * cos) + (rotate_half(q) * sin)
92
+ k = (k * cos) + (rotate_half(k) * sin)
93
+ return q, k
94
+
95
+
96
+ def repeat_kv(x: Tensor, n_rep: int) -> Tensor:
97
+ if n_rep == 1:
98
+ return x
99
+ b, n_kv_heads, seq_len, head_dim = x.shape
100
+ x = x[:, :, None, :, :].expand(b, n_kv_heads, n_rep, seq_len, head_dim)
101
+ return x.reshape(b, n_kv_heads * n_rep, seq_len, head_dim)
102
+
103
+
104
+ class NanoChatAttention(nn.Module):
105
+ def __init__(self, config: NanoChatConfig):
106
+ super().__init__()
107
+ self.config = config
108
+ self.n_head = config.n_head
109
+ self.n_kv_head = config.n_kv_head
110
+ self.head_dim = config.n_embd // config.n_head
111
+ if config.n_embd % config.n_head != 0:
112
+ raise ValueError("Embedding dimension must be divisible by number of heads")
113
+
114
+ self.q_proj = nn.Linear(config.n_embd, self.n_head * self.head_dim, bias=False)
115
+ self.k_proj = nn.Linear(config.n_embd, self.n_kv_head * self.head_dim, bias=False)
116
+ self.v_proj = nn.Linear(config.n_embd, self.n_kv_head * self.head_dim, bias=False)
117
+ self.out_proj = nn.Linear(config.n_embd, config.n_embd, bias=False)
118
+
119
+ def forward(
120
+ self,
121
+ hidden_states: Tensor,
122
+ cos: Tensor,
123
+ sin: Tensor,
124
+ past_key_value: Optional[Tuple[Tensor, Tensor]] = None,
125
+ use_cache: bool = False,
126
+ ) -> Tuple[Tensor, Optional[Tuple[Tensor, Tensor]]]:
127
+ bsz, q_len, _ = hidden_states.shape
128
+
129
+ query = self.q_proj(hidden_states)
130
+ key = self.k_proj(hidden_states)
131
+ value = self.v_proj(hidden_states)
132
+
133
+ query = query.view(bsz, q_len, self.n_head, self.head_dim).transpose(1, 2)
134
+ key = key.view(bsz, q_len, self.n_kv_head, self.head_dim).transpose(1, 2)
135
+ value = value.view(bsz, q_len, self.n_kv_head, self.head_dim).transpose(1, 2)
136
+
137
+ query, key = apply_rotary_emb(query, key, cos, sin)
138
+ if self.config.use_qk_norm:
139
+ query = rms_norm(query)
140
+ key = rms_norm(key)
141
+
142
+ if past_key_value is not None:
143
+ past_k, past_v = past_key_value
144
+ if past_k is not None and past_v is not None:
145
+ key = torch.cat([past_k, key], dim=2)
146
+ value = torch.cat([past_v, value], dim=2)
147
+
148
+ present = (key, value) if use_cache else None
149
+
150
+ key_for_scores = repeat_kv(key, self.n_head // self.n_kv_head)
151
+ value_for_scores = repeat_kv(value, self.n_head // self.n_kv_head)
152
+
153
+ attn_scores = torch.matmul(query, key_for_scores.transpose(-1, -2)) / math.sqrt(self.head_dim)
154
+ attn_scores = attn_scores.to(torch.float32)
155
+
156
+ # causal mask that accounts for the prefix introduced by past key values
157
+ if attn_scores.size(-1) != q_len:
158
+ total_k = attn_scores.size(-1)
159
+ past_len = total_k - q_len
160
+ mask = torch.arange(total_k, device=attn_scores.device)
161
+ causal = mask.unsqueeze(0) <= (mask.new_tensor(past_len) + torch.arange(q_len, device=mask.device).unsqueeze(1))
162
+ attn_scores = attn_scores.masked_fill(~causal, torch.finfo(attn_scores.dtype).min)
163
+ else:
164
+ mask = torch.triu(torch.ones((q_len, q_len), device=attn_scores.device, dtype=torch.bool), diagonal=1)
165
+ attn_scores = attn_scores.masked_fill(mask, torch.finfo(attn_scores.dtype).min)
166
+
167
+ attn_weights = F.softmax(attn_scores, dim=-1, dtype=torch.float32)
168
+ attn_output = torch.matmul(attn_weights, value_for_scores).to(value_for_scores.dtype)
169
+
170
+ attn_output = attn_output.transpose(1, 2).contiguous().view(bsz, q_len, -1)
171
+ attn_output = self.out_proj(attn_output)
172
+
173
+ return attn_output, present
174
+
175
+
176
+ class NanoChatMLP(nn.Module):
177
+ def __init__(self, config: NanoChatConfig):
178
+ super().__init__()
179
+ hidden_dim = config.n_embd * 4
180
+ self.fc = nn.Linear(config.n_embd, hidden_dim, bias=False)
181
+ self.proj = nn.Linear(hidden_dim, config.n_embd, bias=False)
182
+
183
+ def forward(self, x: Tensor) -> Tensor:
184
+ return self.proj(relu_squared(self.fc(x)))
185
+
186
+
187
+ class NanoChatBlock(nn.Module):
188
+ def __init__(self, config: NanoChatConfig):
189
+ super().__init__()
190
+ self.attn = NanoChatAttention(config)
191
+ self.mlp = NanoChatMLP(config)
192
+
193
+ def forward(
194
+ self,
195
+ x: Tensor,
196
+ cos: Tensor,
197
+ sin: Tensor,
198
+ past_key_value: Optional[Tuple[Tensor, Tensor]] = None,
199
+ use_cache: bool = False,
200
+ ) -> Tuple[Tensor, Optional[Tuple[Tensor, Tensor]]]:
201
+ residual = x
202
+ attn_input = rms_norm(x)
203
+ attn_output, present = self.attn(attn_input, cos, sin, past_key_value, use_cache)
204
+ x = residual + attn_output
205
+ mlp_input = rms_norm(x)
206
+ x = x + self.mlp(mlp_input)
207
+ return x, present
208
+
209
+
210
+ class NanoChatModel(nn.Module):
211
+ def __init__(self, config: NanoChatConfig):
212
+ super().__init__()
213
+ self.config = config
214
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.n_embd)
215
+ self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
216
+ self.blocks = nn.ModuleList([NanoChatBlock(config) for _ in range(config.n_layer)])
217
+
218
+ self.softcap = config.softcap
219
+ self._rope_cache: Optional[Tuple[Tensor, Tensor]] = None
220
+ self._rope_cache_length = 0
221
+
222
+ def _build_rope_cache(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> Tuple[Tensor, Tensor]:
223
+ if self._rope_cache is not None and self._rope_cache_length >= seq_len and self._rope_cache[0].device == device:
224
+ return self._rope_cache
225
+
226
+ head_dim = self.config.n_embd // self.config.n_head
227
+ theta = 10000.0 ** (-torch.arange(0, head_dim, 2, device=device, dtype=torch.float32) / head_dim)
228
+ position_ids = torch.arange(seq_len, device=device, dtype=torch.float32)
229
+ freqs = torch.einsum("i,j->ij", position_ids, theta)
230
+ cos = freqs.cos()[None, None, :, :]
231
+ sin = freqs.sin()[None, None, :, :]
232
+ # Expand to full head_dim (from head_dim/2 to head_dim)
233
+ cos = torch.repeat_interleave(cos, repeats=2, dim=-1)
234
+ sin = torch.repeat_interleave(sin, repeats=2, dim=-1)
235
+ cos = cos.to(dtype=dtype)
236
+ sin = sin.to(dtype=dtype)
237
+
238
+ self._rope_cache = (cos, sin)
239
+ self._rope_cache_length = seq_len
240
+ return cos, sin
241
+
242
+ def forward(
243
+ self,
244
+ input_ids: Tensor,
245
+ past_key_values: Optional[Tuple[Tuple[Tensor, Tensor], ...]] = None,
246
+ attention_mask: Optional[Tensor] = None,
247
+ labels: Optional[Tensor] = None,
248
+ use_cache: bool = False,
249
+ ) -> Tuple[Tensor, Optional[Tuple[Tuple[Tensor, Tensor], ...]]]:
250
+ del attention_mask # attention masking is handled implicitly via causal masking
251
+ bsz, seq_len = input_ids.shape
252
+ device = input_ids.device
253
+ dtype = self.embed_tokens.weight.dtype
254
+
255
+ inputs_embeds = self.embed_tokens(input_ids)
256
+ x = inputs_embeds
257
+
258
+ past_key_values = past_key_values or tuple([None] * len(self.blocks))
259
+ # Handle DynamicCache which may have (None, None) tuples
260
+ past_length = 0
261
+ if past_key_values and past_key_values[0] is not None:
262
+ if past_key_values[0][0] is not None:
263
+ past_length = past_key_values[0][0].size(2)
264
+
265
+ cos_full, sin_full = self._build_rope_cache(seq_len + past_length, device, dtype)
266
+ cos = cos_full[:, :, past_length:, :]
267
+ sin = sin_full[:, :, past_length:, :]
268
+ new_past_key_values = [] if use_cache else None
269
+
270
+ for layer, block in enumerate(self.blocks):
271
+ past = past_key_values[layer] if past_key_values[layer] is not None else None
272
+ x, present = block(x, cos, sin, past, use_cache)
273
+ if use_cache:
274
+ new_past_key_values.append(present)
275
+
276
+ x = rms_norm(x)
277
+ logits = self.lm_head(x)
278
+
279
+ if self.softcap is not None and self.softcap > 0:
280
+ logits = self.softcap * torch.tanh(logits / self.softcap)
281
+
282
+ loss = None
283
+ if labels is not None:
284
+ loss = F.cross_entropy(logits.view(-1, logits.size(-1)), labels.view(-1), ignore_index=-1)
285
+
286
+ return logits, loss, tuple(new_past_key_values) if use_cache else None
287
+
288
+
289
+ class NanoChatForCausalLM(PreTrainedModel):
290
+ config_class = NanoChatConfig
291
+ base_model_prefix = "model"
292
+ supports_gradient_checkpointing = False
293
+
294
+ def __init__(self, config: NanoChatConfig):
295
+ super().__init__(config)
296
+ self.model = NanoChatModel(config)
297
+ if config.tie_word_embeddings:
298
+ self.tie_weights()
299
+
300
+ def get_input_embeddings(self) -> nn.Embedding:
301
+ return self.model.embed_tokens
302
+
303
+ def set_input_embeddings(self, value: nn.Embedding) -> None:
304
+ self.model.embed_tokens = value
305
+
306
+ def get_output_embeddings(self) -> nn.Linear:
307
+ return self.model.lm_head
308
+
309
+ def set_output_embeddings(self, new_embeddings: nn.Linear) -> None:
310
+ self.model.lm_head = new_embeddings
311
+
312
+ def prepare_inputs_for_generation(
313
+ self,
314
+ input_ids: Tensor,
315
+ past_key_values: Optional[Tuple[Tuple[Tensor, Tensor], ...]] = None,
316
+ **kwargs,
317
+ ):
318
+ if past_key_values:
319
+ input_ids = input_ids[:, -1:]
320
+ return {"input_ids": input_ids, "past_key_values": past_key_values, "use_cache": kwargs.get("use_cache", True)}
321
+
322
+ def _reorder_cache(self, past_key_values, beam_idx):
323
+ reordered = []
324
+ for layer_past in past_key_values:
325
+ reordered.append(
326
+ (
327
+ layer_past[0].index_select(0, beam_idx),
328
+ layer_past[1].index_select(0, beam_idx),
329
+ )
330
+ )
331
+ return tuple(reordered)
332
+
333
+ def forward(
334
+ self,
335
+ input_ids: Tensor,
336
+ attention_mask: Optional[Tensor] = None,
337
+ past_key_values: Optional[Tuple[Tuple[Tensor, Tensor], ...]] = None,
338
+ labels: Optional[Tensor] = None,
339
+ use_cache: bool = False,
340
+ **kwargs,
341
+ ) -> CausalLMOutputWithPast:
342
+ logits, loss, new_past = self.model(
343
+ input_ids=input_ids,
344
+ past_key_values=past_key_values,
345
+ attention_mask=attention_mask,
346
+ labels=labels,
347
+ use_cache=use_cache,
348
+ )
349
+ return CausalLMOutputWithPast(
350
+ loss=loss,
351
+ logits=logits,
352
+ past_key_values=new_past,
353
+ )
354
+
355
+
356
+ try:
357
+ AutoConfig.register("nanochat", NanoChatConfig)
358
+ except ValueError:
359
+ # Transformers build already provides this registration (e.g., nanochat branch); reuse it.
360
+ pass
361
+
362
+ try:
363
+ AutoModelForCausalLM.register(NanoChatConfig, NanoChatForCausalLM)
364
+ except ValueError:
365
+ pass