SmitaGautam commited on
Commit
6b35b80
·
verified ·
1 Parent(s): 6e6d27b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -11
app.py CHANGED
@@ -1,12 +1,24 @@
 
1
  import torch
2
  import gradio as gr
3
  from train import CharTokenizer, Seq2Seq, Encoder, Decoder, TransformerTransliterator
4
 
 
 
 
 
 
5
 
 
 
6
 
 
 
7
 
8
- # ----------------------
9
- # 1️⃣ Load LSTM checkpoint
 
 
10
  NUM_LAYERS_MODEL = 2
11
  DROPOUT = 0.3
12
 
@@ -21,8 +33,18 @@ lstm_model.eval()
21
  print("✅ LSTM model loaded")
22
 
23
  # ----------------------
24
- # 2️⃣ Load Transformer checkpoint
25
  # ----------------------
 
 
 
 
 
 
 
 
 
 
26
  dim_feedforward=512,
27
  dropout=0.1,
28
  max_len=100
@@ -33,7 +55,7 @@ transformer_model.eval()
33
  print("✅ Transformer model loaded")
34
 
35
  # ----------------------
36
- # 3️⃣ Load lightweight LLM (DistilBERT-based or small model)
37
  # ----------------------
38
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
39
 
@@ -46,12 +68,12 @@ try:
46
  print("✅ LLM model loaded (Flan-T5 Small)")
47
  has_llm = True
48
  except Exception as e:
49
- print(f"⚠️ LLM loading failed: {e}")
50
- print("⚠️ Will use only LSTM and Transformer models")
51
  has_llm = False
52
 
53
  # ----------------------
54
- # 4️⃣ Transliteration Function
55
  # ----------------------
56
  @torch.no_grad()
57
  def transliterate(word):
@@ -78,7 +100,7 @@ def transliterate(word):
78
  # LLM prediction (lightweight T5)
79
  if has_llm:
80
  try:
81
- prompt = f"Transliterate the romanized Hindi word to Devanagari script: {word}"
82
  inputs = llm_tokenizer(prompt, return_tensors="pt").to(device)
83
  output_ids = llm_model.generate(
84
  **inputs,
@@ -97,7 +119,7 @@ def transliterate(word):
97
  return lstm_pred, transformer_pred, llm_pred
98
 
99
  # ----------------------
100
- # 5️⃣ Gradio Interface
101
  # ----------------------
102
  demo = gr.Interface(
103
  fn=transliterate,
@@ -124,11 +146,11 @@ demo = gr.Interface(
124
  allow_flagging="never"
125
  )
126
 
127
- if __name__ == "__main__":
128
  print("🚀 Starting Gradio interface...")
129
  demo.launch(
130
  share=False,
131
  debug=False,
132
  server_name="0.0.0.0",
133
  server_port=7860
134
- )
 
1
+ import os
2
  import torch
3
  import gradio as gr
4
  from train import CharTokenizer, Seq2Seq, Encoder, Decoder, TransformerTransliterator
5
 
6
+ # ----------------------
7
+ # ⿡ Load LSTM checkpoint
8
+ # ----------------------
9
+ lstm_ckpt_path = "lstm_transliterator.pt"
10
+ lstm_ckpt = torch.load(lstm_ckpt_path, map_location='cpu')
11
 
12
+ src_vocab = lstm_ckpt['src_vocab']
13
+ tgt_vocab = lstm_ckpt['tgt_vocab']
14
 
15
+ src_tokenizer = CharTokenizer(vocab=src_vocab)
16
+ tgt_tokenizer = CharTokenizer(vocab=tgt_vocab)
17
 
18
+ # Reconstruct LSTM model architecture
19
+ EMBED_DIM = 256
20
+ ENC_HIDDEN_DIM = 256
21
+ DEC_HIDDEN_DIM = 256
22
  NUM_LAYERS_MODEL = 2
23
  DROPOUT = 0.3
24
 
 
33
  print("✅ LSTM model loaded")
34
 
35
  # ----------------------
36
+ # Load Transformer checkpoint
37
  # ----------------------
38
+ transformer_ckpt_path = "transformer_transliterator.pt"
39
+ transformer_ckpt = torch.load(transformer_ckpt_path, map_location='cpu')
40
+
41
+ transformer_model = TransformerTransliterator(
42
+ src_vocab_size=len(src_tokenizer),
43
+ tgt_vocab_size=len(tgt_tokenizer),
44
+ d_model=256,
45
+ nhead=8,
46
+ num_encoder_layers=2,
47
+ num_decoder_layers=2,
48
  dim_feedforward=512,
49
  dropout=0.1,
50
  max_len=100
 
55
  print("✅ Transformer model loaded")
56
 
57
  # ----------------------
58
+ # Load lightweight LLM (DistilBERT-based or small model)
59
  # ----------------------
60
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
61
 
 
68
  print("✅ LLM model loaded (Flan-T5 Small)")
69
  has_llm = True
70
  except Exception as e:
71
+ print(f" LLM loading failed: {e}")
72
+ print(" Will use only LSTM and Transformer models")
73
  has_llm = False
74
 
75
  # ----------------------
76
+ # Transliteration Function
77
  # ----------------------
78
  @torch.no_grad()
79
  def transliterate(word):
 
100
  # LLM prediction (lightweight T5)
101
  if has_llm:
102
  try:
103
+ prompt = f"Transliterate the Romanized Hindi word to Devanagari script: {word}"
104
  inputs = llm_tokenizer(prompt, return_tensors="pt").to(device)
105
  output_ids = llm_model.generate(
106
  **inputs,
 
119
  return lstm_pred, transformer_pred, llm_pred
120
 
121
  # ----------------------
122
+ # Gradio Interface
123
  # ----------------------
124
  demo = gr.Interface(
125
  fn=transliterate,
 
146
  allow_flagging="never"
147
  )
148
 
149
+ if _name_ == "_main_":
150
  print("🚀 Starting Gradio interface...")
151
  demo.launch(
152
  share=False,
153
  debug=False,
154
  server_name="0.0.0.0",
155
  server_port=7860
156
+     )