Spaces:
Sleeping
Sleeping
Update main.py
Browse files
main.py
CHANGED
|
@@ -125,7 +125,7 @@ def get_summarizer():
|
|
| 125 |
def get_qa_model():
|
| 126 |
global qa_model
|
| 127 |
if qa_model is None:
|
| 128 |
-
qa_model= pipeline("question-answering", model="
|
| 129 |
return qa_model
|
| 130 |
|
| 131 |
def get_image_captioner():
|
|
@@ -804,77 +804,33 @@ async def summarize_document(request: Request, file: UploadFile = File(...)):
|
|
| 804 |
|
| 805 |
|
| 806 |
|
|
|
|
|
|
|
| 807 |
@app.post("/qa")
|
| 808 |
@limiter.limit("5/minute")
|
| 809 |
async def question_answering(
|
| 810 |
request: Request,
|
| 811 |
-
file: UploadFile = File(
|
| 812 |
question: str = Form(...),
|
| 813 |
language: str = Form("fr")
|
| 814 |
):
|
| 815 |
-
|
| 816 |
-
Enhanced QA endpoint with:
|
| 817 |
-
- Better error handling
|
| 818 |
-
- Model validation
|
| 819 |
-
- Detailed logging
|
| 820 |
-
- Original functionality preserved
|
| 821 |
-
"""
|
| 822 |
-
# Validate input immediately
|
| 823 |
-
if not file.filename:
|
| 824 |
-
raise HTTPException(400, "No filename provided")
|
| 825 |
-
|
| 826 |
if not question.strip():
|
| 827 |
raise HTTPException(400, "Question cannot be empty")
|
| 828 |
|
| 829 |
-
|
| 830 |
-
|
| 831 |
-
|
| 832 |
-
|
| 833 |
-
|
|
|
|
|
|
|
|
|
|
| 834 |
try:
|
|
|
|
| 835 |
text = extract_text(content, file_ext)
|
| 836 |
-
if not text.strip():
|
| 837 |
-
raise HTTPException(400, "No extractable text found")
|
| 838 |
-
|
| 839 |
-
# Clean and truncate text (preserve original logic)
|
| 840 |
text = re.sub(r'\s+', ' ', text).strip()[:5000]
|
| 841 |
-
except Exception as e:
|
| 842 |
-
logger.error(f"Text extraction failed for {file.filename}: {str(e)}", exc_info=True)
|
| 843 |
-
raise HTTPException(422, f"Failed to process {file_ext} file: {str(e)}")
|
| 844 |
|
| 845 |
-
# 3. Theme Detection (original logic preserved)
|
| 846 |
-
theme_keywords = ["thème", "sujet principal", "quoi le sujet", "theme", "main topic"]
|
| 847 |
-
if any(kw in question.lower() for kw in theme_keywords):
|
| 848 |
-
try:
|
| 849 |
-
summarizer = get_summarizer()
|
| 850 |
-
summary_output = summarizer(
|
| 851 |
-
text,
|
| 852 |
-
max_length=min(100, len(text)//4),
|
| 853 |
-
min_length=30,
|
| 854 |
-
do_sample=False,
|
| 855 |
-
truncation=True
|
| 856 |
-
)
|
| 857 |
-
|
| 858 |
-
theme = summary_output[0].get("summary_text", text[:200] + "...")
|
| 859 |
-
return {
|
| 860 |
-
"question": question,
|
| 861 |
-
"answer": f"Le document traite principalement de : {theme}",
|
| 862 |
-
"confidence": 0.95,
|
| 863 |
-
"language": language
|
| 864 |
-
}
|
| 865 |
-
except Exception as e:
|
| 866 |
-
logger.warning(f"Theme detection fallback for '{question}': {str(e)}")
|
| 867 |
-
theme = text[:200] + ("..." if len(text) > 200 else "")
|
| 868 |
-
return {
|
| 869 |
-
"question": question,
|
| 870 |
-
"answer": f"D'après le document : {theme}",
|
| 871 |
-
"confidence": 0.7,
|
| 872 |
-
"language": language,
|
| 873 |
-
"warning": "theme_summary_fallback"
|
| 874 |
-
}
|
| 875 |
-
|
| 876 |
-
# 4. Standard QA (original logic preserved)
|
| 877 |
-
try:
|
| 878 |
qa = get_qa_model()
|
| 879 |
result = qa(question=question, context=text[:3000])
|
| 880 |
|
|
@@ -882,17 +838,41 @@ async def question_answering(
|
|
| 882 |
"question": question,
|
| 883 |
"answer": result["answer"],
|
| 884 |
"confidence": result["score"],
|
|
|
|
| 885 |
"language": language
|
| 886 |
}
|
| 887 |
except Exception as e:
|
| 888 |
-
logger.error(f"QA failed
|
| 889 |
-
raise HTTPException(500, "Failed to
|
| 890 |
|
| 891 |
-
|
| 892 |
-
|
| 893 |
-
|
| 894 |
-
|
| 895 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 896 |
@app.post("/visualize/natural")
|
| 897 |
async def natural_language_visualization(
|
| 898 |
file: UploadFile = File(...),
|
|
|
|
| 125 |
def get_qa_model():
|
| 126 |
global qa_model
|
| 127 |
if qa_model is None:
|
| 128 |
+
qa_model= pipe = pipeline("question-answering", model="deepset/roberta-base-squad2")
|
| 129 |
return qa_model
|
| 130 |
|
| 131 |
def get_image_captioner():
|
|
|
|
| 804 |
|
| 805 |
|
| 806 |
|
| 807 |
+
from typing import Optional
|
| 808 |
+
|
| 809 |
@app.post("/qa")
|
| 810 |
@limiter.limit("5/minute")
|
| 811 |
async def question_answering(
|
| 812 |
request: Request,
|
| 813 |
+
file: Optional[UploadFile] = File(None), # Make file optional
|
| 814 |
question: str = Form(...),
|
| 815 |
language: str = Form("fr")
|
| 816 |
):
|
| 817 |
+
# Validate question
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 818 |
if not question.strip():
|
| 819 |
raise HTTPException(400, "Question cannot be empty")
|
| 820 |
|
| 821 |
+
# Check if the question is about the document
|
| 822 |
+
is_doc_question = any(
|
| 823 |
+
kw in question.lower()
|
| 824 |
+
for kw in ["document", "file", "text", "this pdf", "this doc"]
|
| 825 |
+
)
|
| 826 |
+
|
| 827 |
+
# (A) If file is provided and question is about it → Document QA
|
| 828 |
+
if file and is_doc_question:
|
| 829 |
try:
|
| 830 |
+
file_ext, content = await process_uploaded_file(file)
|
| 831 |
text = extract_text(content, file_ext)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 832 |
text = re.sub(r'\s+', ' ', text).strip()[:5000]
|
|
|
|
|
|
|
|
|
|
| 833 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 834 |
qa = get_qa_model()
|
| 835 |
result = qa(question=question, context=text[:3000])
|
| 836 |
|
|
|
|
| 838 |
"question": question,
|
| 839 |
"answer": result["answer"],
|
| 840 |
"confidence": result["score"],
|
| 841 |
+
"source": "document",
|
| 842 |
"language": language
|
| 843 |
}
|
| 844 |
except Exception as e:
|
| 845 |
+
logger.error(f"Doc QA failed: {str(e)}")
|
| 846 |
+
raise HTTPException(500, "Failed to analyze document")
|
| 847 |
|
| 848 |
+
# (B) If no file or general question → Open-domain QA (RAG)
|
| 849 |
+
else:
|
| 850 |
+
try:
|
| 851 |
+
rag = get_rag_model()
|
| 852 |
+
answer = rag(question)[0]["generated_text"]
|
| 853 |
+
|
| 854 |
+
return {
|
| 855 |
+
"question": question,
|
| 856 |
+
"answer": answer,
|
| 857 |
+
"confidence": 0.8, # RAG doesn't return scores
|
| 858 |
+
"source": "general knowledge",
|
| 859 |
+
"language": language
|
| 860 |
+
}
|
| 861 |
+
except Exception as e:
|
| 862 |
+
logger.error(f"RAG failed: {str(e)}")
|
| 863 |
+
raise HTTPException(500, "Failed to fetch general answer")
|
| 864 |
+
|
| 865 |
+
|
| 866 |
+
|
| 867 |
+
|
| 868 |
+
|
| 869 |
+
|
| 870 |
+
|
| 871 |
+
|
| 872 |
+
|
| 873 |
+
|
| 874 |
+
|
| 875 |
+
|
| 876 |
@app.post("/visualize/natural")
|
| 877 |
async def natural_language_visualization(
|
| 878 |
file: UploadFile = File(...),
|