chenguittiMaroua commited on
Commit
2e8f5e6
·
verified ·
1 Parent(s): cbf5a05

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +57 -85
main.py CHANGED
@@ -132,78 +132,44 @@ def get_summarizer():
132
 
133
 
134
  MODEL_CHOICES = [
135
- "google/flan-t5-small", # ~300MB (English)
136
- "google/flan-t5-base", # ~900MB (English)
137
- "cmarkea/flan-t5-base-fr" # French-optimized
138
  ]
139
 
140
- class QAService:
141
- def __init__(self):
142
- self.model = None
143
- self.tokenizer = None
144
- self.model_name = None
145
- self.device = "cuda" if torch.cuda.is_available() else "cpu"
146
-
147
- def initialize(self):
148
- """Initialize with fallback support"""
149
- for model_name in MODEL_CHOICES:
150
- try:
151
- logger.info(f"Loading {model_name}")
152
-
153
- self.tokenizer = AutoTokenizer.from_pretrained(model_name)
154
- self.model = AutoModelForSeq2SeqLM.from_pretrained(
155
- model_name,
156
- device_map="auto",
157
- torch_dtype=torch.float16 if "cuda" in self.device else torch.float32
158
- )
159
- self.model_name = model_name
160
- logger.info(f"Successfully loaded {model_name} on {self.device}")
161
- return True
162
-
163
- except Exception as e:
164
- logger.warning(f"Failed to load {model_name}: {str(e)}")
165
- continue
166
-
167
- logger.error("All models failed to load")
168
- return False
169
 
170
- def generate_answer(self, question: str, context: Optional[str] = None):
171
- """Generate answer with proper text generation parameters"""
 
 
 
172
  try:
173
- input_text = f"question: {question}"
174
- if context:
175
- input_text += f" context: {context[:2000]}" # Limit context size
176
 
177
- inputs = self.tokenizer(
178
- input_text,
179
- return_tensors="pt",
180
- truncation=True,
181
- max_length=512
182
- ).to(self.device)
183
-
184
- outputs = self.model.generate(
185
- **inputs,
186
- max_new_tokens=150,
187
- num_beams=3,
188
- early_stopping=True,
189
- temperature=0.7,
190
- repetition_penalty=2.5,
191
- no_repeat_ngram_size=3
192
  )
193
 
194
- return self.tokenizer.decode(outputs[0], skip_special_tokens=True)
 
 
195
 
196
  except Exception as e:
197
- logger.error(f"Generation failed: {str(e)}")
198
- raise
199
-
200
- # Initialize service
201
- qa_service = QAService()
202
 
203
  @app.on_event("startup")
204
  async def startup_event():
205
- if not qa_service.initialize():
206
- logger.error("QA service failed to initialize")
207
 
208
 
209
 
@@ -903,59 +869,65 @@ async def summarize_document(request: Request, file: UploadFile = File(...)):
903
  from typing import Optional
904
 
905
  @app.post("/qa")
906
- async def handle_qa(
907
  question: str = Form(...),
908
- file: Optional[UploadFile] = File(None),
909
- language: str = Form("fr")
910
  ):
911
- """Handle QA requests with file upload support"""
912
- if not qa_service.model:
913
  raise HTTPException(
914
- 503,
915
  detail={
916
- "error": "Service unavailable",
 
917
  "supported_models": MODEL_CHOICES,
918
- "suggestion": "Try again later or contact support"
919
  }
920
  )
921
 
922
  try:
923
- # Validate question
924
- if not question.strip():
925
- raise HTTPException(400, "Question cannot be empty")
926
-
927
- # Process file if provided
928
  context = None
929
  if file:
930
  try:
931
- file_ext, content = await process_uploaded_file(file)
932
- context = extract_text(content, file_ext)
933
- context = re.sub(r'\s+', ' ', context).strip()[:2000] # Clean and limit
934
  except HTTPException:
935
  raise
936
  except Exception as e:
937
  logger.error(f"File processing failed: {str(e)}")
938
  raise HTTPException(422, "File processing error")
939
 
940
- # Generate answer
941
  try:
