File size: 5,931 Bytes
d728ee4
6b93b71
d728ee4
 
 
899e193
 
 
 
 
 
 
 
 
 
297e3be
b649976
 
 
 
d728ee4
b649976
d728ee4
473762c
 
 
 
 
 
297e3be
473762c
 
 
d728ee4
473762c
 
d728ee4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
473762c
d728ee4
 
473762c
 
d728ee4
 
473762c
297e3be
473762c
4b33aa9
0940d8b
473762c
 
 
 
 
 
 
0940d8b
473762c
 
 
 
0940d8b
473762c
d728ee4
 
 
 
 
 
473762c
 
 
 
 
4b33aa9
473762c
d728ee4
2001581
 
 
 
d728ee4
 
2001581
 
d728ee4
2001581
d728ee4
 
 
 
 
 
2001581
 
 
 
 
d728ee4
 
2001581
 
 
 
d728ee4
2001581
473762c
 
 
 
d728ee4
2001581
473762c
d728ee4
 
473762c
d728ee4
 
473762c
d728ee4
 
473762c
d728ee4
 
 
 
 
 
2001581
473762c
d728ee4
2001581
d728ee4
2001581
b649976
473762c
d728ee4
 
473762c
2001581
 
473762c
 
d728ee4
2001581
473762c
 
 
 
 
d728ee4
132227c
0a9f989
d728ee4
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
from fastapi import FastAPI, UploadFile, File, Form, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from transformers import pipeline
from typing import Optional
import io
import fitz  # PyMuPDF
from PIL import Image
import pandas as pd
import uvicorn
from docx import Document
from pptx import Presentation
import pytesseract
import logging
import re

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

app = FastAPI()

# CORS Configuration
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_methods=["*"],
    allow_headers=["*"],
)

# Constants
MAX_FILE_SIZE = 10 * 1024 * 1024  # 10MB
SUPPORTED_FILE_TYPES = {
    "docx", "xlsx", "pptx", "pdf", "jpg", "jpeg", "png"
}

# Model caching
summarizer = None
qa_model = None
image_captioner = None

def get_summarizer():
    global summarizer
    if summarizer is None:
        summarizer = pipeline("summarization", model="facebook/bart-large-cnn")
    return summarizer

def get_qa_model():
    global qa_model
    if qa_model is None:
        qa_model = pipeline("question-answering", model="deepset/roberta-base-squad2")
    return qa_model

def get_image_captioner():
    global image_captioner
    if image_captioner is None:
        image_captioner = pipeline("image-to-text", model="Salesforce/blip-image-captioning-large")
    return image_captioner

async def process_uploaded_file(file: UploadFile):
    if not file.filename:
        raise HTTPException(400, "No file provided")
    
    file_ext = file.filename.split('.')[-1].lower()
    if file_ext not in SUPPORTED_FILE_TYPES:
        raise HTTPException(400, f"Unsupported file type. Supported: {', '.join(SUPPORTED_FILE_TYPES)}")
    
    content = await file.read()
    if len(content) > MAX_FILE_SIZE:
        raise HTTPException(413, f"File too large. Max size: {MAX_FILE_SIZE//1024//1024}MB")
    
    return file_ext, content

def extract_text(content: bytes, file_ext: str) -> str:
    try:
        if file_ext == "docx":
            doc = Document(io.BytesIO(content))
            return " ".join(p.text for p in doc.paragraphs if p.text.strip())
        
        elif file_ext in {"xls", "xlsx"}:
            df = pd.read_excel(io.BytesIO(content))
            return " ".join(df.iloc[:, 0].dropna().astype(str).tolist())
        
        elif file_ext == "pptx":
            ppt = Presentation(io.BytesIO(content))
            return " ".join(shape.text for slide in ppt.slides 
                          for shape in slide.shapes if hasattr(shape, "text"))
        
        elif file_ext == "pdf":
            pdf = fitz.open(stream=content, filetype="pdf")
            text = []
            for page in pdf:
                page_text = page.get_text("text")
                if page_text.strip():
                    text.append(page_text)
            return " ".join(text)
        
        elif file_ext in {"jpg", "jpeg", "png"}:
            image = Image.open(io.BytesIO(content))
            return pytesseract.image_to_string(image, config='--psm 6')
        
    except Exception as e:
        logger.error(f"Text extraction failed: {str(e)}")
        raise HTTPException(422, f"Failed to extract text from {file_ext} file")

@app.post("/summarize")
async def summarize_document(file: UploadFile = File(...)):
    try:
        file_ext, content = await process_uploaded_file(file)
        text = extract_text(content, file_ext)
        
        if not text.strip():
            raise HTTPException(400, "No extractable text found")
        
        # Clean and chunk text
        text = re.sub(r'\s+', ' ', text).strip()
        chunks = [text[i:i+1000] for i in range(0, len(text), 1000)]
        
        # Summarize each chunk
        summarizer = get_summarizer()
        summaries = []
        for chunk in chunks:
            summary = summarizer(chunk, max_length=150, min_length=50, do_sample=False)[0]["summary_text"]
            summaries.append(summary)
        
        return {"summary": " ".join(summaries)}
        
    except HTTPException:
        raise
    except Exception as e:
        logger.error(f"Summarization failed: {str(e)}")
        raise HTTPException(500, "Document summarization failed")

@app.post("/qa")
async def question_answering(
    file: UploadFile = File(...),
    question: str = Form(...),
    language: str = Form("fr")
):
    try:
        file_ext, content = await process_uploaded_file(file)
        text = extract_text(content, file_ext)
        
        if not text.strip():
            raise HTTPException(400, "No extractable text found")
        
        # Clean text
        text = re.sub(r'\s+', ' ', text).strip()
        
        # Handle theme questions
        theme_keywords = ["thème", "sujet principal", "quoi le sujet", "theme", "main topic"]
        if any(kw in question.lower() for kw in theme_keywords):
            # Use summarization for theme detection
            summarizer = get_summarizer()
            theme = summarizer(text, max_length=100, min_length=30, do_sample=False)[0]["summary_text"]
            return {
                "question": question,
                "answer": f"Le document traite principalement de : {theme}",
                "confidence": 0.95,
                "language": language
            }
        
        # Standard QA processing
        qa = get_qa_model()
        result = qa(question=question, context=text)
        
        return {
            "question": question,
            "answer": result["answer"],
            "confidence": result["score"],
            "language": language
        }
        
    except HTTPException:
        raise
    except Exception as e:
        logger.error(f"QA processing failed: {str(e)}")
        raise HTTPException(500, "Document analysis failed")

if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=7860)