import os import torch import gradio as gr from train import CharTokenizer, Seq2Seq, Encoder, Decoder, TransformerTransliterator # ---------------------- # 1️⃣ Load LSTM checkpoint # ---------------------- lstm_ckpt_path = "lstm_transliterator.pt" lstm_ckpt = torch.load(lstm_ckpt_path, map_location='cpu') src_vocab = lstm_ckpt['src_vocab'] tgt_vocab = lstm_ckpt['tgt_vocab'] src_tokenizer = CharTokenizer(vocab=src_vocab) tgt_tokenizer = CharTokenizer(vocab=tgt_vocab) # Reconstruct LSTM model architecture EMBED_DIM = 256 ENC_HIDDEN_DIM = 256 DEC_HIDDEN_DIM = 256 NUM_LAYERS_MODEL = 2 DROPOUT = 0.3 device = 'cuda' if torch.cuda.is_available() else 'cpu' encoder = Encoder(len(src_tokenizer), EMBED_DIM, ENC_HIDDEN_DIM, NUM_LAYERS_MODEL, DROPOUT) decoder = Decoder(len(tgt_tokenizer), EMBED_DIM, ENC_HIDDEN_DIM, DEC_HIDDEN_DIM, NUM_LAYERS_MODEL, DROPOUT) lstm_model = Seq2Seq(encoder, decoder, device=device).to(device) lstm_model.load_state_dict(lstm_ckpt['model_state_dict']) lstm_model.eval() print("✅ LSTM model loaded") # ---------------------- # 2️⃣ Load Transformer checkpoint # ---------------------- transformer_ckpt_path = "transformer_transliterator.pt" transformer_ckpt = torch.load(transformer_ckpt_path, map_location='cpu') transformer_model = TransformerTransliterator( src_vocab_size=len(src_tokenizer), tgt_vocab_size=len(tgt_tokenizer), d_model=256, nhead=8, num_encoder_layers=2, num_decoder_layers=2, dim_feedforward=512, dropout=0.1, max_len=100 ).to(device) transformer_model.load_state_dict(transformer_ckpt['model_state_dict']) transformer_model.eval() print("✅ Transformer model loaded") # ---------------------- # 3️⃣ Load TinyLLaMA # ---------------------- from transformers import AutoTokenizer, AutoModelForCausalLM try: llm_model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" llm_tokenizer = AutoTokenizer.from_pretrained(llm_model_name, trust_remote_code=True) llm_model = AutoModelForCausalLM.from_pretrained( llm_model_name, torch_dtype=torch.float16 if device == "cuda" else torch.float32, device_map="auto" if device == "cuda" else None, trust_remote_code=True, ) if device != "cuda": llm_model = llm_model.to(device) llm_model.eval() print("✅ TinyLLaMA model loaded") has_llm = True except Exception as e: print(f"⚠️ TinyLLaMA loading failed: {e}") print("⚠️ Will use only LSTM and Transformer models") has_llm = False # ---------------------- # 4️⃣ Transliteration Function # ---------------------- @torch.no_grad() def transliterate(word): word = word.strip() if not word: return "❌ Empty input", "❌ Empty input", "❌ Empty input" try: # LSTM prediction lstm_pred = lstm_model.translate(word, src_tokenizer, tgt_tokenizer) except Exception as e: lstm_pred = f"Error: {str(e)[:50]}" try: # Transformer prediction (greedy) transformer_pred = transformer_model.translate( word, src_tokenizer, tgt_tokenizer, device=device, decoding="greedy" ) except Exception as e: transformer_pred = f"Error: {str(e)[:50]}" # TinyLLaMA prediction if has_llm: try: prompt = f"Transliterate the following English word to Hindi (Devanagari script).\nEnglish word: {word}\nHindi transliteration:" inputs = llm_tokenizer(prompt, return_tensors="pt").to(device) with torch.no_grad(): output_ids = llm_model.generate( **inputs, max_new_tokens=30, temperature=0.7, top_p=0.9, do_sample=True, pad_token_id=llm_tokenizer.eos_token_id, eos_token_id=llm_tokenizer.eos_token_id, ) generated = llm_tokenizer.decode(output_ids[0], skip_special_tokens=True) llm_pred = generated.replace(prompt, "").strip() llm_pred = ''.join(c for c in llm_pred if not (c.isascii() and c.isalpha()) and c.strip()) except Exception as e: llm_pred = f"Error: {str(e)[:50]}" else: llm_pred = "TinyLLaMA model not loaded (insufficient memory)" return lstm_pred, transformer_pred, llm_pred # ---------------------- # 5️⃣ Gradio Interface # ---------------------- demo = gr.Interface( fn=transliterate, inputs=gr.Textbox( label="Input Hindi Roman Word", placeholder="e.g., namaste, dhanyavaad, bharat", lines=1 ), outputs=[ gr.Textbox(label="LSTM Prediction", interactive=False), gr.Textbox(label="Transformer Prediction", interactive=False), gr.Textbox(label="TinyLLaMA Prediction", interactive=False) ], title="Hindi Roman to Devanagari Transliteration", description="Compare three models: LSTM, Transformer, and TinyLLaMA.\nEnter a Hindi Roman word to get transliteration predictions.", examples=[ ["namaste"], ["dhanyavaad"], ["bharat"], ["mumbai"], ["hindustan"], ["pranaam"] ], allow_flagging="never" ) demo.launch( share=False, debug=False, server_name="0.0.0.0", server_port=7860 )