Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, UploadFile, File, Form, HTTPException, Request | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import JSONResponse | |
| from transformers import pipeline | |
| from typing import Tuple | |
| import io | |
| import fitz # PyMuPDF | |
| from PIL import Image | |
| import pandas as pd | |
| import uvicorn | |
| from docx import Document | |
| from pptx import Presentation | |
| import pytesseract | |
| import logging | |
| import re | |
| from slowapi import Limiter | |
| from slowapi.util import get_remote_address | |
| from slowapi.errors import RateLimitExceeded | |
| from slowapi.middleware import SlowAPIMiddleware | |
| import matplotlib.pyplot as plt | |
| import seaborn as sns | |
| import tempfile | |
| import base64 | |
| from io import BytesIO | |
| from typing import Optional | |
| from pydantic import BaseModel | |
| # Initialize rate limiter | |
| limiter = Limiter(key_func=get_remote_address) | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| app = FastAPI() | |
| # Apply rate limiting middleware | |
| app.state.limiter = limiter | |
| app.add_middleware(SlowAPIMiddleware) | |
| # CORS Configuration | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Constants | |
| MAX_FILE_SIZE = 10 * 1024 * 1024 # 10MB | |
| SUPPORTED_FILE_TYPES = { | |
| "docx", "xlsx", "pptx", "pdf", "jpg", "jpeg", "png" | |
| } | |
| # Model caching | |
| summarizer = None | |
| qa_model = None | |
| image_captioner = None | |
| def get_summarizer(): | |
| global summarizer | |
| if summarizer is None: | |
| summarizer = pipeline("summarization", model="facebook/bart-large-cnn") | |
| return summarizer | |
| def get_qa_model(): | |
| global qa_model | |
| if qa_model is None: | |
| qa_model = pipeline("question-answering", model="deepset/roberta-base-squad2") | |
| return qa_model | |
| def get_image_captioner(): | |
| global image_captioner | |
| if image_captioner is None: | |
| image_captioner = pipeline("image-to-text", model="Salesforce/blip-image-captioning-large") | |
| return image_captioner | |
| async def process_uploaded_file(file: UploadFile) -> Tuple[str, bytes]: | |
| """Validate and process uploaded file with special handling for each type""" | |
| if not file.filename: | |
| raise HTTPException(400, "No filename provided") | |
| file_ext = file.filename.split('.')[-1].lower() | |
| if file_ext not in SUPPORTED_FILE_TYPES: | |
| raise HTTPException(400, f"Unsupported file type. Supported: {', '.join(SUPPORTED_FILE_TYPES)}") | |
| content = await file.read() | |
| if len(content) > MAX_FILE_SIZE: | |
| raise HTTPException(413, f"File too large. Max size: {MAX_FILE_SIZE//1024//1024}MB") | |
| # Special validation for PDFs | |
| if file_ext == "pdf": | |
| try: | |
| with fitz.open(stream=content, filetype="pdf") as doc: | |
| if doc.is_encrypted: | |
| if not doc.authenticate(""): | |
| raise ValueError("Encrypted PDF - cannot extract text") | |
| if len(doc) > 50: | |
| raise ValueError("PDF too large (max 50 pages)") | |
| except Exception as e: | |
| logger.error(f"PDF validation failed: {str(e)}") | |
| raise HTTPException(422, detail=f"Invalid PDF file: {str(e)}") | |
| await file.seek(0) # Reset file pointer for processing | |
| return file_ext, content | |
| def extract_text(content: bytes, file_ext: str) -> str: | |
| """Extract text from various file formats with enhanced support""" | |
| try: | |
| if file_ext == "docx": | |
| doc = Document(io.BytesIO(content)) | |
| return "\n".join(para.text for para in doc.paragraphs if para.text.strip()) | |
| elif file_ext in {"xlsx", "xls"}: | |
| df = pd.read_excel(io.BytesIO(content), sheet_name=None) | |
| all_text = [] | |
| for sheet_name, sheet_data in df.items(): | |
| sheet_text = [] | |
| for column in sheet_data.columns: | |
| sheet_text.extend(sheet_data[column].dropna().astype(str).tolist()) | |
| all_text.append(f"Sheet: {sheet_name}\n" + "\n".join(sheet_text)) | |
| return "\n\n".join(all_text) | |
| elif file_ext == "pptx": | |
| ppt = Presentation(io.BytesIO(content)) | |
| text = [] | |
| for slide in ppt.slides: | |
| for shape in slide.shapes: | |
| if hasattr(shape, "text") and shape.text.strip(): | |
| text.append(shape.text) | |
| return "\n".join(text) | |
| elif file_ext == "pdf": | |
| pdf = fitz.open(stream=content, filetype="pdf") | |
| return "\n".join(page.get_text("text") for page in pdf) | |
| elif file_ext in {"jpg", "jpeg", "png"}: | |
| # First try OCR | |
| try: | |
| image = Image.open(io.BytesIO(content)) | |
| text = pytesseract.image_to_string(image, config='--psm 6') | |
| if text.strip(): | |
| return text | |
| # If OCR fails, try image captioning | |
| captioner = get_image_captioner() | |
| result = captioner(image) | |
| return result[0]['generated_text'] | |
| except Exception as img_e: | |
| logger.error(f"Image processing failed: {str(img_e)}") | |
| raise ValueError("Could not extract text or caption from image") | |
| except Exception as e: | |
| logger.error(f"Text extraction failed for {file_ext}: {str(e)}") | |
| raise HTTPException(422, f"Failed to extract text from {file_ext} file") | |
| async def summarize_document(request: Request, file: UploadFile = File(...)): | |
| try: | |
| file_ext, content = await process_uploaded_file(file) | |
| text = extract_text(content, file_ext) | |
| if not text.strip(): | |
| raise HTTPException(400, "No extractable text found") | |
| # Clean and chunk text | |
| text = re.sub(r'\s+', ' ', text).strip() | |
| chunks = [text[i:i+1000] for i in range(0, len(text), 1000)] | |
| # Summarize each chunk | |
| summarizer = get_summarizer() | |
| summaries = [] | |
| for chunk in chunks: | |
| summary = summarizer(chunk, max_length=150, min_length=50, do_sample=False)[0]["summary_text"] | |
| summaries.append(summary) | |
| return {"summary": " ".join(summaries)} | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| logger.error(f"Summarization failed: {str(e)}") | |
| raise HTTPException(500, "Document summarization failed") | |
| async def question_answering( | |
| request: Request, | |
| file: UploadFile = File(...), | |
| question: str = Form(...), | |
| language: str = Form("fr") | |
| ): | |
| try: | |
| file_ext, content = await process_uploaded_file(file) | |
| text = extract_text(content, file_ext) | |
| if not text.strip(): | |
| raise HTTPException(400, "No extractable text found") | |
| # Clean and truncate text | |
| text = re.sub(r'\s+', ' ', text).strip()[:5000] | |
| # Theme detection | |
| theme_keywords = ["thème", "sujet principal", "quoi le sujet", "theme", "main topic"] | |
| if any(kw in question.lower() for kw in theme_keywords): | |
| try: | |
| summarizer = get_summarizer() | |
| summary_output = summarizer( | |
| text, | |
| max_length=min(100, len(text)//4), | |
| min_length=30, | |
| do_sample=False, | |
| truncation=True | |
| ) | |
| theme = summary_output[0].get("summary_text", text[:200] + "...") | |
| return { | |
| "question": question, | |
| "answer": f"Le document traite principalement de : {theme}", | |
| "confidence": 0.95, | |
| "language": language | |
| } | |
| except Exception: | |
| theme = text[:200] + ("..." if len(text) > 200 else "") | |
| return { | |
| "question": question, | |
| "answer": f"D'après le document : {theme}", | |
| "confidence": 0.7, | |
| "language": language, | |
| "warning": "theme_summary_fallback" | |
| } | |
| # Standard QA | |
| qa = get_qa_model() | |
| result = qa(question=question, context=text[:3000]) | |
| return { | |
| "question": question, | |
| "answer": result["answer"], | |
| "confidence": result["score"], | |
| "language": language | |
| } | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| logger.error(f"QA processing failed: {str(e)}") | |
| raise HTTPException(500, detail=f"Analysis failed: {str(e)}") | |
| async def rate_limit_exceeded_handler(request: Request, exc: RateLimitExceeded): | |
| return JSONResponse( | |
| status_code=429, | |
| content={"detail": "Too many requests. Please try again later."} | |
| ) | |
| # Add this new Pydantic model for visualization requests | |
| class VisualizationRequest(BaseModel): | |
| chart_type: str | |
| x_column: Optional[str] = None | |
| y_column: Optional[str] = None | |
| hue_column: Optional[str] = None | |
| title: Optional[str] = None | |
| x_label: Optional[str] = None | |
| y_label: Optional[str] = None | |
| style: str = "seaborn" # seaborn or matplotlib | |
| # Add this new function for visualization code generation | |
| def generate_visualization(df: pd.DataFrame, request: VisualizationRequest) -> str: | |
| """Generate and execute visualization code based on request""" | |
| plt.style.use(request.style) | |
| code_lines = [ | |
| "import matplotlib.pyplot as plt", | |
| "import seaborn as sns", | |
| "import pandas as pd", | |
| "", | |
| "# Data preparation", | |
| f"df = pd.DataFrame({df.head().to_dict()})", # Simplified for demo | |
| "", | |
| "# Visualization code" | |
| ] | |
| if request.chart_type == "line": | |
| code_lines.append(f"plt.figure(figsize=(10, 6))") | |
| if request.hue_column: | |
| code_lines.append(f"sns.lineplot(data=df, x='{request.x_column}', y='{request.y_column}', hue='{request.hue_column}')") | |
| else: | |
| code_lines.append(f"plt.plot(df['{request.x_column}'], df['{request.y_column}'])") | |
| elif request.chart_type == "bar": | |
| code_lines.append(f"plt.figure(figsize=(10, 6))") | |
| if request.hue_column: | |
| code_lines.append(f"sns.barplot(data=df, x='{request.x_column}', y='{request.y_column}', hue='{request.hue_column}')") | |
| else: | |
| code_lines.append(f"plt.bar(df['{request.x_column}'], df['{request.y_column}'])") | |
| elif request.chart_type == "scatter": | |
| code_lines.append(f"plt.figure(figsize=(10, 6))") | |
| if request.hue_column: | |
| code_lines.append(f"sns.scatterplot(data=df, x='{request.x_column}', y='{request.y_column}', hue='{request.hue_column}')") | |
| else: | |
| code_lines.append(f"plt.scatter(df['{request.x_column}'], df['{request.y_column}'])") | |
| elif request.chart_type == "histogram": | |
| code_lines.append(f"plt.figure(figsize=(10, 6))") | |
| code_lines.append(f"plt.hist(df['{request.x_column}'], bins=20)") | |
| else: | |
| raise ValueError("Unsupported chart type") | |
| # Add labels and title | |
| if request.title: | |
| code_lines.append(f"plt.title('{request.title}')") | |
| if request.x_label: | |
| code_lines.append(f"plt.xlabel('{request.x_label}')") | |
| if request.y_label: | |
| code_lines.append(f"plt.ylabel('{request.y_label}')") | |
| code_lines.append("plt.tight_layout()") | |
| code_lines.append("plt.show()") | |
| return "\n".join(code_lines) | |
| # Add this new endpoint for visualization | |
| async def generate_visualization_from_excel( | |
| request: Request, | |
| file: UploadFile = File(...), | |
| chart_type: str = Form(...), | |
| x_column: Optional[str] = Form(None), | |
| y_column: Optional[str] = Form(None), | |
| hue_column: Optional[str] = Form(None), | |
| title: Optional[str] = Form(None), | |
| x_label: Optional[str] = Form(None), | |
| y_label: Optional[str] = Form(None), | |
| style: str = Form("seaborn") | |
| ): | |
| try: | |
| # Validate file | |
| file_ext, content = await validate_file(file) | |
| if file_ext not in {"xlsx", "xls"}: | |
| raise HTTPException(400, "Only Excel files are supported for visualization") | |
| # Read Excel file | |
| df = pd.read_excel(io.BytesIO(content)) | |
| # Generate visualization request | |
| vis_request = VisualizationRequest( | |
| chart_type=chart_type, | |
| x_column=x_column, | |
| y_column=y_column, | |
| hue_column=hue_column, | |
| title=title, | |
| x_label=x_label, | |
| y_label=y_label, | |
| style=style | |
| ) | |
| # Generate and execute the visualization code | |
| plt.figure() | |
| exec(generate_visualization(df, vis_request), globals(), locals()) | |
| # Save the plot to a temporary file | |
| with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmpfile: | |
| plt.savefig(tmpfile.name, format='png', dpi=300) | |
| plt.close() | |
| # Read the image back as bytes | |
| with open(tmpfile.name, "rb") as f: | |
| image_bytes = f.read() | |
| # Encode image as base64 | |
| image_base64 = base64.b64encode(image_bytes).decode('utf-8') | |
| return { | |
| "status": "success", | |
| "image": f"data:image/png;base64,{image_base64}", | |
| "code": generate_visualization(df, vis_request) | |
| } | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| logger.error(f"Visualization failed: {str(e)}\n{traceback.format_exc()}") | |
| raise HTTPException(500, detail=f"Visualization failed: {str(e)}") | |
| # Add this new endpoint for getting column names | |
| async def get_excel_columns( | |
| request: Request, | |
| file: UploadFile = File(...) | |
| ): | |
| try: | |
| file_ext, content = await validate_file(file) | |
| if file_ext not in {"xlsx", "xls"}: | |
| raise HTTPException(400, "Only Excel files are supported") | |
| df = pd.read_excel(io.BytesIO(content)) | |
| return { | |
| "columns": list(df.columns), | |
| "sample_data": df.head().to_dict(orient='records') | |
| } | |
| except Exception as e: | |
| logger.error(f"Column extraction failed: {str(e)}") | |
| raise HTTPException(500, detail="Failed to extract columns from Excel file") | |
| if __name__ == "__main__": | |
| uvicorn.run(app, host="0.0.0.0", port=7860) |