chenguittiMaroua commited on
Commit
0940d8b
·
verified ·
1 Parent(s): b176e46

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +52 -262
main.py CHANGED
@@ -64,87 +64,42 @@ def get_image_captioning():
64
  @lru_cache()
65
  def get_translator():
66
  return pipeline("translation", model="facebook/nllb-200-distilled-600M")
67
-
68
  @lru_cache()
69
  def get_qa_model():
70
- model_name = "deepset/roberta-base-squad2"
71
- tokenizer = AutoTokenizer.from_pretrained(model_name)
72
- model = AutoModelForQuestionAnswering.from_pretrained(model_name)
73
- return tokenizer, model
74
-
75
- # Helper Functions
76
- def answer_question(question: str, context: str) -> dict:
77
- tokenizer, model = get_qa_model()
78
-
79
  try:
80
- # Try with the full context first
81
- inputs = tokenizer(
82
- question,
83
- context,
84
- max_length=512,
85
- truncation="only_second",
86
- padding="max_length",
87
- return_tensors="pt"
88
- )
89
-
90
- with torch.no_grad():
91
- outputs = model(
92
- input_ids=inputs["input_ids"],
93
- attention_mask=inputs["attention_mask"]
94
- )
95
-
96
- answer_start = torch.argmax(outputs.start_logits)
97
- answer_end = torch.argmax(outputs.end_logits) + 1
98
- answer = tokenizer.decode(
99
- inputs["input_ids"][0][answer_start:answer_end],
100
- skip_special_tokens=True
101
- ).strip()
102
-
103
- # Calculate confidence
104
- start_score = torch.max(torch.nn.functional.softmax(outputs.start_logits, dim=1)).item()
105
- end_score = torch.max(torch.nn.functional.softmax(outputs.end_logits, dim=1)).item()
106
- confidence = (start_score + end_score) / 2
107
-
108
- # If no answer found, try sentence by sentence
109
- if not answer or confidence < 0.5:
110
- sentences = [s.strip() for s in context.split('.') if len(s.strip()) > 10]
111
- for sentence in sentences:
112
- if any(word in sentence.lower() for word in question.lower().split()):
113
- inputs = tokenizer(
114
- question,
115
- sentence,
116
- max_length=512,
117
- truncation="only_second",
118
- padding="max_length",
119
- return_tensors="pt"
120
- )
121
- with torch.no_grad():
122
- outputs = model(
123
- input_ids=inputs["input_ids"],
124
- attention_mask=inputs["attention_mask"]
125
- )
126
- temp_start = torch.argmax(outputs.start_logits)
127
- temp_end = torch.argmax(outputs.end_logits) + 1
128
- temp_answer = tokenizer.decode(
129
- inputs["input_ids"][0][temp_start:temp_end],
130
- skip_special_tokens=True
131
- ).strip()
132
- if temp_answer:
133
- return {
134
- "answer": temp_answer,
135
- "confidence": 0.7, # Slightly lower confidence for fallback
136
- "context_used": sentence
137
- }
138
-
139
- return {
140
- "answer": answer if answer else "No answer found in the given context",
141
- "confidence": confidence
142
- }
143
  except Exception as e:
144
- return {
145
- "answer": f"Error processing answer: {str(e)}",
146
- "confidence": 0.0
147
- }
 
 
 
 
148
  @app.get("/", response_class=HTMLResponse)
149
  def home ():
150
  with open("static/indexAI.html","r") as file :
@@ -192,195 +147,30 @@ async def summarize_document(file: UploadFile = File(...)):
192
  return {"summary": summary}
193
  except Exception as e:
194
  raise HTTPException(500, f"Error processing document: {str(e)}")
195
-
196
- @app.post("/ask")
197
- async def ask_question(
198
- question: str = Form(...),
199
- file: Optional[UploadFile] = File(None),
200
- text: Optional[str] = Form(None)
201
- ):
202
  try:
203
- # 1. Extract and preprocess context
204
- context = await extract_context(file, text)
205
- if not context.strip():
206
- raise HTTPException(400, "No extractable content found")
207
-
208
- # 2. Clean and prepare context
209
- context = clean_context(context)
210
-
211
- # 3. Try primary QA model
212
- qa_result = answer_with_model(question, context)
213
-
214
- # 4. If high confidence but no answer found, try sentence-level analysis
215
- if qa_result["confidence"] > 0.6 and is_no_answer(qa_result["answer"]):
216
- sentence_result = answer_from_sentences(question, context)
217
- if sentence_result:
218
- return format_response(sentence_result, context)
219
-
220
- # 5. If low confidence, try semantic similarity
221
- if qa_result["confidence"] < 0.4:
222
- semantic_result = answer_with_semantic_search(question, context)
223
- if semantic_result["confidence"] > qa_result["confidence"]:
224
- return format_response(semantic_result, context)
225
-
226
- # 6. Final fallback to keyword matching
227
- if is_no_answer(qa_result["answer"]):
228
- keyword_result = answer_with_keywords(question, context)
229
- if keyword_result:
230
- return format_response(keyword_result, context)
231
-
232
- # 7. Return whatever answer we have
233
- return format_response(qa_result, context)
234
 
