asm-app / main.py
chenguittiMaroua's picture
Update main.py
d728ee4 verified
raw
history blame
5.93 kB
from fastapi import FastAPI, UploadFile, File, Form, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from transformers import pipeline
from typing import Optional
import io
import fitz # PyMuPDF
from PIL import Image
import pandas as pd
import uvicorn
from docx import Document
from pptx import Presentation
import pytesseract
import logging
import re
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
app = FastAPI()
# CORS Configuration
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
# Constants
MAX_FILE_SIZE = 10 * 1024 * 1024 # 10MB
SUPPORTED_FILE_TYPES = {
"docx", "xlsx", "pptx", "pdf", "jpg", "jpeg", "png"
}
# Model caching
summarizer = None
qa_model = None
image_captioner = None
def get_summarizer():
global summarizer
if summarizer is None:
summarizer = pipeline("summarization", model="facebook/bart-large-cnn")
return summarizer
def get_qa_model():
global qa_model
if qa_model is None:
qa_model = pipeline("question-answering", model="deepset/roberta-base-squad2")
return qa_model
def get_image_captioner():
global image_captioner
if image_captioner is None:
image_captioner = pipeline("image-to-text", model="Salesforce/blip-image-captioning-large")
return image_captioner
async def process_uploaded_file(file: UploadFile):
if not file.filename:
raise HTTPException(400, "No file provided")
file_ext = file.filename.split('.')[-1].lower()
if file_ext not in SUPPORTED_FILE_TYPES:
raise HTTPException(400, f"Unsupported file type. Supported: {', '.join(SUPPORTED_FILE_TYPES)}")
content = await file.read()
if len(content) > MAX_FILE_SIZE:
raise HTTPException(413, f"File too large. Max size: {MAX_FILE_SIZE//1024//1024}MB")
return file_ext, content
def extract_text(content: bytes, file_ext: str) -> str:
try:
if file_ext == "docx":
doc = Document(io.BytesIO(content))
return " ".join(p.text for p in doc.paragraphs if p.text.strip())
elif file_ext in {"xls", "xlsx"}:
df = pd.read_excel(io.BytesIO(content))
return " ".join(df.iloc[:, 0].dropna().astype(str).tolist())
elif file_ext == "pptx":
ppt = Presentation(io.BytesIO(content))
return " ".join(shape.text for slide in ppt.slides
for shape in slide.shapes if hasattr(shape, "text"))
elif file_ext == "pdf":
pdf = fitz.open(stream=content, filetype="pdf")
text = []
for page in pdf:
page_text = page.get_text("text")
if page_text.strip():
text.append(page_text)
return " ".join(text)
elif file_ext in {"jpg", "jpeg", "png"}:
image = Image.open(io.BytesIO(content))
return pytesseract.image_to_string(image, config='--psm 6')
except Exception as e:
logger.error(f"Text extraction failed: {str(e)}")
raise HTTPException(422, f"Failed to extract text from {file_ext} file")
@app.post("/summarize")
async def summarize_document(file: UploadFile = File(...)):
try:
file_ext, content = await process_uploaded_file(file)
text = extract_text(content, file_ext)
if not text.strip():
raise HTTPException(400, "No extractable text found")
# Clean and chunk text
text = re.sub(r'\s+', ' ', text).strip()
chunks = [text[i:i+1000] for i in range(0, len(text), 1000)]
# Summarize each chunk
summarizer = get_summarizer()
summaries = []
for chunk in chunks:
summary = summarizer(chunk, max_length=150, min_length=50, do_sample=False)[0]["summary_text"]
summaries.append(summary)
return {"summary": " ".join(summaries)}
except HTTPException:
raise
except Exception as e:
logger.error(f"Summarization failed: {str(e)}")
raise HTTPException(500, "Document summarization failed")
@app.post("/qa")
async def question_answering(
file: UploadFile = File(...),
question: str = Form(...),
language: str = Form("fr")
):
try:
file_ext, content = await process_uploaded_file(file)
text = extract_text(content, file_ext)
if not text.strip():
raise HTTPException(400, "No extractable text found")
# Clean text
text = re.sub(r'\s+', ' ', text).strip()
# Handle theme questions
theme_keywords = ["thème", "sujet principal", "quoi le sujet", "theme", "main topic"]
if any(kw in question.lower() for kw in theme_keywords):
# Use summarization for theme detection
summarizer = get_summarizer()
theme = summarizer(text, max_length=100, min_length=30, do_sample=False)[0]["summary_text"]
return {
"question": question,
"answer": f"Le document traite principalement de : {theme}",
"confidence": 0.95,
"language": language
}
# Standard QA processing
qa = get_qa_model()
result = qa(question=question, context=text)
return {
"question": question,
"answer": result["answer"],
"confidence": result["score"],
"language": language
}
except HTTPException:
raise
except Exception as e:
logger.error(f"QA processing failed: {str(e)}")
raise HTTPException(500, "Document analysis failed")
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860)