chenguittiMaroua commited on
Commit
4b33aa9
·
verified ·
1 Parent(s): 3c52e24

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +79 -49
main.py CHANGED
@@ -47,70 +47,100 @@ def get_translator():
47
  # Optimized QA Model Loading
48
  @lru_cache()
49
  def get_qa_model():
 
50
  model_name = "deepset/roberta-base-squad2"
51
  tokenizer = AutoTokenizer.from_pretrained(model_name)
52
  model = AutoModelForQuestionAnswering.from_pretrained(model_name)
53
  return tokenizer, model
 
54
  def answer_question(question: str, context: str) -> dict:
55
  """
56
- Fixed QA function without offset_mapping issue
57
  """
58
  tokenizer, model = get_qa_model()
59
 
60
- # Tokenize inputs without offset mapping
61
- inputs = tokenizer(
62
- question,
63
- context,
64
- max_length=512,
65
- truncation="only_second",
66
- stride=128,
67
- return_overflowing_tokens=True,
68
- padding="max_length",
69
- return_tensors="pt"
70
- )
71
-
72
- # Get model predictions
73
- with torch.no_grad():
74
- outputs = model(**{k: v for k, v in inputs.items() if k != "offset_mapping"})
75
-
76
- # Process each possible answer
77
- answers = []
78
- for i in range(inputs["input_ids"].shape[0]):
79
- start_logits = outputs.start_logits[i]
80
- end_logits = outputs.end_logits[i]
81
-
82
- # Get the most probable start and end positions
83
- answer_start = torch.argmax(start_logits)
84
- answer_end = torch.argmax(end_logits) + 1
85
-
86
- # Skip invalid answers
87
- if answer_start >= answer_end:
88
- continue
89
-
90
- # Get the answer text
91
  answer = tokenizer.decode(
92
- inputs["input_ids"][i][answer_start:answer_end],
93
  skip_special_tokens=True
94
  ).strip()
95
 
96
- # Calculate confidence score
97
- start_score = torch.nn.functional.softmax(start_logits, dim=0)[answer_start]
98
- end_score = torch.nn.functional.softmax(end_logits, dim=0)[answer_end-1]
99
  confidence = float((start_score + end_score) / 2)
100
 
101
- if answer: # Only keep valid answers
102
- answers.append({
103
- "answer": answer,
104
- "confidence": confidence,
105
- "start": answer_start.item(),
106
- "end": answer_end.item()
107
- })
108
-
109
- if not answers:
110
- return {"answer": "No answer found", "confidence": 0.0}
111
-
112
- # Return the answer with highest confidence
113
- return max(answers, key=lambda x: x["confidence"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
 
115
 
116
  # Home Route
 
47
  # Optimized QA Model Loading
48
  @lru_cache()
49
  def get_qa_model():
50
+ """Simplified model loading without unexpected parameters"""
51
  model_name = "deepset/roberta-base-squad2"
52
  tokenizer = AutoTokenizer.from_pretrained(model_name)
53
  model = AutoModelForQuestionAnswering.from_pretrained(model_name)
54
  return tokenizer, model
55
+
56
  def answer_question(question: str, context: str) -> dict:
57
  """
58
+ Robust QA function with minimal parameters
59
  """
60
  tokenizer, model = get_qa_model()
61
 
62
+ try:
63
+ # Simple tokenization without problematic parameters
64
+ inputs = tokenizer(
65
+ question,
66
+ context,
67
+ max_length=512,
68
+ truncation="only_second",
69
+ padding="max_length",
70
+ return_tensors="pt"
71
+ )
72
+
73
+ # Filter inputs to only include what the model expects
74
+ model_inputs = {
75
+ "input_ids": inputs["input_ids"],
76
+ "attention_mask": inputs["attention_mask"]
77
+ }
78
+
79
+ with torch.no_grad():
80
+ outputs = model(**model_inputs)
81
+
82
+ # Get the most probable answer
83
+ answer_start = torch.argmax(outputs.start_logits)
84
+ answer_end = torch.argmax(outputs.end_logits) + 1
85
+
 
 
 
 
 
 
 
86
  answer = tokenizer.decode(
87
+ inputs["input_ids"][0][answer_start:answer_end],
88
  skip_special_tokens=True
89
  ).strip()
90
 
91
+ # Calculate confidence
92
+ start_score = torch.nn.functional.softmax(outputs.start_logits, dim=1)[0][answer_start]
93
+ end_score = torch.nn.functional.softmax(outputs.end_logits, dim=1)[0][answer_end-1]
94
  confidence = float((start_score + end_score) / 2)
95
 
96
+ return {
97
+ "answer": answer if answer else "No answer found",
98
+ "confidence": confidence
99
+ }
100
+
101
+ except Exception as e:
102
+ return {
103
+ "answer": f"Error processing answer: {str(e)}",
104
+ "confidence": 0.0
105
+ }
106
+
107
+ @app.post("/ask")
108
+ async def ask_question(
109
+ question: str = Form(...),
110
+ file: Optional[UploadFile] = File(None),
111
+ text: Optional[str] = Form(None)
112
+ ):
113
+ """
114
+ Final robust QA endpoint
115
+ """
116
+ try:
117
+ # [Keep your existing context extraction code here]
118
+ # ...
119
+
120
+ if not context.strip():
121
+ raise HTTPException(status_code=400, detail="No extractable content found.")
122
+
123
+ # Clean context
124
+ context = " ".join(context.split())
125
+
126
+ # Get answer with error handling
127
+ result = answer_question(question, context)
128
+
129
+ if result["confidence"] < 0.1:
130
+ # [Keep your fallback semantic search if you want]
131
+ pass
132
+
133
+ return {
134
+ "answer": result["answer"],
135
+ "confidence": result["confidence"],
136
+ "context_used": context[:500] + "..." if len(context) > 500 else context
137
+ }
138
+
139
+ except Exception as e:
140
+ raise HTTPException(
141
+ status_code=500,
142
+ detail=f"Error processing question: {str(e)}"
143
+ )
144
 
145
 
146
  # Home Route