|
|
import os |
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
from collections import OrderedDict |
|
|
import string |
|
|
from model import ChatGCLM, MAX_SEQ_LEN |
|
|
|
|
|
MODEL_PATH = None |
|
|
for f in os.listdir("."): |
|
|
if f.startswith("Turing_") and f.endswith(".pt"): |
|
|
MODEL_PATH = f |
|
|
break |
|
|
|
|
|
if MODEL_PATH is None: |
|
|
print("Error: No model checkpoint found!") |
|
|
print("Please train the model first with: python3 train.py") |
|
|
exit(1) |
|
|
|
|
|
EOS_ID = 2 |
|
|
OFFSET = 3 |
|
|
CHARS = string.printable |
|
|
|
|
|
def encode(text): |
|
|
return [CHARS.index(c) + OFFSET for c in text if c in CHARS] |
|
|
|
|
|
def decode(ids): |
|
|
return "".join([CHARS[i - OFFSET] for i in ids if i >= OFFSET]) |
|
|
|
|
|
def load_model(device): |
|
|
vocab_size = len(CHARS) + OFFSET |
|
|
|
|
|
model = ChatGCLM(vocab_size).to(device) |
|
|
if os.path.exists(MODEL_PATH) and os.path.getsize(MODEL_PATH) > 0: |
|
|
print(f"Loading model from: {MODEL_PATH}") |
|
|
ckpt = torch.load(MODEL_PATH, map_location=device) |
|
|
|
|
|
if isinstance(ckpt, dict): |
|
|
if 'model_state_dict' in ckpt: |
|
|
state_dict = ckpt['model_state_dict'] |
|
|
elif 'state_dict' in ckpt: |
|
|
state_dict = ckpt['state_dict'] |
|
|
else: |
|
|
state_dict = ckpt |
|
|
else: |
|
|
state_dict = ckpt |
|
|
|
|
|
def _strip_module_prefix(sd): |
|
|
keys = list(sd.keys()) |
|
|
if any(k.startswith('module.') for k in keys): |
|
|
new_sd = OrderedDict() |
|
|
for k, v in sd.items(): |
|
|
new_key = k[len('module.'): ] if k.startswith('module.') else k |
|
|
new_sd[new_key] = v |
|
|
return new_sd |
|
|
return sd |
|
|
|
|
|
state_dict = _strip_module_prefix(state_dict) |
|
|
|
|
|
res = model.load_state_dict(state_dict, strict=False) |
|
|
missing = getattr(res, 'missing_keys', None) |
|
|
unexpected = getattr(res, 'unexpected_keys', None) |
|
|
if missing: |
|
|
print(f"Warning: missing keys when loading state_dict: {missing}") |
|
|
if unexpected: |
|
|
print(f"Warning: unexpected keys in state_dict: {unexpected}") |
|
|
|
|
|
model.eval() |
|
|
return model |
|
|
else: |
|
|
print(f"Error: Could not load model from {MODEL_PATH}") |
|
|
return None |
|
|
|
|
|
@torch.no_grad() |
|
|
def generate(model, prompt, device, max_new_tokens=200, temperature=0.8, top_k=50): |
|
|
model.eval() |
|
|
input_ids = encode(prompt) |
|
|
x = torch.tensor([input_ids], dtype=torch.long, device=device) |
|
|
|
|
|
print(f"\n{'='*70}") |
|
|
print(f"PROMPT: {prompt}") |
|
|
print(f"{'='*70}") |
|
|
print("GENERATED TEXT:") |
|
|
print(prompt, end="", flush=True) |
|
|
|
|
|
generated_tokens = [] |
|
|
for _ in range(max_new_tokens): |
|
|
ctx = x[:, -MAX_SEQ_LEN:] if x.size(1) > MAX_SEQ_LEN else x |
|
|
logits = model(ctx) |
|
|
next_token_logits = logits[:, -1, :] / temperature |
|
|
|
|
|
if top_k is not None: |
|
|
v, _ = torch.topk(next_token_logits, min(top_k, next_token_logits.size(-1))) |
|
|
next_token_logits[next_token_logits < v[:, [-1]]] = -float('Inf') |
|
|
|
|
|
probs = F.softmax(next_token_logits, dim=-1) |
|
|
next_token = torch.multinomial(probs, num_samples=1) |
|
|
idx = next_token.item() |
|
|
|
|
|
if idx == EOS_ID: |
|
|
break |
|
|
|
|
|
x = torch.cat((x, next_token), dim=1) |
|
|
generated_tokens.append(idx) |
|
|
token_text = decode([idx]) |
|
|
print(token_text, end="", flush=True) |
|
|
|
|
|
print(f"\n{'='*70}\n") |
|
|
return decode(generated_tokens) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" |
|
|
print(f"Using device: {device}") |
|
|
|
|
|
model = load_model(device) |
|
|
|
|
|
if model is None: |
|
|
exit(1) |
|
|
|
|
|
test_prompts = [ |
|
|
"Once upon a time", |
|
|
"The future of AI is", |
|
|
"In a world where", |
|
|
] |
|
|
|
|
|
print("\n" + "="*70) |
|
|
print("ChatGCLM Text Generation Demo") |
|
|
print("="*70) |
|
|
|
|
|
for prompt in test_prompts: |
|
|
generate(model, prompt, device, max_new_tokens=150, temperature=0.8, top_k=50) |
|
|
|
|
|
print("\n" + "="*70) |
|
|
print("Interactive Mode - Enter your own prompts!") |
|
|
print("="*70) |
|
|
|
|
|
while True: |
|
|
user_prompt = input("\nEnter prompt (or 'exit' to quit): ") |
|
|
if user_prompt.lower() == 'exit': |
|
|
break |
|
|
if user_prompt.strip(): |
|
|
generate(model, user_prompt, device, max_new_tokens=200, temperature=0.8, top_k=50) |