chenguittiMaroua commited on
Commit
cbf5a05
·
verified ·
1 Parent(s): ea0f16e

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +92 -47
main.py CHANGED
@@ -132,45 +132,85 @@ def get_summarizer():
132
 
133
 
134
  MODEL_CHOICES = [
135
- "patrickvonplaten/t5-tiny-random", # Tiny test model (always works)
136
- "google/flan-t5-small", # 300MB
137
- "google/flan-t5-base", # 900MB
138
- "facebook/bart-large-cnn" # 1.6GB
139
  ]
140
 
141
  class QAService:
142
  def __init__(self):
143
  self.model = None
 
144
  self.model_name = None
145
- self.device = 0 if torch.cuda.is_available() else -1
146
 
147
  def initialize(self):
148
- """Try loading models until one succeeds"""
149
  for model_name in MODEL_CHOICES:
150
  try:
151
- logger.info(f"Attempting to load {model_name}")
152
 
153
- # Lightweight pipeline initialization
154
- self.model = pipeline(
155
- "text2text-generation",
156
- model=model_name,
157
- device=self.device,
158
- torch_dtype=torch.float16 if self.device == 0 else torch.float32
159
  )
160
  self.model_name = model_name
161
- logger.info(f"Successfully loaded {model_name}")
162
  return True
163
 
164
  except Exception as e:
165
  logger.warning(f"Failed to load {model_name}: {str(e)}")
166
  continue
167
-
168
- logger.error("All model loading attempts failed")
169
  return False
170
 
171
- # Global service instance
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
172
  qa_service = QAService()
173
 
 
 
 
 
 
 
 
 
 
 
 
174
 
175
 
176
 
@@ -863,55 +903,60 @@ async def summarize_document(request: Request, file: UploadFile = File(...)):
863
  from typing import Optional
864
 
865
  @app.post("/qa")
866
- async def handle_qa_request(
867
  question: str = Form(...),
868
- file: Optional[UploadFile] = File(None)
 
869
  ):
870
- # Initialize service if needed
871
  if not qa_service.model:
872
- if not qa_service.initialize():
873
- raise HTTPException(
874
- status_code=500,
875
- detail={
876
- "error": "System unavailable",
877
- "status": "Model initialization failed",
878
- "recovery_suggestion": "Retry in 30 seconds or contact support"
879
- }
880
- )
881
 
882
  try:
883
- # Process input
 
 
 
 
884
  context = None
885
  if file:
886
- file_ext, content = await process_uploaded_file(file)
887
- context = extract_text(content, file_ext)[:2000] # Strict limit
888
-
889
- # Generate response
 
 
 
 
 
 
 
890
  try:
891
- input_text = f"question: {question}" + (f" context: {context}" if context else "")
892
- result = qa_service.model(
893
- input_text,
894
- max_length=150,
895
- num_beams=2,
896
- early_stopping=True
897
- )
898
 
899
  return {
900
  "question": question,
901
- "answer": result[0]["generated_text"],
902
  "model": qa_service.model_name,
903
- "context_used": bool(context)
 
904
  }
905
 
906
  except Exception as e:
907
- logger.error(f"Generation failed: {str(e)}")
908
  raise HTTPException(
909
- status_code=500,
910
  detail={
911
  "error": "Answer generation failed",
912
  "model": qa_service.model_name,
913
- "input_size": len(input_text) if 'input_text' in locals() else None,
914
- "suggestion": "Simplify your question or reduce document size"
915
  }
916
  )
917
 
 
132
 
133
 
134
  MODEL_CHOICES = [
135
+ "google/flan-t5-small", # ~300MB (English)
136
+ "google/flan-t5-base", # ~900MB (English)
137
+ "cmarkea/flan-t5-base-fr" # French-optimized
 
138
  ]
139
 
140
  class QAService:
141
  def __init__(self):
142
  self.model = None
143
+ self.tokenizer = None
144
  self.model_name = None
145
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
146
 
147
  def initialize(self):