942
- answer = qa_service.generate_answer(question, context)
 
 
 
 
 
 
 
 
 
 
 
943
 
944
  return {
945
  "question": question,
946
- "answer": answer,
947
- "model": qa_service.model_name,
948
- "context_used": context is not None,
949
- "language": language
950
  }
951
 
952
  except Exception as e:
953
- logger.error(f"Answer generation failed: {str(e)}")
954
  raise HTTPException(
955
- 500,
956
  detail={
957
  "error": "Answer generation failed",
958
- "model": qa_service.model_name,
959
  "suggestion": "Try simplifying your question or reducing document size"
960
  }
961
  )
 
132
 
133
 
134
  MODEL_CHOICES = [
135
+ "mrm8488/t5-base-finetuned-question-generation-ap", # Small QA model (140MB)
136
+ "google/flan-t5-small", # Official small model (300MB)
137
+ "hello-simpleai/chatbot" # Very small fallback
138
  ]
139
 
140
+ qa_pipeline = None
141
+ current_model = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
142
 
143
+ def initialize_qa():
144
+ global qa_pipeline, current_model
145
+
146
+ # Try each model in order
147
+ for model_name in MODEL_CHOICES:
148
  try:
149
+ logger.info(f"Attempting to load {model_name}")
 
 
150
 
151
+ qa_pipeline = pipeline(
152
+ "text2text-generation",
153
+ model=model_name,
154
+ device=0 if torch.cuda.is_available() else -1,
155
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
 
 
 
 
 
 
 
 
 
 
156
  )
157
 
158
+ current_model = model_name
159
+ logger.info(f"Successfully loaded {model_name}")
160
+ return True
161
 
162
  except Exception as e:
163
+ logger.warning(f"Failed to load {model_name}: {str(e)}")
164
+ continue
165
+
166
+ logger.error("All model loading attempts failed")
167
+ return False
168
 
169
  @app.on_event("startup")
170
  async def startup_event():
171
+ if not initialize_qa():
172
+ logger.error("QA system failed to initialize")
173
 
174
 
175
 
 
869
  from typing import Optional
870
 
871
  @app.post("/qa")
872
+ async def question_answering(
873
  question: str = Form(...),
874
+ file: Optional[UploadFile] = File(None)
 
875
  ):
876
+ """Handle QA requests with optional file context"""
877
+ if qa_pipeline is None:
878
  raise HTTPException(
879
+ status_code=503,
880
  detail={
881
+ "error": "QA system unavailable",
882
+ "status": "No working model could be loaded",
883
  "supported_models": MODEL_CHOICES,
884
+ "recovery_suggestion": "Please try again later"
885
  }
886
  )
887
 
888
  try:
889
+ # Process input
 
 
 
 
890
  context = None
891
  if file:
892
  try:
893
+ _, content = await process_uploaded_file(file)
894
+ context = extract_text(content, file.filename.split('.')[-1])
895
+ context = re.sub(r'\s+', ' ', context).strip()[:1000] # Clean and limit context
896
  except HTTPException:
897
  raise
898
  except Exception as e:
899
  logger.error(f"File processing failed: {str(e)}")
900
  raise HTTPException(422, "File processing error")
901
 
902
+ # Generate response
903
  try:
904
+ input_text = f"question: {question}"
905
+ if context:
906
+ input_text += f" context: {context}"
907
+
908
+ result = qa_pipeline(
909
+ input_text,
910
+ max_length=100,
911
+ num_beams=2,
912
+ temperature=0.7,
913
+ repetition_penalty=2.0,
914
+ no_repeat_ngram_size=3
915
+ )
916
 
917
  return {
918
  "question": question,
919
+ "answer": result[0]["generated_text"],
920
+ "model": current_model,
921
+ "context_used": context is not None
 
922
  }
923
 
924
  except Exception as e:
925
+ logger.error(f"Generation failed: {str(e)}")
926
  raise HTTPException(
927
+ status_code=500,
928
  detail={
929
  "error": "Answer generation failed",
930
+ "model": current_model,
931
  "suggestion": "Try simplifying your question or reducing document size"
932
  }
933
  )