nanochat561 / tokenization_nanochat.py
HarleyCooper's picture
tokenizer: fix init ordering
06dcd9f verified
raw
history blame
9.7 kB
"""
Hugging Face tokenizer wrapper for nanochat's rustbpe+tiktoken vocabulary.
"""
from __future__ import annotations
import os
import pickle
from typing import Dict, List, Optional, Tuple
import tiktoken
from transformers.tokenization_utils import PreTrainedTokenizer
from transformers import AutoTokenizer
try:
from .configuration_nanochat import NanoChatConfig # type: ignore
except ImportError:
from configuration_nanochat import NanoChatConfig # type: ignore
SPECIAL_TOKENS = [
"<|bos|>",
"<|user_start|>",
"<|user_end|>",
"<|assistant_start|>",
"<|assistant_end|>",
"<|python_start|>",
"<|python_end|>",
"<|output_start|>",
"<|output_end|>",
]
class NanoChatTokenizer(PreTrainedTokenizer):
vocab_files_names = {"tokenizer_file": "tokenizer/tokenizer.pkl"}
model_input_names = ["input_ids", "attention_mask"]
def __init__(self, tokenizer_file: Optional[str] = None, **kwargs):
if tokenizer_file is None:
raise ValueError("tokenizer_file must be provided")
# Keep a copy of the kwargs in case we need hub metadata.
init_kwargs = dict(kwargs)
bos_token = kwargs.pop("bos_token", "<|bos|>")
eos_token = kwargs.pop("eos_token", "<|bos|>")
unk_token = kwargs.pop("unk_token", "<|bos|>")
pad_token = kwargs.pop("pad_token", "<|bos|>")
resolved_path = tokenizer_file
if not os.path.isfile(resolved_path):
repo_id = init_kwargs.get("name_or_path") or init_kwargs.get("pretrained_model_name_or_path")
if repo_id:
from huggingface_hub import hf_hub_download
resolved_path = hf_hub_download(
repo_id,
tokenizer_file,
revision=init_kwargs.get("revision"),
subfolder=init_kwargs.get("subfolder"),
cache_dir=init_kwargs.get("cache_dir"),
token=init_kwargs.get("token"),
)
if not os.path.isfile(resolved_path):
raise FileNotFoundError(f"Cannot locate tokenizer state at {tokenizer_file}")
with open(resolved_path, "rb") as handle:
self._encoding: tiktoken.Encoding = pickle.load(handle)
self._id_to_token: List[str] = []
for token_id, token_bytes in enumerate(self._encoding.token_byte_values()):
token = token_bytes.decode("utf-8", errors="replace")
self._id_to_token.append(token)
self.vocab: Dict[str, int] = {token: idx for idx, token in enumerate(self._id_to_token)}
self._special_token_ids: Dict[str, int] = {}
for special in SPECIAL_TOKENS:
special_id = self._encoding.encode_single_token(special)
if special_id >= len(self._id_to_token):
self._id_to_token.extend([""] * (special_id - len(self._id_to_token) + 1))
self._id_to_token[special_id] = special
self.vocab[special] = special_id
self._special_token_ids[special] = special_id
super().__init__(
bos_token=bos_token,
eos_token=eos_token,
unk_token=unk_token,
pad_token=pad_token,
**kwargs,
)
# Ensure legacy defaults regardless of config values.
self.bos_token = bos_token or "<|bos|>"
self.eos_token = eos_token or "<|bos|>"
self.unk_token = unk_token or "<|bos|>"
self.pad_token = pad_token or "<|bos|>"
self.bos_token_id = self.vocab[self.bos_token]
self.eos_token_id = self.vocab[self.eos_token]
self.unk_token_id = self.vocab[self.unk_token]
self.pad_token_id = self.vocab[self.pad_token]
@property
def vocab_size(self) -> int:
return len(self._id_to_token)
def get_vocab(self) -> Dict[str, int]:
return dict(self.vocab)
def _tokenize(self, text: str) -> List[str]:
token_ids = self._encoding.encode(text, allowed_special=set())
return [self._id_to_token[token_id] for token_id in token_ids]
def _convert_token_to_id(self, token: str) -> int:
return self.vocab.get(token, self.unk_token_id)
def _convert_id_to_token(self, index: int) -> str:
return self._id_to_token[index]
def build_inputs_with_special_tokens(
self,
token_ids_0: List[int],
token_ids_1: Optional[List[int]] = None,
) -> List[int]:
if token_ids_1 is not None:
raise ValueError("nanochat tokenizer only supports single sequences")
return [self.bos_token_id] + token_ids_0
def create_token_type_ids_from_sequences(
self,
token_ids_0: List[int],
token_ids_1: Optional[List[int]] = None,
) -> List[int]:
del token_ids_1
return [0] * (len(token_ids_0) + 1) # +1 for BOS
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
target_dir = os.path.join(save_directory, "tokenizer")
os.makedirs(target_dir, exist_ok=True)
filename = (filename_prefix + "-" if filename_prefix else "") + "tokenizer.pkl"
dest_file = os.path.join(target_dir, filename)
with open(dest_file, "wb") as handle:
pickle.dump(self._encoding, handle)
return (dest_file,)
def _decode(
self,
token_ids: List[int],
skip_special_tokens: bool = False,
clean_up_tokenization_spaces: Optional[bool] = None,
spaces_between_special_tokens: bool = True,
**kwargs,
) -> str:
del clean_up_tokenization_spaces, spaces_between_special_tokens, kwargs
if skip_special_tokens:
token_ids = [tid for tid in token_ids if tid not in self.all_special_ids]
return self._encoding.decode(token_ids)
def apply_chat_template(
self,
conversation,
add_generation_prompt: bool = False,
tokenize: bool = False,
return_tensors: Optional[str] = None,
**kwargs,
):
if not isinstance(conversation, list) or not conversation:
raise ValueError("conversation must be a non-empty list of messages")
messages = conversation
if messages[0]["role"] == "system":
if len(messages) < 2 or messages[1]["role"] != "user":
raise ValueError("system prompt must be followed by a user message")
merged = messages[0]["content"] + "\n\n" + messages[1]["content"]
messages = [dict(messages[1], content=merged)] + messages[2:]
token_ids: List[int] = [self.bos_token_id]
def encode_text(text: str) -> List[int]:
return self._encoding.encode(text, allowed_special=set()) if text else []
user_start = self._special_token_ids["<|user_start|>"]
user_end = self._special_token_ids["<|user_end|>"]
assistant_start = self._special_token_ids["<|assistant_start|>"]
assistant_end = self._special_token_ids["<|assistant_end|>"]
python_start = self._special_token_ids["<|python_start|>"]
python_end = self._special_token_ids["<|python_end|>"]
output_start = self._special_token_ids["<|output_start|>"]
output_end = self._special_token_ids["<|output_end|>"]
for idx, message in enumerate(messages):
expected_role = "user" if idx % 2 == 0 else "assistant"
if message["role"] != expected_role:
raise ValueError(f"Message {idx} should be from {expected_role}, got {message['role']}")
content = message["content"]
if message["role"] == "user":
if not isinstance(content, str):
raise ValueError("User messages must be plain strings")
token_ids.append(user_start)
token_ids.extend(encode_text(content))
token_ids.append(user_end)
else: # assistant
token_ids.append(assistant_start)
if isinstance(content, str):
token_ids.extend(encode_text(content))
elif isinstance(content, list):
for part in content:
part_type = part.get("type", "text")
value = part.get("text", "")
if part_type == "text":
token_ids.extend(encode_text(value))
elif part_type == "python":
token_ids.append(python_start)
token_ids.extend(encode_text(value))
token_ids.append(python_end)
elif part_type == "python_output":
token_ids.append(output_start)
token_ids.extend(encode_text(value))
token_ids.append(output_end)
else:
raise ValueError(f"Unsupported assistant part type: {part_type}")
else:
raise ValueError(f"Assistant content must be str or list, got {type(content)}")
token_ids.append(assistant_end)
if add_generation_prompt:
token_ids.append(assistant_start)
if tokenize:
if return_tensors and return_tensors != "pt":
raise ValueError("Only return_tensors='pt' is supported")
import torch
return torch.tensor(token_ids, dtype=torch.long).unsqueeze(0)
return self._encoding.decode(token_ids)
# Register the tokenizer so AutoTokenizer can locate it via NanoChatConfig.
AutoTokenizer.register(NanoChatConfig, NanoChatTokenizer)