chenguittiMaroua commited on
Commit
1458d30
·
verified ·
1 Parent(s): b3e453a

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +65 -164
main.py CHANGED
@@ -804,49 +804,6 @@ async def summarize_document(request: Request, file: UploadFile = File(...)):
804
 
805
 
806
 
807
- from typing import Optional
808
- from fastapi import HTTPException, UploadFile, Form, Request
809
- from transformers import pipeline
810
- import re
811
- import logging
812
-
813
- logger = logging.getLogger(__name__)
814
-
815
- # Global model caches
816
- QA_MODEL = None
817
- SUMMARIZER = None
818
- GENERATIVE_MODEL = None
819
-
820
- def get_qa_model():
821
- global QA_MODEL
822
- if QA_MODEL is None:
823
- QA_MODEL = pipeline(
824
- "question-answering",
825
- model="deepset/roberta-base-squad2", # Better than BERT for QA
826
- device=0 if torch.cuda.is_available() else -1 # GPU if available
827
- )
828
- return QA_MODEL
829
-
830
- def get_summarizer():
831
- global SUMMARIZER
832
- if SUMMARIZER is None:
833
- SUMMARIZER = pipeline(
834
- "summarization",
835
- model="facebook/bart-large-cnn",
836
- device=0 if torch.cuda.is_available() else -1
837
- )
838
- return SUMMARIZER
839
-
840
- def get_generative_model():
841
- global GENERATIVE_MODEL
842
- if GENERATIVE_MODEL is None:
843
- GENERATIVE_MODEL = pipeline(
844
- "text-generation",
845
- model="google/flan-t5-large", # Good balance of speed/accuracy
846
- device=0 if torch.cuda.is_available() else -1
847
- )
848
- return GENERATIVE_MODEL
849
-
850
  @app.post("/qa")
851
  @limiter.limit("5/minute")