235
- except Exception as e:
236
- raise HTTPException(500, f"Error processing question: {str(e)}")
 
 
 
 
237
 
 
 
 
238
 
239
- # Helper functions
240
- async def extract_context(file: Optional[UploadFile], text: Optional[str]) -> str:
241
- """Extract text from file or use provided text"""
242
- if file:
243
- content = await file.read()
244
- file_ext = file.filename.split(".")[-1].lower()
245
-
246
- if file_ext == "pdf":
247
- pdf = fitz.open(stream=content, filetype="pdf")
248
- return " ".join([page.get_text("text") for page in pdf])
249
- elif file_ext == "docx":
250
- doc = Document(io.BytesIO(content))
251
- return " ".join([p.text for p in doc.paragraphs if p.text.strip()])
252
- elif file_ext in ["xls", "xlsx"]:
253
- df = pd.read_excel(io.BytesIO(content))
254
- return " ".join(df.iloc[:, 0].dropna().astype(str).tolist())
255
- elif file_ext == "pptx":
256
- ppt = Presentation(io.BytesIO(content))
257
- return " ".join([shape.text for slide in ppt.slides for shape in slide.shapes if hasattr(shape, "text")])
258
- elif file_ext in ["jpg", "jpeg", "png"]:
259
- image = Image.open(io.BytesIO(content))
260
- try:
261
- context = pytesseract.image_to_string(image)
262
- return context if context.strip() else get_image_captioning()(image)[0]['generated_text']
263
- except:
264
- return get_image_captioning()(image)[0]['generated_text']
265
- return text or ""
266
-
267
- def clean_context(context: str) -> str:
268
- """Clean and normalize context text"""
269
- context = " ".join(context.split()) # Remove excessive whitespace
270
- return context[:10000] # Limit context size
271
-
272
- def answer_with_model(question: str, context: str) -> dict:
273
- """Use QA model to find answer"""
274
- tokenizer, model = get_qa_model()
275
- inputs = tokenizer(
276
- question,
277
- context,
278
- max_length=512,
279
- truncation="only_second",
280
- padding="max_length",
281
- return_tensors="pt"
282
- )
283
-
284
- with torch.no_grad():
285
- outputs = model(
286
- input_ids=inputs["input_ids"],
287
- attention_mask=inputs["attention_mask"]
288
- )
289
-
290
- answer_start = torch.argmax(outputs.start_logits)
291
- answer_end = torch.argmax(outputs.end_logits) + 1
292
- answer = tokenizer.decode(
293
- inputs["input_ids"][0][answer_start:answer_end],
294
- skip_special_tokens=True
295
- ).strip()
296
-
297
- confidence = (torch.max(torch.nn.functional.softmax(outputs.start_logits, dim=1)).item() +
298
- torch.max(torch.nn.functional.softmax(outputs.end_logits, dim=1)).item()) / 2
299
-
300
- return {
301
- "answer": answer if answer else "No answer found",
302
- "confidence": confidence
303
- }
304
-
305
- def answer_from_sentences(question: str, context: str) -> Optional[dict]:
306
- """Try to find answer by analyzing individual sentences"""
307
- sentences = [s.strip() for s in context.split('.') if len(s.strip()) > 20]
308
- for sentence in sentences:
309
- if any(word in sentence.lower() for word in question.lower().split() if len(word) > 3):
310
- tokenizer, model = get_qa_model()
311
- inputs = tokenizer(
312
- question,
313
- sentence,
314
- max_length=512,
315
- truncation="only_second",
316
- padding="max_length",
317
- return_tensors="pt"
318
- )
319
-
320
- with torch.no_grad():
321
- outputs = model(
322
- input_ids=inputs["input_ids"],
323
- attention_mask=inputs["attention_mask"]
324
- )
325
-
326
- answer_start = torch.argmax(outputs.start_logits)
327
- answer_end = torch.argmax(outputs.end_logits) + 1
328
- answer = tokenizer.decode(
329
- inputs["input_ids"][0][answer_start:answer_end],
330
- skip_special_tokens=True
331
- ).strip()
332
-
333
- if answer and answer.lower() not in ["no answer", "no answer found"]:
334
- return {
335
- "answer": answer,
336
- "confidence": 0.7 # Slightly lower confidence for fallback
337
- }
338
- return None
339
-
340
- def answer_with_semantic_search(question: str, context: str) -> dict:
341
- """Use semantic similarity to find relevant answer"""
342
- model = SentenceTransformer('all-MiniLM-L6-v2')
343
- sentences = [s.strip() for s in context.split('.') if len(s.strip()) > 20]
344
-
345
- if not sentences:
346
- return {"answer": "No answer found", "confidence": 0.0}
347
-
348
- question_embedding = model.encode(question, convert_to_tensor=True)
349
- sentence_embeddings = model.encode(sentences, convert_to_tensor=True)
350
- cos_scores = util.cos_sim(question_embedding, sentence_embeddings)[0]
351
- best_idx = torch.argmax(cos_scores).item()
352
-
353
- if cos_scores[best_idx] > 0.5:
354
- result = answer_with_model(question, sentences[best_idx])
355
- if not is_no_answer(result["answer"]):
356
- result["confidence"] = min(result["confidence"] + 0.1, 0.9) # Boost confidence slightly
357
- return result
358
-
359
- return {"answer": "No answer found", "confidence": 0.0}
360
-
361
- def answer_with_keywords(question: str, context: str) -> Optional[dict]:
362
- """Simple keyword matching fallback"""
363
- keywords = [word for word in question.lower().split() if len(word) > 3]
364
- sentences = [s.strip() for s in context.split('.') if any(kw in s.lower() for kw in keywords)]
365
-
366
- if sentences:
367
- return {
368
- "answer": sentences[0],
369
- "confidence": 0.6
370
- }
371
- return None
372
-
373
- def is_no_answer(answer: str) -> bool:
374
- """Check if answer indicates no answer found"""
375
- return answer.lower() in ["no answer", "no answer found", "no answer found in the given context"]
376
-
377
- def format_response(result: dict, context: str) -> dict:
378
- """Format final response"""
379
- return {
380
- "answer": result["answer"],
381
- "confidence": result["confidence"],
382
- "context_used": context[:500] + "..." if len(context) > 500 else context
383
- }
384
 
