Spaces:
Sleeping
Sleeping
Update main.py
Browse files
main.py
CHANGED
|
@@ -64,87 +64,42 @@ def get_image_captioning():
|
|
| 64 |
@lru_cache()
|
| 65 |
def get_translator():
|
| 66 |
return pipeline("translation", model="facebook/nllb-200-distilled-600M")
|
| 67 |
-
|
| 68 |
@lru_cache()
|
| 69 |
def get_qa_model():
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
tokenizer, model = get_qa_model()
|
| 78 |
-
|
| 79 |
try:
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
)
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
answer_end = torch.argmax(outputs.end_logits) + 1
|
| 98 |
-
answer = tokenizer.decode(
|
| 99 |
-
inputs["input_ids"][0][answer_start:answer_end],
|
| 100 |
-
skip_special_tokens=True
|
| 101 |
-
).strip()
|
| 102 |
-
|
| 103 |
-
# Calculate confidence
|
| 104 |
-
start_score = torch.max(torch.nn.functional.softmax(outputs.start_logits, dim=1)).item()
|
| 105 |
-
end_score = torch.max(torch.nn.functional.softmax(outputs.end_logits, dim=1)).item()
|
| 106 |
-
confidence = (start_score + end_score) / 2
|
| 107 |
-
|
| 108 |
-
# If no answer found, try sentence by sentence
|
| 109 |
-
if not answer or confidence < 0.5:
|
| 110 |
-
sentences = [s.strip() for s in context.split('.') if len(s.strip()) > 10]
|
| 111 |
-
for sentence in sentences:
|
| 112 |
-
if any(word in sentence.lower() for word in question.lower().split()):
|
| 113 |
-
inputs = tokenizer(
|
| 114 |
-
question,
|
| 115 |
-
sentence,
|
| 116 |
-
max_length=512,
|
| 117 |
-
truncation="only_second",
|
| 118 |
-
padding="max_length",
|
| 119 |
-
return_tensors="pt"
|
| 120 |
-
)
|
| 121 |
-
with torch.no_grad():
|
| 122 |
-
outputs = model(
|
| 123 |
-
input_ids=inputs["input_ids"],
|
| 124 |
-
attention_mask=inputs["attention_mask"]
|
| 125 |
-
)
|
| 126 |
-
temp_start = torch.argmax(outputs.start_logits)
|
| 127 |
-
temp_end = torch.argmax(outputs.end_logits) + 1
|
| 128 |
-
temp_answer = tokenizer.decode(
|
| 129 |
-
inputs["input_ids"][0][temp_start:temp_end],
|
| 130 |
-
skip_special_tokens=True
|
| 131 |
-
).strip()
|
| 132 |
-
if temp_answer:
|
| 133 |
-
return {
|
| 134 |
-
"answer": temp_answer,
|
| 135 |
-
"confidence": 0.7, # Slightly lower confidence for fallback
|
| 136 |
-
"context_used": sentence
|
| 137 |
-
}
|
| 138 |
-
|
| 139 |
-
return {
|
| 140 |
-
"answer": answer if answer else "No answer found in the given context",
|
| 141 |
-
"confidence": confidence
|
| 142 |
-
}
|
| 143 |
except Exception as e:
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 148 |
@app.get("/", response_class=HTMLResponse)
|
| 149 |
def home ():
|
| 150 |
with open("static/indexAI.html","r") as file :
|
|
@@ -192,195 +147,30 @@ async def summarize_document(file: UploadFile = File(...)):
|
|
| 192 |
return {"summary": summary}
|
| 193 |
except Exception as e:
|
| 194 |
raise HTTPException(500, f"Error processing document: {str(e)}")
|
| 195 |
-
|
| 196 |
-
@app.post("/
|
| 197 |
-
async def
|
| 198 |
-
question: str = Form(...),
|
| 199 |
-
file: Optional[UploadFile] = File(None),
|
| 200 |
-
text: Optional[str] = Form(None)
|
| 201 |
-
):
|
| 202 |
try:
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
raise HTTPException(400, "No extractable content found")
|
| 207 |
-
|
| 208 |
-
# 2. Clean and prepare context
|
| 209 |
-
context = clean_context(context)
|
| 210 |
-
|
| 211 |
-
# 3. Try primary QA model
|
| 212 |
-
qa_result = answer_with_model(question, context)
|
| 213 |
-
|
| 214 |
-
# 4. If high confidence but no answer found, try sentence-level analysis
|
| 215 |
-
if qa_result["confidence"] > 0.6 and is_no_answer(qa_result["answer"]):
|
| 216 |
-
sentence_result = answer_from_sentences(question, context)
|
| 217 |
-
if sentence_result:
|
| 218 |
-
return format_response(sentence_result, context)
|
| 219 |
-
|
| 220 |
-
# 5. If low confidence, try semantic similarity
|
| 221 |
-
if qa_result["confidence"] < 0.4:
|
| 222 |
-
semantic_result = answer_with_semantic_search(question, context)
|
| 223 |
-
if semantic_result["confidence"] > qa_result["confidence"]:
|
| 224 |
-
return format_response(semantic_result, context)
|
| 225 |
-
|
| 226 |
-
# 6. Final fallback to keyword matching
|
| 227 |
-
if is_no_answer(qa_result["answer"]):
|
| 228 |
-
keyword_result = answer_with_keywords(question, context)
|
| 229 |
-
if keyword_result:
|
| 230 |
-
return format_response(keyword_result, context)
|
| 231 |
-
|
| 232 |
-
# 7. Return whatever answer we have
|
| 233 |
-
return format_response(qa_result, context)
|
| 234 |
|
| 235 |
-
|
| 236 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 237 |
|
|
|
|
|
|
|
|
|
|
| 238 |
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
|
| 244 |
-
file_ext = file.filename.split(".")[-1].lower()
|
| 245 |
-
|
| 246 |
-
if file_ext == "pdf":
|
| 247 |
-
pdf = fitz.open(stream=content, filetype="pdf")
|
| 248 |
-
return " ".join([page.get_text("text") for page in pdf])
|
| 249 |
-
elif file_ext == "docx":
|
| 250 |
-
doc = Document(io.BytesIO(content))
|
| 251 |
-
return " ".join([p.text for p in doc.paragraphs if p.text.strip()])
|
| 252 |
-
elif file_ext in ["xls", "xlsx"]:
|
| 253 |
-
df = pd.read_excel(io.BytesIO(content))
|
| 254 |
-
return " ".join(df.iloc[:, 0].dropna().astype(str).tolist())
|
| 255 |
-
elif file_ext == "pptx":
|
| 256 |
-
ppt = Presentation(io.BytesIO(content))
|
| 257 |
-
return " ".join([shape.text for slide in ppt.slides for shape in slide.shapes if hasattr(shape, "text")])
|
| 258 |
-
elif file_ext in ["jpg", "jpeg", "png"]:
|
| 259 |
-
image = Image.open(io.BytesIO(content))
|
| 260 |
-
try:
|
| 261 |
-
context = pytesseract.image_to_string(image)
|
| 262 |
-
return context if context.strip() else get_image_captioning()(image)[0]['generated_text']
|
| 263 |
-
except:
|
| 264 |
-
return get_image_captioning()(image)[0]['generated_text']
|
| 265 |
-
return text or ""
|
| 266 |
-
|
| 267 |
-
def clean_context(context: str) -> str:
|
| 268 |
-
"""Clean and normalize context text"""
|
| 269 |
-
context = " ".join(context.split()) # Remove excessive whitespace
|
| 270 |
-
return context[:10000] # Limit context size
|
| 271 |
-
|
| 272 |
-
def answer_with_model(question: str, context: str) -> dict:
|
| 273 |
-
"""Use QA model to find answer"""
|
| 274 |
-
tokenizer, model = get_qa_model()
|
| 275 |
-
inputs = tokenizer(
|
| 276 |
-
question,
|
| 277 |
-
context,
|
| 278 |
-
max_length=512,
|
| 279 |
-
truncation="only_second",
|
| 280 |
-
padding="max_length",
|
| 281 |
-
return_tensors="pt"
|
| 282 |
-
)
|
| 283 |
-
|
| 284 |
-
with torch.no_grad():
|
| 285 |
-
outputs = model(
|
| 286 |
-
input_ids=inputs["input_ids"],
|
| 287 |
-
attention_mask=inputs["attention_mask"]
|
| 288 |
-
)
|
| 289 |
-
|
| 290 |
-
answer_start = torch.argmax(outputs.start_logits)
|
| 291 |
-
answer_end = torch.argmax(outputs.end_logits) + 1
|
| 292 |
-
answer = tokenizer.decode(
|
| 293 |
-
inputs["input_ids"][0][answer_start:answer_end],
|
| 294 |
-
skip_special_tokens=True
|
| 295 |
-
).strip()
|
| 296 |
-
|
| 297 |
-
confidence = (torch.max(torch.nn.functional.softmax(outputs.start_logits, dim=1)).item() +
|
| 298 |
-
torch.max(torch.nn.functional.softmax(outputs.end_logits, dim=1)).item()) / 2
|
| 299 |
-
|
| 300 |
-
return {
|
| 301 |
-
"answer": answer if answer else "No answer found",
|
| 302 |
-
"confidence": confidence
|
| 303 |
-
}
|
| 304 |
-
|
| 305 |
-
def answer_from_sentences(question: str, context: str) -> Optional[dict]:
|
| 306 |
-
"""Try to find answer by analyzing individual sentences"""
|
| 307 |
-
sentences = [s.strip() for s in context.split('.') if len(s.strip()) > 20]
|
| 308 |
-
for sentence in sentences:
|
| 309 |
-
if any(word in sentence.lower() for word in question.lower().split() if len(word) > 3):
|
| 310 |
-
tokenizer, model = get_qa_model()
|
| 311 |
-
inputs = tokenizer(
|
| 312 |
-
question,
|
| 313 |
-
sentence,
|
| 314 |
-
max_length=512,
|
| 315 |
-
truncation="only_second",
|
| 316 |
-
padding="max_length",
|
| 317 |
-
return_tensors="pt"
|
| 318 |
-
)
|
| 319 |
-
|
| 320 |
-
with torch.no_grad():
|
| 321 |
-
outputs = model(
|
| 322 |
-
input_ids=inputs["input_ids"],
|
| 323 |
-
attention_mask=inputs["attention_mask"]
|
| 324 |
-
)
|
| 325 |
-
|
| 326 |
-
answer_start = torch.argmax(outputs.start_logits)
|
| 327 |
-
answer_end = torch.argmax(outputs.end_logits) + 1
|
| 328 |
-
answer = tokenizer.decode(
|
| 329 |
-
inputs["input_ids"][0][answer_start:answer_end],
|
| 330 |
-
skip_special_tokens=True
|
| 331 |
-
).strip()
|
| 332 |
-
|
| 333 |
-
if answer and answer.lower() not in ["no answer", "no answer found"]:
|
| 334 |
-
return {
|
| 335 |
-
"answer": answer,
|
| 336 |
-
"confidence": 0.7 # Slightly lower confidence for fallback
|
| 337 |
-
}
|
| 338 |
-
return None
|
| 339 |
-
|
| 340 |
-
def answer_with_semantic_search(question: str, context: str) -> dict:
|
| 341 |
-
"""Use semantic similarity to find relevant answer"""
|
| 342 |
-
model = SentenceTransformer('all-MiniLM-L6-v2')
|
| 343 |
-
sentences = [s.strip() for s in context.split('.') if len(s.strip()) > 20]
|
| 344 |
-
|
| 345 |
-
if not sentences:
|
| 346 |
-
return {"answer": "No answer found", "confidence": 0.0}
|
| 347 |
-
|
| 348 |
-
question_embedding = model.encode(question, convert_to_tensor=True)
|
| 349 |
-
sentence_embeddings = model.encode(sentences, convert_to_tensor=True)
|
| 350 |
-
cos_scores = util.cos_sim(question_embedding, sentence_embeddings)[0]
|
| 351 |
-
best_idx = torch.argmax(cos_scores).item()
|
| 352 |
-
|
| 353 |
-
if cos_scores[best_idx] > 0.5:
|
| 354 |
-
result = answer_with_model(question, sentences[best_idx])
|
| 355 |
-
if not is_no_answer(result["answer"]):
|
| 356 |
-
result["confidence"] = min(result["confidence"] + 0.1, 0.9) # Boost confidence slightly
|
| 357 |
-
return result
|
| 358 |
-
|
| 359 |
-
return {"answer": "No answer found", "confidence": 0.0}
|
| 360 |
-
|
| 361 |
-
def answer_with_keywords(question: str, context: str) -> Optional[dict]:
|
| 362 |
-
"""Simple keyword matching fallback"""
|
| 363 |
-
keywords = [word for word in question.lower().split() if len(word) > 3]
|
| 364 |
-
sentences = [s.strip() for s in context.split('.') if any(kw in s.lower() for kw in keywords)]
|
| 365 |
-
|
| 366 |
-
if sentences:
|
| 367 |
-
return {
|
| 368 |
-
"answer": sentences[0],
|
| 369 |
-
"confidence": 0.6
|
| 370 |
-
}
|
| 371 |
-
return None
|
| 372 |
-
|
| 373 |
-
def is_no_answer(answer: str) -> bool:
|
| 374 |
-
"""Check if answer indicates no answer found"""
|
| 375 |
-
return answer.lower() in ["no answer", "no answer found", "no answer found in the given context"]
|
| 376 |
-
|
| 377 |
-
def format_response(result: dict, context: str) -> dict:
|
| 378 |
-
"""Format final response"""
|
| 379 |
-
return {
|
| 380 |
-
"answer": result["answer"],
|
| 381 |
-
"confidence": result["confidence"],
|
| 382 |
-
"context_used": context[:500] + "..." if len(context) > 500 else context
|
| 383 |
-
}
|
| 384 |
|
| 385 |
@app.post("/api/caption")
|
| 386 |
async def caption_image(file: UploadFile = File(...)):
|
|
|
|
| 64 |
@lru_cache()
|
| 65 |
def get_translator():
|
| 66 |
return pipeline("translation", model="facebook/nllb-200-distilled-600M")
|
|
|
|
| 67 |
@lru_cache()
|
| 68 |
def get_qa_model():
|
| 69 |
+
return pipeline("question-answering", model="deepset/roberta-base-squad2")
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
#########################################################
|
| 73 |
+
def extract_text_from_file(file_content: bytes, file_ext: str):
|
| 74 |
+
text = ""
|
| 75 |
+
|
|
|
|
|
|
|
| 76 |
try:
|
| 77 |
+
if file_ext == "docx":
|
| 78 |
+
doc = Document(io.BytesIO(file_content))
|
| 79 |
+
text = " ".join([p.text for p in doc.paragraphs if p.text.strip()])
|
| 80 |
+
elif file_ext in ["xls", "xlsx"]:
|
| 81 |
+
df = pd.read_excel(io.BytesIO(file_content))
|
| 82 |
+
text = " ".join(df.iloc[:, 0].dropna().astype(str).tolist()) # Extract first column text
|
| 83 |
+
elif file_ext == "pptx":
|
| 84 |
+
ppt = Presentation(io.BytesIO(file_content))
|
| 85 |
+
text = " ".join([shape.text for slide in ppt.slides for shape in slide.shapes if hasattr(shape, "text")])
|
| 86 |
+
elif file_ext == "pdf":
|
| 87 |
+
pdf = fitz.open(stream=file_content, filetype="pdf")
|
| 88 |
+
text = " ".join([page.get_text("text") for page in pdf])
|
| 89 |
+
elif file_ext in ["jpg", "jpeg", "png"]:
|
| 90 |
+
image = Image.open(io.BytesIO(file_content))
|
| 91 |
+
text = pytesseract.image_to_string(image) # OCR for text extraction
|
| 92 |
+
else:
|
| 93 |
+
raise HTTPException(status_code=400, detail="Unsupported file format.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 94 |
except Exception as e:
|
| 95 |
+
raise HTTPException(status_code=500, detail=f"Error extracting text: {str(e)}")
|
| 96 |
+
|
| 97 |
+
if not text.strip():
|
| 98 |
+
raise HTTPException(status_code=400, detail="No extractable text found.")
|
| 99 |
+
|
| 100 |
+
return text
|
| 101 |
+
|
| 102 |
+
########################################################
|
| 103 |
@app.get("/", response_class=HTMLResponse)
|
| 104 |
def home ():
|
| 105 |
with open("static/indexAI.html","r") as file :
|
|
|
|
| 147 |
return {"summary": summary}
|
| 148 |
except Exception as e:
|
| 149 |
raise HTTPException(500, f"Error processing document: {str(e)}")
|
| 150 |
+
#################################################################
|
| 151 |
+
@app.post("/qa")
|
| 152 |
+
async def question_answering(file: UploadFile = File(...), question: str = Form(...)):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 153 |
try:
|
| 154 |
+
content = await file.read()
|
| 155 |
+
file_ext = file.filename.split(".")[-1].lower()
|
| 156 |
+
extracted_text = extract_text_from_file(content, file_ext)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 157 |
|
| 158 |
+
# 🔥 Step 1: Summarize first (if text is too long)
|
| 159 |
+
if len(extracted_text) > 2000:
|
| 160 |
+
summarizer = get_summarizer()
|
| 161 |
+
summarized_text = summarizer(extracted_text[:2000], max_length=500, min_length=100, do_sample=False)[0]["summary_text"]
|
| 162 |
+
else:
|
| 163 |
+
summarized_text = extracted_text
|
| 164 |
|
| 165 |
+
# 🔥 Step 2: Use summarized text for QA
|
| 166 |
+
qa_model = get_qa_model()
|
| 167 |
+
answer = qa_model(question=question, context=summarized_text) # Fixed argument format
|
| 168 |
|
| 169 |
+
return {"question": question, "answer": answer["answer"], "context_used": summarized_text}
|
| 170 |
+
|
| 171 |
+
except Exception as e:
|
| 172 |
+
raise HTTPException(status_code=500, detail=f"Error processing question: {str(e)}")
|
| 173 |
+
###############################################
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 174 |
|
| 175 |
@app.post("/api/caption")
|
| 176 |
async def caption_image(file: UploadFile = File(...)):
|