Spaces:
Sleeping
Sleeping
Update main.py
Browse files
main.py
CHANGED
|
@@ -47,70 +47,100 @@ def get_translator():
|
|
| 47 |
# Optimized QA Model Loading
|
| 48 |
@lru_cache()
|
| 49 |
def get_qa_model():
|
|
|
|
| 50 |
model_name = "deepset/roberta-base-squad2"
|
| 51 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 52 |
model = AutoModelForQuestionAnswering.from_pretrained(model_name)
|
| 53 |
return tokenizer, model
|
|
|
|
| 54 |
def answer_question(question: str, context: str) -> dict:
|
| 55 |
"""
|
| 56 |
-
|
| 57 |
"""
|
| 58 |
tokenizer, model = get_qa_model()
|
| 59 |
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
answer_end = torch.argmax(end_logits) + 1
|
| 85 |
-
|
| 86 |
-
# Skip invalid answers
|
| 87 |
-
if answer_start >= answer_end:
|
| 88 |
-
continue
|
| 89 |
-
|
| 90 |
-
# Get the answer text
|
| 91 |
answer = tokenizer.decode(
|
| 92 |
-
inputs["input_ids"][
|
| 93 |
skip_special_tokens=True
|
| 94 |
).strip()
|
| 95 |
|
| 96 |
-
# Calculate confidence
|
| 97 |
-
start_score = torch.nn.functional.softmax(start_logits, dim=
|
| 98 |
-
end_score = torch.nn.functional.softmax(end_logits, dim=
|
| 99 |
confidence = float((start_score + end_score) / 2)
|
| 100 |
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 114 |
|
| 115 |
|
| 116 |
# Home Route
|
|
|
|
| 47 |
# Optimized QA Model Loading
|
| 48 |
@lru_cache()
|
| 49 |
def get_qa_model():
|
| 50 |
+
"""Simplified model loading without unexpected parameters"""
|
| 51 |
model_name = "deepset/roberta-base-squad2"
|
| 52 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 53 |
model = AutoModelForQuestionAnswering.from_pretrained(model_name)
|
| 54 |
return tokenizer, model
|
| 55 |
+
|
| 56 |
def answer_question(question: str, context: str) -> dict:
|
| 57 |
"""
|
| 58 |
+
Robust QA function with minimal parameters
|
| 59 |
"""
|
| 60 |
tokenizer, model = get_qa_model()
|
| 61 |
|
| 62 |
+
try:
|
| 63 |
+
# Simple tokenization without problematic parameters
|
| 64 |
+
inputs = tokenizer(
|
| 65 |
+
question,
|
| 66 |
+
context,
|
| 67 |
+
max_length=512,
|
| 68 |
+
truncation="only_second",
|
| 69 |
+
padding="max_length",
|
| 70 |
+
return_tensors="pt"
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
# Filter inputs to only include what the model expects
|
| 74 |
+
model_inputs = {
|
| 75 |
+
"input_ids": inputs["input_ids"],
|
| 76 |
+
"attention_mask": inputs["attention_mask"]
|
| 77 |
+
}
|
| 78 |
+
|
| 79 |
+
with torch.no_grad():
|
| 80 |
+
outputs = model(**model_inputs)
|
| 81 |
+
|
| 82 |
+
# Get the most probable answer
|
| 83 |
+
answer_start = torch.argmax(outputs.start_logits)
|
| 84 |
+
answer_end = torch.argmax(outputs.end_logits) + 1
|
| 85 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 86 |
answer = tokenizer.decode(
|
| 87 |
+
inputs["input_ids"][0][answer_start:answer_end],
|
| 88 |
skip_special_tokens=True
|
| 89 |
).strip()
|
| 90 |
|
| 91 |
+
# Calculate confidence
|
| 92 |
+
start_score = torch.nn.functional.softmax(outputs.start_logits, dim=1)[0][answer_start]
|
| 93 |
+
end_score = torch.nn.functional.softmax(outputs.end_logits, dim=1)[0][answer_end-1]
|
| 94 |
confidence = float((start_score + end_score) / 2)
|
| 95 |
|
| 96 |
+
return {
|
| 97 |
+
"answer": answer if answer else "No answer found",
|
| 98 |
+
"confidence": confidence
|
| 99 |
+
}
|
| 100 |
+
|
| 101 |
+
except Exception as e:
|
| 102 |
+
return {
|
| 103 |
+
"answer": f"Error processing answer: {str(e)}",
|
| 104 |
+
"confidence": 0.0
|
| 105 |
+
}
|
| 106 |
+
|
| 107 |
+
@app.post("/ask")
|
| 108 |
+
async def ask_question(
|
| 109 |
+
question: str = Form(...),
|
| 110 |
+
file: Optional[UploadFile] = File(None),
|
| 111 |
+
text: Optional[str] = Form(None)
|
| 112 |
+
):
|
| 113 |
+
"""
|
| 114 |
+
Final robust QA endpoint
|
| 115 |
+
"""
|
| 116 |
+
try:
|
| 117 |
+
# [Keep your existing context extraction code here]
|
| 118 |
+
# ...
|
| 119 |
+
|
| 120 |
+
if not context.strip():
|
| 121 |
+
raise HTTPException(status_code=400, detail="No extractable content found.")
|
| 122 |
+
|
| 123 |
+
# Clean context
|
| 124 |
+
context = " ".join(context.split())
|
| 125 |
+
|
| 126 |
+
# Get answer with error handling
|
| 127 |
+
result = answer_question(question, context)
|
| 128 |
+
|
| 129 |
+
if result["confidence"] < 0.1:
|
| 130 |
+
# [Keep your fallback semantic search if you want]
|
| 131 |
+
pass
|
| 132 |
+
|
| 133 |
+
return {
|
| 134 |
+
"answer": result["answer"],
|
| 135 |
+
"confidence": result["confidence"],
|
| 136 |
+
"context_used": context[:500] + "..." if len(context) > 500 else context
|
| 137 |
+
}
|
| 138 |
+
|
| 139 |
+
except Exception as e:
|
| 140 |
+
raise HTTPException(
|
| 141 |
+
status_code=500,
|
| 142 |
+
detail=f"Error processing question: {str(e)}"
|
| 143 |
+
)
|
| 144 |
|
| 145 |
|
| 146 |
# Home Route
|