385
  @app.post("/api/caption")
386
  async def caption_image(file: UploadFile = File(...)):
 
64
  @lru_cache()
65
  def get_translator():
66
  return pipeline("translation", model="facebook/nllb-200-distilled-600M")
 
67
  @lru_cache()
68
  def get_qa_model():
69
+ return pipeline("question-answering", model="deepset/roberta-base-squad2")
70
+
71
+
72
+ #########################################################
73
+ def extract_text_from_file(file_content: bytes, file_ext: str):
74
+ text = ""
75
+
 
 
76
  try:
77
+ if file_ext == "docx":
78
+ doc = Document(io.BytesIO(file_content))
79
+ text = " ".join([p.text for p in doc.paragraphs if p.text.strip()])
80
+ elif file_ext in ["xls", "xlsx"]:
81
+ df = pd.read_excel(io.BytesIO(file_content))
82
+ text = " ".join(df.iloc[:, 0].dropna().astype(str).tolist()) # Extract first column text
83
+ elif file_ext == "pptx":
84
+ ppt = Presentation(io.BytesIO(file_content))
85
+ text = " ".join([shape.text for slide in ppt.slides for shape in slide.shapes if hasattr(shape, "text")])
86
+ elif file_ext == "pdf":
87
+ pdf = fitz.open(stream=file_content, filetype="pdf")
88
+ text = " ".join([page.get_text("text") for page in pdf])
89
+ elif file_ext in ["jpg", "jpeg", "png"]:
90
+ image = Image.open(io.BytesIO(file_content))
91
+ text = pytesseract.image_to_string(image) # OCR for text extraction
92
+ else:
93
+ raise HTTPException(status_code=400, detail="Unsupported file format.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
  except Exception as e:
95
+ raise HTTPException(status_code=500, detail=f"Error extracting text: {str(e)}")
96
+
97
+ if not text.strip():
98
+ raise HTTPException(status_code=400, detail="No extractable text found.")
99
+
100
+ return text
101
+
102
+ ########################################################
103
  @app.get("/", response_class=HTMLResponse)
104
  def home ():
105
  with open("static/indexAI.html","r") as file :
 
147
  return {"summary": summary}
148
  except Exception as e:
149
  raise HTTPException(500, f"Error processing document: {str(e)}")
150
+ #################################################################
151
+ @app.post("/qa")
152
+ async def question_answering(file: UploadFile = File(...), question: str = Form(...)):
 
 
 
 
153
  try:
154
+ content = await file.read()
155
+ file_ext = file.filename.split(".")[-1].lower()
156
+ extracted_text = extract_text_from_file(content, file_ext)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
157
 
158
+ # 🔥 Step 1: Summarize first (if text is too long)
159
+ if len(extracted_text) > 2000:
160
+ summarizer = get_summarizer()
161
+ summarized_text = summarizer(extracted_text[:2000], max_length=500, min_length=100, do_sample=False)[0]["summary_text"]
162
+ else:
163
+ summarized_text = extracted_text
164
 
165
+ # 🔥 Step 2: Use summarized text for QA
166
+ qa_model = get_qa_model()
167
+ answer = qa_model(question=question, context=summarized_text) # Fixed argument format
168
 
169
+ return {"question": question, "answer": answer["answer"], "context_used": summarized_text}
170
+
171
+ except Exception as e:
172
+ raise HTTPException(status_code=500, detail=f"Error processing question: {str(e)}")
173
+ ###############################################
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
174
 
175
  @app.post("/api/caption")
176
  async def caption_image(file: UploadFile = File(...)):