import os import torch import torch.nn.functional as F from collections import OrderedDict import string from model import ChatGCLM, MAX_SEQ_LEN MODEL_PATH = None for f in os.listdir("."): if f.startswith("Turing_") and f.endswith(".pt"): MODEL_PATH = f break if MODEL_PATH is None: print("Error: No model checkpoint found!") print("Please train the model first with: python3 train.py") exit(1) EOS_ID = 2 OFFSET = 3 CHARS = string.printable def encode(text): return [CHARS.index(c) + OFFSET for c in text if c in CHARS] def decode(ids): return "".join([CHARS[i - OFFSET] for i in ids if i >= OFFSET]) def load_model(device): vocab_size = len(CHARS) + OFFSET model = ChatGCLM(vocab_size).to(device) if os.path.exists(MODEL_PATH) and os.path.getsize(MODEL_PATH) > 0: print(f"Loading model from: {MODEL_PATH}") ckpt = torch.load(MODEL_PATH, map_location=device) if isinstance(ckpt, dict): if 'model_state_dict' in ckpt: state_dict = ckpt['model_state_dict'] elif 'state_dict' in ckpt: state_dict = ckpt['state_dict'] else: state_dict = ckpt else: state_dict = ckpt def _strip_module_prefix(sd): keys = list(sd.keys()) if any(k.startswith('module.') for k in keys): new_sd = OrderedDict() for k, v in sd.items(): new_key = k[len('module.'): ] if k.startswith('module.') else k new_sd[new_key] = v return new_sd return sd state_dict = _strip_module_prefix(state_dict) res = model.load_state_dict(state_dict, strict=False) missing = getattr(res, 'missing_keys', None) unexpected = getattr(res, 'unexpected_keys', None) if missing: print(f"Warning: missing keys when loading state_dict: {missing}") if unexpected: print(f"Warning: unexpected keys in state_dict: {unexpected}") model.eval() return model else: print(f"Error: Could not load model from {MODEL_PATH}") return None @torch.no_grad() def generate(model, prompt, device, max_new_tokens=200, temperature=0.8, top_k=50): model.eval() input_ids = encode(prompt) x = torch.tensor([input_ids], dtype=torch.long, device=device) print(f"\n{'='*70}") print(f"PROMPT: {prompt}") print(f"{'='*70}") print("GENERATED TEXT:") print(prompt, end="", flush=True) generated_tokens = [] for _ in range(max_new_tokens): ctx = x[:, -MAX_SEQ_LEN:] if x.size(1) > MAX_SEQ_LEN else x logits = model(ctx) next_token_logits = logits[:, -1, :] / temperature if top_k is not None: v, _ = torch.topk(next_token_logits, min(top_k, next_token_logits.size(-1))) next_token_logits[next_token_logits < v[:, [-1]]] = -float('Inf') probs = F.softmax(next_token_logits, dim=-1) next_token = torch.multinomial(probs, num_samples=1) idx = next_token.item() if idx == EOS_ID: break x = torch.cat((x, next_token), dim=1) generated_tokens.append(idx) token_text = decode([idx]) print(token_text, end="", flush=True) print(f"\n{'='*70}\n") return decode(generated_tokens) if __name__ == "__main__": device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" print(f"Using device: {device}") model = load_model(device) if model is None: exit(1) test_prompts = [ "Once upon a time", "The future of AI is", "In a world where", ] print("\n" + "="*70) print("ChatGCLM Text Generation Demo") print("="*70) for prompt in test_prompts: generate(model, prompt, device, max_new_tokens=150, temperature=0.8, top_k=50) print("\n" + "="*70) print("Interactive Mode - Enter your own prompts!") print("="*70) while True: user_prompt = input("\nEnter prompt (or 'exit' to quit): ") if user_prompt.lower() == 'exit': break if user_prompt.strip(): generate(model, user_prompt, device, max_new_tokens=200, temperature=0.8, top_k=50)