|
|
""" |
|
|
Hugging Face tokenizer wrapper for nanochat's rustbpe+tiktoken vocabulary. |
|
|
""" |
|
|
|
|
|
from __future__ import annotations |
|
|
|
|
|
import os |
|
|
import pickle |
|
|
from typing import Dict, List, Optional, Tuple |
|
|
|
|
|
import tiktoken |
|
|
from transformers.tokenization_utils import PreTrainedTokenizer |
|
|
from transformers import AutoTokenizer |
|
|
|
|
|
try: |
|
|
from .configuration_nanochat import NanoChatConfig |
|
|
except ImportError: |
|
|
from configuration_nanochat import NanoChatConfig |
|
|
|
|
|
SPECIAL_TOKENS = [ |
|
|
"<|bos|>", |
|
|
"<|user_start|>", |
|
|
"<|user_end|>", |
|
|
"<|assistant_start|>", |
|
|
"<|assistant_end|>", |
|
|
"<|python_start|>", |
|
|
"<|python_end|>", |
|
|
"<|output_start|>", |
|
|
"<|output_end|>", |
|
|
] |
|
|
|
|
|
|
|
|
class NanoChatTokenizer(PreTrainedTokenizer): |
|
|
vocab_files_names = {"tokenizer_file": "tokenizer/tokenizer.pkl"} |
|
|
model_input_names = ["input_ids", "attention_mask"] |
|
|
|
|
|
def __init__(self, tokenizer_file: Optional[str] = None, **kwargs): |
|
|
if tokenizer_file is None: |
|
|
raise ValueError("tokenizer_file must be provided") |
|
|
|
|
|
|
|
|
init_kwargs = dict(kwargs) |
|
|
bos_token = kwargs.pop("bos_token", "<|bos|>") |
|
|
eos_token = kwargs.pop("eos_token", "<|bos|>") |
|
|
unk_token = kwargs.pop("unk_token", "<|bos|>") |
|
|
pad_token = kwargs.pop("pad_token", "<|bos|>") |
|
|
|
|
|
resolved_path = tokenizer_file |
|
|
if not os.path.isfile(resolved_path): |
|
|
repo_id = init_kwargs.get("name_or_path") or init_kwargs.get("pretrained_model_name_or_path") |
|
|
if repo_id: |
|
|
from huggingface_hub import hf_hub_download |
|
|
|
|
|
resolved_path = hf_hub_download( |
|
|
repo_id, |
|
|
tokenizer_file, |
|
|
revision=init_kwargs.get("revision"), |
|
|
subfolder=init_kwargs.get("subfolder"), |
|
|
cache_dir=init_kwargs.get("cache_dir"), |
|
|
token=init_kwargs.get("token"), |
|
|
) |
|
|
if not os.path.isfile(resolved_path): |
|
|
raise FileNotFoundError(f"Cannot locate tokenizer state at {tokenizer_file}") |
|
|
|
|
|
with open(resolved_path, "rb") as handle: |
|
|
self._encoding: tiktoken.Encoding = pickle.load(handle) |
|
|
|
|
|
self._id_to_token: List[str] = [] |
|
|
for token_id, token_bytes in enumerate(self._encoding.token_byte_values()): |
|
|
token = token_bytes.decode("utf-8", errors="replace") |
|
|
self._id_to_token.append(token) |
|
|
self.vocab: Dict[str, int] = {token: idx for idx, token in enumerate(self._id_to_token)} |
|
|
|
|
|
self._special_token_ids: Dict[str, int] = {} |
|
|
for special in SPECIAL_TOKENS: |
|
|
special_id = self._encoding.encode_single_token(special) |
|
|
if special_id >= len(self._id_to_token): |
|
|
self._id_to_token.extend([""] * (special_id - len(self._id_to_token) + 1)) |
|
|
self._id_to_token[special_id] = special |
|
|
self.vocab[special] = special_id |
|
|
self._special_token_ids[special] = special_id |
|
|
|
|
|
super().__init__( |
|
|
bos_token=bos_token, |
|
|
eos_token=eos_token, |
|
|
unk_token=unk_token, |
|
|
pad_token=pad_token, |
|
|
**kwargs, |
|
|
) |
|
|
|
|
|
|
|
|
self.bos_token = bos_token or "<|bos|>" |
|
|
self.eos_token = eos_token or "<|bos|>" |
|
|
self.unk_token = unk_token or "<|bos|>" |
|
|
self.pad_token = pad_token or "<|bos|>" |
|
|
|
|
|
self.bos_token_id = self.vocab[self.bos_token] |
|
|
self.eos_token_id = self.vocab[self.eos_token] |
|
|
self.unk_token_id = self.vocab[self.unk_token] |
|
|
self.pad_token_id = self.vocab[self.pad_token] |
|
|
|
|
|
@property |
|
|
def vocab_size(self) -> int: |
|
|
return len(self._id_to_token) |
|
|
|
|
|
def get_vocab(self) -> Dict[str, int]: |
|
|
return dict(self.vocab) |
|
|
|
|
|
def _tokenize(self, text: str) -> List[str]: |
|
|
token_ids = self._encoding.encode(text, allowed_special=set()) |
|
|
return [self._id_to_token[token_id] for token_id in token_ids] |
|
|
|
|
|
def _convert_token_to_id(self, token: str) -> int: |
|
|
return self.vocab.get(token, self.unk_token_id) |
|
|
|
|
|
def _convert_id_to_token(self, index: int) -> str: |
|
|
return self._id_to_token[index] |
|
|
|
|
|
def build_inputs_with_special_tokens( |
|
|
self, |
|
|
token_ids_0: List[int], |
|
|
token_ids_1: Optional[List[int]] = None, |
|
|
) -> List[int]: |
|
|
if token_ids_1 is not None: |
|
|
raise ValueError("nanochat tokenizer only supports single sequences") |
|
|
return [self.bos_token_id] + token_ids_0 |
|
|
|
|
|
def create_token_type_ids_from_sequences( |
|
|
self, |
|
|
token_ids_0: List[int], |
|
|
token_ids_1: Optional[List[int]] = None, |
|
|
) -> List[int]: |
|
|
del token_ids_1 |
|
|
return [0] * (len(token_ids_0) + 1) |
|
|
|
|
|
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: |
|
|
target_dir = os.path.join(save_directory, "tokenizer") |
|
|
os.makedirs(target_dir, exist_ok=True) |
|
|
filename = (filename_prefix + "-" if filename_prefix else "") + "tokenizer.pkl" |
|
|
dest_file = os.path.join(target_dir, filename) |
|
|
with open(dest_file, "wb") as handle: |
|
|
pickle.dump(self._encoding, handle) |
|
|
return (dest_file,) |
|
|
|
|
|
def _decode( |
|
|
self, |
|
|
token_ids: List[int], |
|
|
skip_special_tokens: bool = False, |
|
|
clean_up_tokenization_spaces: Optional[bool] = None, |
|
|
spaces_between_special_tokens: bool = True, |
|
|
**kwargs, |
|
|
) -> str: |
|
|
del clean_up_tokenization_spaces, spaces_between_special_tokens, kwargs |
|
|
if skip_special_tokens: |
|
|
token_ids = [tid for tid in token_ids if tid not in self.all_special_ids] |
|
|
return self._encoding.decode(token_ids) |
|
|
|
|
|
def apply_chat_template( |
|
|
self, |
|
|
conversation, |
|
|
add_generation_prompt: bool = False, |
|
|
tokenize: bool = False, |
|
|
return_tensors: Optional[str] = None, |
|
|
**kwargs, |
|
|
): |
|
|
if not isinstance(conversation, list) or not conversation: |
|
|
raise ValueError("conversation must be a non-empty list of messages") |
|
|
|
|
|
messages = conversation |
|
|
if messages[0]["role"] == "system": |
|
|
if len(messages) < 2 or messages[1]["role"] != "user": |
|
|
raise ValueError("system prompt must be followed by a user message") |
|
|
merged = messages[0]["content"] + "\n\n" + messages[1]["content"] |
|
|
messages = [dict(messages[1], content=merged)] + messages[2:] |
|
|
|
|
|
token_ids: List[int] = [self.bos_token_id] |
|
|
|
|
|
def encode_text(text: str) -> List[int]: |
|
|
return self._encoding.encode(text, allowed_special=set()) if text else [] |
|
|
|
|
|
user_start = self._special_token_ids["<|user_start|>"] |
|
|
user_end = self._special_token_ids["<|user_end|>"] |
|
|
assistant_start = self._special_token_ids["<|assistant_start|>"] |
|
|
assistant_end = self._special_token_ids["<|assistant_end|>"] |
|
|
python_start = self._special_token_ids["<|python_start|>"] |
|
|
python_end = self._special_token_ids["<|python_end|>"] |
|
|
output_start = self._special_token_ids["<|output_start|>"] |
|
|
output_end = self._special_token_ids["<|output_end|>"] |
|
|
|
|
|
for idx, message in enumerate(messages): |
|
|
expected_role = "user" if idx % 2 == 0 else "assistant" |
|
|
if message["role"] != expected_role: |
|
|
raise ValueError(f"Message {idx} should be from {expected_role}, got {message['role']}") |
|
|
content = message["content"] |
|
|
if message["role"] == "user": |
|
|
if not isinstance(content, str): |
|
|
raise ValueError("User messages must be plain strings") |
|
|
token_ids.append(user_start) |
|
|
token_ids.extend(encode_text(content)) |
|
|
token_ids.append(user_end) |
|
|
else: |
|
|
token_ids.append(assistant_start) |
|
|
if isinstance(content, str): |
|
|
token_ids.extend(encode_text(content)) |
|
|
elif isinstance(content, list): |
|
|
for part in content: |
|
|
part_type = part.get("type", "text") |
|
|
value = part.get("text", "") |
|
|
if part_type == "text": |
|
|
token_ids.extend(encode_text(value)) |
|
|
elif part_type == "python": |
|
|
token_ids.append(python_start) |
|
|
token_ids.extend(encode_text(value)) |
|
|
token_ids.append(python_end) |
|
|
elif part_type == "python_output": |
|
|
token_ids.append(output_start) |
|
|
token_ids.extend(encode_text(value)) |
|
|
token_ids.append(output_end) |
|
|
else: |
|
|
raise ValueError(f"Unsupported assistant part type: {part_type}") |
|
|
else: |
|
|
raise ValueError(f"Assistant content must be str or list, got {type(content)}") |
|
|
token_ids.append(assistant_end) |
|
|
|
|
|
if add_generation_prompt: |
|
|
token_ids.append(assistant_start) |
|
|
|
|
|
if tokenize: |
|
|
if return_tensors and return_tensors != "pt": |
|
|
raise ValueError("Only return_tensors='pt' is supported") |
|
|
import torch |
|
|
|
|
|
return torch.tensor(token_ids, dtype=torch.long).unsqueeze(0) |
|
|
return self._encoding.decode(token_ids) |
|
|
|
|
|
|
|
|
|
|
|
AutoTokenizer.register(NanoChatConfig, NanoChatTokenizer) |
|
|
|