852
  async def question_answering(
@@ -856,13 +813,13 @@ async def question_answering(
856
  language: str = Form("fr")
857
  ):
858
  """
859
- Enhanced QA endpoint that handles:
860
- - Any type of question (factual, thematic, analytical)
861
- - Multiple file formats
862
- - Comprehensive error handling
863
- - Language consideration
864
  """
865
- # Validate inputs
866
  if not file.filename:
867
  raise HTTPException(400, "No filename provided")
868
 
@@ -870,128 +827,72 @@ async def question_answering(
870
  raise HTTPException(400, "Question cannot be empty")
871
 
872
  try:
873
- # 1. Process file and extract text
874
  file_ext, content = await process_uploaded_file(file)
875
- text = extract_text(content, file_ext)
876
-
877
- if not text.strip():
878
- raise HTTPException(400, "No extractable text found")
879
-
880
- # Clean and normalize text
881
- text = re.sub(r'\s+', ' ', text).strip()
882
 
883
- # 2. Determine question type and process accordingly
884
- question_lower = question.lower()
885
-
886
- # Theme detection questions
887
- theme_keywords = ["thème", "sujet principal", "quoi le sujet", "theme", "main topic", "de quoi parle", "what is this about"]
888
- if any(kw in question_lower for kw in theme_keywords):
889
- return handle_theme_question(text, question, language)
890
-
891
- # Summary questions
892
- summary_keywords = ["résumé", "résume", "summarize", "summary", "synthèse"]
893
- if any(kw in question_lower for kw in summary_keywords):
894
- return handle_summary_question(text, question, language)
895
 
896
- # List/Enumeration questions
897
- list_keywords = ["liste", "list", "énumère", "quels sont", "what are the"]
898
- if any(kw in question_lower for kw in list_keywords):
899
- return handle_list_question(text, question, language)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
900
 
901
- # Default QA processing
902
- return handle_general_question(text, question, language)
 
 
 
 
 
 
 
903
 
904
  except HTTPException:
905
- raise
906
  except Exception as e:
907
- logger.error(f"QA processing failed: {str(e)}", exc_info=True)
908
- raise HTTPException(500, detail=f"Analysis failed: {str(e)}")
909
-
910
- # Helper functions for different question types
911
- def handle_theme_question(text: str, question: str, language: str):
912
- try:
913
- summarizer = get_summarizer()
914
- summary_output = summarizer(
915
- text,
916
- max_length=min(100, len(text)//4),
917
- min_length=30,
918
- do_sample=False,
919
- truncation=True
920
- )
921
- theme = summary_output[0].get("summary_text", text[:200] + "...")
922
- return {
923
- "question": question,
924
- "answer": f"Le document traite principalement de : {theme}",
925
- "confidence": 0.95,
926
- "language": language
927
- }
928
- except Exception:
929
- theme = text[:200] + ("..." if len(text) > 200 else "")
930
- return {
931
- "question": question,
932
- "answer": f"D'après le document : {theme}",
933
- "confidence": 0.7,
934
- "language": language,
935
- "warning": "theme_summary_fallback"
936
- }
937
-
938
- def handle_summary_question(text: str, question: str, language: str):
939
- summarizer = get_summarizer()
940
- summary = summarizer(
941
- text,
942
- max_length=150,
943
- min_length=50,
944
- do_sample=False,
945
- truncation=True
946
- )[0]["summary_text"]
947
- return {
948
- "question": question,
949
- "answer": f"Résumé du document : {summary}",
950
- "confidence": 0.9,
951
- "language": language
952
- }
953
-
954
- def handle_list_question(text: str, question: str, language: str):
955
- # Use QA model to find relevant parts, then extract list items
956
- qa = get_qa_model()
957
- result = qa(question=question, context=text[:3000])
958
-
959
- # Post-process to extract list items
960
- answer = result["answer"]
961
- if "\n" not in answer and "," in answer:
962
- items = [x.strip() for x in answer.split(",")]
963
- if len(items) > 1:
964
- answer = "\n- " + "\n- ".join(items)
965
-
966
- return {
967
- "question": question,
968
- "answer": answer,
969
- "confidence": result["score"],
970
- "language": language
971
- }
972
-
973
- def handle_general_question(text: str, question: str, language: str):
974
- qa = get_qa_model()
975
-
976
- # First try with full context
977
- result = qa(question=question, context=text[:3000])
978
-
979
- # If low confidence, try with different context windows
980
- if result["score"] < 0.3:
981
- alternative_results = [
982
- qa(question=question, context=text[1000:4000]),
983
- qa(question=question, context=text[2000:5000])
984
- ]
985
- best_result = max(alternative_results + [result], key=lambda x: x["score"])
986
- if best_result["score"] > result["score"] + 0.1:
987
- result = best_result
988
-
989
- return {
990
- "question": question,
991
- "answer": result["answer"],
992
- "confidence": result["score"],
993
- "language": language
994
- }
995
  @app.post("/visualize/natural")
996
  async def natural_language_visualization(
997
  file: UploadFile = File(...),
 
804
 
805
 
806
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
807
  @app.post("/qa")
808
  @limiter.limit("5/minute")
809
  async def question_answering(
 
813
  language: str = Form("fr")
814
  ):
815
  """
816
+ Enhanced QA endpoint with:
817
+ - Better error handling
818
+ - Model validation
819
+ - Detailed logging
820
+ - Original functionality preserved
821
  """
822
+ # Validate input immediately
823
  if not file.filename:
824
  raise HTTPException(400, "No filename provided")
825
 
 
827
  raise HTTPException(400, "Question cannot be empty")
828
 
829
  try:
830
+ # 1. File Processing
831
  file_ext, content = await process_uploaded_file(file)
 
 
 
 
 
 
 
832
 
833
+ # 2. Text Extraction with enhanced error context
834
+ try:
835
+ text = extract_text(content, file_ext)
836
+ if not text.strip():
837
+ raise HTTPException(400, "No extractable text found")
 
 
 
 
 
 
 
838
 
839
+ # Clean and truncate text (preserve original logic)
840
+ text = re.sub(r'\s+', ' ', text).strip()[:5000]
841
+ except Exception as e:
842
+ logger.error(f"Text extraction failed for {file.filename}: {str(e)}", exc_info=True)
843
+ raise HTTPException(422, f"Failed to process {file_ext} file: {str(e)}")
844
+
845
+ # 3. Theme Detection (original logic preserved)
846
+ theme_keywords = ["thème", "sujet principal", "quoi le sujet", "theme", "main topic"]
847
+ if any(kw in question.lower() for kw in theme_keywords):
848
+ try:
849
+ summarizer = get_summarizer()
850
+ summary_output = summarizer(
851
+ text,
852
+ max_length=min(100, len(text)//4),
853
+ min_length=30,
854
+ do_sample=False,
855
+ truncation=True
856
+ )
857
+
858
+ theme = summary_output[0].get("summary_text", text[:200] + "...")
859
+ return {
860
+ "question": question,
861
+ "answer": f"Le document traite principalement de : {theme}",
862
+ "confidence": 0.95,
863
+ "language": language
864
+ }
865
+ except Exception as e:
866
+ logger.warning(f"Theme detection fallback for '{question}': {str(e)}")
867
+ theme = text[:200] + ("..." if len(text) > 200 else "")
868
+ return {
869
+ "question": question,
870
+ "answer": f"D'après le document : {theme}",
871
+ "confidence": 0.7,
872
+ "language": language,
873
+ "warning": "theme_summary_fallback"
874
+ }
875
+
876
+ # 4. Standard QA (original logic preserved)
877
+ try:
878
+ qa = get_qa_model()
879
+ result = qa(question=question, context=text[:3000])
880
 
881
+ return {
882
+ "question": question,
883
+ "answer": result["answer"],
884
+ "confidence": result["score"],
885
+ "language": language
886
+ }
887
+ except Exception as e:
888
+ logger.error(f"QA failed for question '{question}': {str(e)}", exc_info=True)
889
+ raise HTTPException(500, "Failed to generate answer")
890
 
891
  except HTTPException:
892
+ raise # Re-raise existing HTTP exceptions
893
  except Exception as e:
894
+ logger.critical(f"Unexpected error processing request: {str(e)}", exc_info=True)
895
+ raise HTTPException(500, "Internal server error")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
896
  @app.post("/visualize/natural")
897
  async def natural_language_visualization(
898
  file: UploadFile = File(...),