148
+ """Initialize with fallback support"""
149
  for model_name in MODEL_CHOICES:
150
  try:
151
+ logger.info(f"Loading {model_name}")
152
 
153
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
154
+ self.model = AutoModelForSeq2SeqLM.from_pretrained(
155
+ model_name,
156
+ device_map="auto",
157
+ torch_dtype=torch.float16 if "cuda" in self.device else torch.float32
 
158
  )
159
  self.model_name = model_name
160
+ logger.info(f"Successfully loaded {model_name} on {self.device}")
161
  return True
162
 
163
  except Exception as e:
164
  logger.warning(f"Failed to load {model_name}: {str(e)}")
165
  continue
166
+
167
+ logger.error("All models failed to load")
168
  return False
169
 
170
+ def generate_answer(self, question: str, context: Optional[str] = None):
171
+ """Generate answer with proper text generation parameters"""
172
+ try:
173
+ input_text = f"question: {question}"
174
+ if context:
175
+ input_text += f" context: {context[:2000]}" # Limit context size
176
+
177
+ inputs = self.tokenizer(
178
+ input_text,
179
+ return_tensors="pt",
180
+ truncation=True,
181
+ max_length=512
182
+ ).to(self.device)
183
+
184
+ outputs = self.model.generate(
185
+ **inputs,
186
+ max_new_tokens=150,
187
+ num_beams=3,
188
+ early_stopping=True,
189
+ temperature=0.7,
190
+ repetition_penalty=2.5,
191
+ no_repeat_ngram_size=3
192
+ )
193
+
194
+ return self.tokenizer.decode(outputs[0], skip_special_tokens=True)
195
+
196
+ except Exception as e:
197
+ logger.error(f"Generation failed: {str(e)}")
198
+ raise
199
+
200
+ # Initialize service
201
  qa_service = QAService()
202
 
203
+ @app.on_event("startup")
204
+ async def startup_event():
205
+ if not qa_service.initialize():
206
+ logger.error("QA service failed to initialize")
207
+
208
+
209
+
210
+
211
+
212
+
213
+
214
 
215
 
216
 
 
903
  from typing import Optional
904
 
905
  @app.post("/qa")
906
+ async def handle_qa(
907
  question: str = Form(...),
908
+ file: Optional[UploadFile] = File(None),
909
+ language: str = Form("fr")
910
  ):
911
+ """Handle QA requests with file upload support"""
912
  if not qa_service.model:
913
+ raise HTTPException(
914
+ 503,
915
+ detail={
916
+ "error": "Service unavailable",
917
+ "supported_models": MODEL_CHOICES,
918
+ "suggestion": "Try again later or contact support"
919
+ }
920
+ )
 
921
 
922
  try:
923
+ # Validate question
924
+ if not question.strip():
925
+ raise HTTPException(400, "Question cannot be empty")
926
+
927
+ # Process file if provided
928
  context = None
929
  if file:
930
+ try:
931
+ file_ext, content = await process_uploaded_file(file)
932
+ context = extract_text(content, file_ext)
933
+ context = re.sub(r'\s+', ' ', context).strip()[:2000] # Clean and limit
934
+ except HTTPException:
935
+ raise
936
+ except Exception as e:
937
+ logger.error(f"File processing failed: {str(e)}")
938
+ raise HTTPException(422, "File processing error")
939
+
940
+ # Generate answer
941
  try:
942
+ answer = qa_service.generate_answer(question, context)
 
 
 
 
 
 
943
 
944
  return {
945
  "question": question,
946
+ "answer": answer,
947
  "model": qa_service.model_name,
948
+ "context_used": context is not None,
949
+ "language": language
950
  }
951
 
952
  except Exception as e:
953
+ logger.error(f"Answer generation failed: {str(e)}")
954
  raise HTTPException(
955
+ 500,
956
  detail={
957
  "error": "Answer generation failed",
958
  "model": qa_service.model_name,
959
+ "suggestion": "Try simplifying your question or reducing document size"
 
960
  }
961
  )
962