Turing / sample.py
AGofficial's picture
Upload 8 files
53264fa verified
import os
import torch
import torch.nn.functional as F
from collections import OrderedDict
import string
from model import ChatGCLM, MAX_SEQ_LEN
MODEL_PATH = None
for f in os.listdir("."):
if f.startswith("Turing_") and f.endswith(".pt"):
MODEL_PATH = f
break
if MODEL_PATH is None:
print("Error: No model checkpoint found!")
print("Please train the model first with: python3 train.py")
exit(1)
EOS_ID = 2
OFFSET = 3
CHARS = string.printable
def encode(text):
return [CHARS.index(c) + OFFSET for c in text if c in CHARS]
def decode(ids):
return "".join([CHARS[i - OFFSET] for i in ids if i >= OFFSET])
def load_model(device):
vocab_size = len(CHARS) + OFFSET
model = ChatGCLM(vocab_size).to(device)
if os.path.exists(MODEL_PATH) and os.path.getsize(MODEL_PATH) > 0:
print(f"Loading model from: {MODEL_PATH}")
ckpt = torch.load(MODEL_PATH, map_location=device)
if isinstance(ckpt, dict):
if 'model_state_dict' in ckpt:
state_dict = ckpt['model_state_dict']
elif 'state_dict' in ckpt:
state_dict = ckpt['state_dict']
else:
state_dict = ckpt
else:
state_dict = ckpt
def _strip_module_prefix(sd):
keys = list(sd.keys())
if any(k.startswith('module.') for k in keys):
new_sd = OrderedDict()
for k, v in sd.items():
new_key = k[len('module.'): ] if k.startswith('module.') else k
new_sd[new_key] = v
return new_sd
return sd
state_dict = _strip_module_prefix(state_dict)
res = model.load_state_dict(state_dict, strict=False)
missing = getattr(res, 'missing_keys', None)
unexpected = getattr(res, 'unexpected_keys', None)
if missing:
print(f"Warning: missing keys when loading state_dict: {missing}")
if unexpected:
print(f"Warning: unexpected keys in state_dict: {unexpected}")
model.eval()
return model
else:
print(f"Error: Could not load model from {MODEL_PATH}")
return None
@torch.no_grad()
def generate(model, prompt, device, max_new_tokens=200, temperature=0.8, top_k=50):
model.eval()
input_ids = encode(prompt)
x = torch.tensor([input_ids], dtype=torch.long, device=device)
print(f"\n{'='*70}")
print(f"PROMPT: {prompt}")
print(f"{'='*70}")
print("GENERATED TEXT:")
print(prompt, end="", flush=True)
generated_tokens = []
for _ in range(max_new_tokens):
ctx = x[:, -MAX_SEQ_LEN:] if x.size(1) > MAX_SEQ_LEN else x
logits = model(ctx)
next_token_logits = logits[:, -1, :] / temperature
if top_k is not None:
v, _ = torch.topk(next_token_logits, min(top_k, next_token_logits.size(-1)))
next_token_logits[next_token_logits < v[:, [-1]]] = -float('Inf')
probs = F.softmax(next_token_logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
idx = next_token.item()
if idx == EOS_ID:
break
x = torch.cat((x, next_token), dim=1)
generated_tokens.append(idx)
token_text = decode([idx])
print(token_text, end="", flush=True)
print(f"\n{'='*70}\n")
return decode(generated_tokens)
if __name__ == "__main__":
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
print(f"Using device: {device}")
model = load_model(device)
if model is None:
exit(1)
test_prompts = [
"Once upon a time",
"The future of AI is",
"In a world where",
]
print("\n" + "="*70)
print("ChatGCLM Text Generation Demo")
print("="*70)
for prompt in test_prompts:
generate(model, prompt, device, max_new_tokens=150, temperature=0.8, top_k=50)
print("\n" + "="*70)
print("Interactive Mode - Enter your own prompts!")
print("="*70)
while True:
user_prompt = input("\nEnter prompt (or 'exit' to quit): ")
if user_prompt.lower() == 'exit':
break
if user_prompt.strip():
generate(model, user_prompt, device, max_new_tokens=200, temperature=0.8, top_k=50)