from fastapi import FastAPI, UploadFile, File, Form, HTTPException, Request from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse from transformers import pipeline from typing import Tuple, Optional import io import fitz # PyMuPDF from PIL import Image import pandas as pd import uvicorn from docx import Document from pptx import Presentation import pytesseract import logging import re from slowapi import Limiter from slowapi.util import get_remote_address from slowapi.errors import RateLimitExceeded from slowapi.middleware import SlowAPIMiddleware import matplotlib.pyplot as plt import seaborn as sns import tempfile import base64 from io import BytesIO from pydantic import BaseModel import traceback import ast from fastapi.responses import HTMLResponse from fastapi import Request from pathlib import Path from fastapi.staticfiles import StaticFiles import numpy as np # Add this import import pandas as pd from io import BytesIO import os import torch # Standard library imports import io import re import logging import tempfile import base64 import warnings from typing import Tuple, Optional from pathlib import Path from docx import Document from pptx import Presentation import re from concurrent.futures import ThreadPoolExecutor from fastapi import FastAPI, UploadFile, File, Form, HTTPException, Request from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse, HTMLResponse from transformers import pipeline import fitz # PyMuPDF from PIL import Image import pandas as pd import uvicorn from docx import Document from pptx import Presentation import pytesseract from slowapi import Limiter from slowapi.util import get_remote_address from slowapi.errors import RateLimitExceeded from slowapi.middleware import SlowAPIMiddleware import matplotlib.pyplot as plt import seaborn as sns from pydantic import BaseModel import traceback import ast from openpyxl import Workbook import uuid # Suppress openpyxl warnings warnings.filterwarnings("ignore", category=UserWarning, module="openpyxl") # Rest of your code (app setup, routes, etc.)... # Initialize rate limiter limiter = Limiter(key_func=get_remote_address) # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) app = FastAPI() # Serve static files (frontend) app.mount("/static", StaticFiles(directory="static"), name="static") @app.get("/", response_class=HTMLResponse) def home (): with open("static/indexAI.html","r") as file : return file.read() # Apply rate limiting middleware app.state.limiter = limiter app.add_middleware(SlowAPIMiddleware) # CORS Configuration app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) UPLOAD_FOLDER = "uploads" OUTPUT_FOLDER = "static" os.makedirs(UPLOAD_FOLDER, exist_ok=True) os.makedirs(OUTPUT_FOLDER, exist_ok=True) # Lightweight model configuration MODEL_NAME = "distilgpt2" MAX_FILE_SIZE = 2 * 1024 * 1024 # 2MB TIMEOUT = 10 # seconds MAX_ROWS = 100 MAX_COLUMNS = 5 try: visualization_model = pipeline( "text-generation", model=MODEL_NAME, device=-1, # CPU framework="pt" ) except Exception as e: print(f"Model loading failed: {str(e)}") visualization_model = None executor = ThreadPoolExecutor(max_workers=2) def safe_read_file(file_content, file_ext): """Robust file reading with size limits""" file_like = io.BytesIO(file_content) if file_ext == 'csv': return pd.read_csv(file_like, nrows=MAX_ROWS) return pd.read_excel(file_like, nrows=MAX_ROWS) def generate_simple_plot(df, chart_type): """Fallback plotting function""" plt.figure(figsize=(8, 5)) numeric_cols = df.select_dtypes(include='number').columns if len(numeric_cols) >= 2: df[numeric_cols[:2]].plot(kind=chart_type if chart_type in ['bar', 'line', 'scatter'] else 'bar') elif len(numeric_cols) == 1: df[numeric_cols[0]].plot(kind='bar') else: df.iloc[:, 0].value_counts().plot(kind='bar') plt.tight_layout() SUPPORTED_FILE_TYPES = { "docx", "xlsx", "pptx", "pdf", "jpg", "jpeg", "png", "txt" } # Model caching summarizer = None qa_model = None image_captioner = None def get_summarizer(): global summarizer if summarizer is None: summarizer = pipeline("summarization", model="facebook/bart-large-cnn") return summarizer #def get_qa_model(): # global qa_model # if qa_model is None: # qa_model= pipe = pipeline("question-answering", model="deepset/roberta-base-squad2") #return qa_model MODEL_CHOICES = [ "cmarkea/flan-t5-base-fr", # Best for French "moussaKam/barthez-orangesum-abstract", # French summarization "google/flan-t5-xl" # Higher quality fallback ] qa_pipeline = None current_model = None def initialize_qa(): global qa_pipeline, current_model # Try each model in order for model_name in MODEL_CHOICES: try: logger.info(f"Attempting to load {model_name}") qa_pipeline = pipeline( "text2text-generation", model=model_name, device=0 if torch.cuda.is_available() else -1, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32 ) current_model = model_name logger.info(f"Successfully loaded {model_name}") return True except Exception as e: logger.warning(f"Failed to load {model_name}: {str(e)}") continue logger.error("All model loading attempts failed") return False @app.on_event("startup") async def startup_event(): if not initialize_qa(): logger.error("QA system failed to initialize") def get_image_captioner(): global image_captioner if image_captioner is None: image_captioner = pipeline("image-to-text", model="Salesforce/blip-image-captioning-large") return image_captioner async def process_uploaded_file(file: UploadFile) -> Tuple[str, bytes]: """Your existing file processing function""" if not file.filename: raise HTTPException(400, "No filename provided") file_ext = file.filename.split('.')[-1].lower() if file_ext not in SUPPORTED_FILE_TYPES: raise HTTPException(400, f"Unsupported file type. Supported: {', '.join(SUPPORTED_FILE_TYPES)}") content = await file.read() if len(content) > MAX_FILE_SIZE: raise HTTPException(413, f"File too large. Max size: {MAX_FILE_SIZE//1024//1024}MB") if file_ext == "pdf": try: with fitz.open(stream=content, filetype="pdf") as doc: if doc.is_encrypted: if not doc.authenticate(""): raise ValueError("Encrypted PDF - cannot extract text") if len(doc) > 50: raise ValueError("PDF too large (max 50 pages)") except Exception as e: logger.error(f"PDF validation failed: {str(e)}") raise HTTPException(422, detail=f"Invalid PDF file: {str(e)}") await file.seek(0) return file_ext, content def extract_text(content: bytes, file_ext: str) -> str: """Your existing text extraction function""" try: if file_ext == "txt": return content.decode("utf-8", errors="replace").strip() if file_ext == "docx": doc = Document(io.BytesIO(content)) return "\n".join(para.text for para in doc.paragraphs if para.text.strip()) elif file_ext in {"xlsx", "xls"}: df = pd.read_excel( io.BytesIO(content), sheet_name=None, engine='openpyxl', na_values=['', 'NA', 'N/A', 'NaN', 'null'], keep_default_na=False, parse_dates=True ) all_text = [] for sheet_name, sheet_data in df.items(): sheet_text = [] for column in sheet_data.columns: if pd.api.types.is_datetime64_any_dtype(sheet_data[column]): sheet_data[column] = sheet_data[column].dt.strftime('%Y-%m-%d %H:%M:%S') col_text = sheet_data[column].astype(str).replace(['nan', 'None', 'NaT'], '').tolist() sheet_text.extend([x for x in col_text if x.strip()]) all_text.append(f"Sheet: {sheet_name}\n" + "\n".join(sheet_text)) return "\n\n".join(all_text) elif file_ext == "pptx": ppt = Presentation(io.BytesIO(content)) text = [] for slide in ppt.slides: for shape in slide.shapes: if hasattr(shape, "text") and shape.text.strip(): text.append(shape.text) return "\n".join(text) elif file_ext == "pdf": pdf = fitz.open(stream=content, filetype="pdf") return "\n".join(page.get_text("text") for page in pdf) elif file_ext in {"jpg", "jpeg", "png"}: try: image = Image.open(io.BytesIO(content)) text = pytesseract.image_to_string(image, config='--psm 6') if text.strip(): return text captioner = get_image_captioner() result = captioner(image) return result[0]['generated_text'] except Exception as img_e: logger.error(f"Image processing failed: {str(img_e)}") raise ValueError("Could not extract text or caption from image") except Exception as e: logger.error(f"Text extraction failed for {file_ext}: {str(e)}", exc_info=True) raise HTTPException(422, f"Failed to extract text from {file_ext} file: {str(e)}") from concurrent.futures import ThreadPoolExecutor import asyncio # Global thread pool for CPU-bound tasks executor = ThreadPoolExecutor(max_workers=4) @app.post("/summarize") @limiter.limit("5/minute") async def summarize_document(request: Request, file: UploadFile = File(...)): """Optimized document summarization with parallel processing""" try: # 1. Fast file processing file_ext, content = await process_uploaded_file(file) # 2. Parallel text extraction loop = asyncio.get_event_loop() text = await loop.run_in_executor(executor, extract_text, content, file_ext) if not text.strip(): raise HTTPException(400, "No extractable text found") # 3. Efficient text cleaning text = re.sub(r'\s+', ' ', text).strip() # 4. Smart chunking with sentence boundaries sentences = [s for s in re.split(r'(?<=[.!?])\s+', text) if s.strip()] chunks = [] current_chunk = [] current_length = 0 for sentence in sentences: sent_length = len(sentence) if current_length + sent_length <= 800: # Slightly smaller chunks for faster processing current_chunk.append(sentence) current_length += sent_length else: chunks.append(' '.join(current_chunk)) current_chunk = [sentence] current_length = sent_length if current_chunk: chunks.append(' '.join(current_chunk)) # 5. Parallel summarization summarizer = get_summarizer() def summarize_chunk(chunk): return summarizer( chunk, max_length=120, # Smaller output for faster processing min_length=40, do_sample=False, truncation=True )[0]["summary_text"] # Process chunks in parallel with ThreadPoolExecutor(max_workers=min(4, len(chunks))) as pool: summaries = list(pool.map(summarize_chunk, chunks)) # 6. Fast final combination combined = ' '.join(summaries) # Optional: Single refinement pass if needed if len(combined.split()) > 200: combined = summarize_chunk(combined[:3000]) # Limit input size return {"summary": combined} except HTTPException: raise except Exception as e: logger.error(f"Summarization failed: {str(e)}", exc_info=True) raise HTTPException(500, "Document summarization failed") from typing import Optional @app.post("/qa") async def question_answering( question: str = Form(...), file: Optional[UploadFile] = File(None) ): if qa_pipeline is None: raise HTTPException(503, detail="Service temporairement indisponible") try: # Process file with enhanced cleaning context = None if file: _, content = await process_uploaded_file(file) raw_text = extract_text(content, file.filename.split('.')[-1]) context = clean_and_translate_to_french(raw_text)[:1200] # New cleaning function # Theme detection with strict French enforcement if "thème" in question.lower() or "theme" in question.lower(): if not context: return { "question": question, "answer": "Aucun document fourni pour identifier le thème", "context_used": False } # Optimized three-step process response = generate_theme_answer(context) return { "question": question, "answer": response, "model": current_model, "context_used": True } # Standard QA with language enforcement input_text = ( f"En tant qu'expert, réponds en français à la question suivante " f"en utilisant exclusivement le contexte fourni:\n" f"Question: {question}\n" f"Contexte: {context[:1000] if context else 'Aucun contexte disponible'}\n" f"Réponse concise:" ) result = qa_pipeline( input_text, max_length=80, num_beams=2, temperature=0.2, repetition_penalty=3.0 ) # Final validation and cleaning final_answer = validate_french_response(result[0]["generated_text"]) return { "question": question, "answer": final_answer, "model": current_model, "context_used": context is not None } except Exception as e: logger.error(f"Erreur: {str(e)}") raise HTTPException(500, "Erreur lors du traitement") # New helper functions def clean_and_translate_to_french(text: str) -> str: """Enhanced text cleaning with basic translation""" # Remove headers/footers text = re.sub(r'^\s*\d+\s*$', '', text, flags=re.MULTILINE) # Convert common English terms to French replacements = { "welcome": "bienvenue", "introduction": "introduction", "chapter": "chapitre", "section": "section" } for eng, fr in replacements.items(): text = text.replace(eng, fr) return text[:2000] def generate_theme_answer(context: str) -> str: """Specialized theme extraction pipeline""" # Step 1: Identify key topics topics_prompt = ( "Liste 3-5 mots-clés en français représentant les sujets principaux " f"de ce texte:\n{context[:1000]}" ) topics = qa_pipeline(topics_prompt, max_length=50)[0]["generated_text"] # Step 2: Generate French summary summary_prompt = ( "Résume en une phrase en français pour un étudiant:\n" f"Mots-clés: {topics}\nTexte: {context[:800]}" ) summary = qa_pipeline(summary_prompt, max_length=60)[0]["generated_text"] # Step 3: Format as theme answer return summary.split(":")[-1].split(".")[0].strip().capitalize() + "." def validate_french_response(text: str) -> str: """Ensure proper French output""" # Remove English fragments text = re.sub(r'[A-Za-z]{3,}', '', text) # Ensure proper sentence structure if not text.endswith(('.', '!', '?')): text = text.split('.')[0] + '.' return text.capitalize() @app.post("/generate-visualization") async def generate_visualization( file: UploadFile = File(...), request: str = Form(...), chart_type: Optional[str] = Form("auto") ): try: # 1. Validate input file_ext = file.filename.split('.')[-1].lower() if file_ext not in ['csv', 'xlsx', 'xls']: raise HTTPException(400, "Only CSV/Excel files accepted") file_content = await file.read() if len(file_content) > MAX_FILE_SIZE: raise HTTPException(400, f"File size exceeds {MAX_FILE_SIZE//1024}KB limit") # 2. Process data df = await asyncio.get_event_loop().run_in_executor( executor, lambda: safe_read_file(file_content, file_ext) ) # Simplify dataframe df = df.iloc[:, :MAX_COLUMNS].dropna(how='all') if df.empty: raise HTTPException(400, "No plottable data found") # 3. Generate visualization plt.switch_backend('Agg') generated_code = None if visualization_model: try: prompt = f"Create {chart_type} chart for {list(df.columns)}. Python code only:" code = visualization_model( prompt, max_length=300, num_return_sequences=1, temperature=0.3 )[0]['generated_text'].split("```python")[-1].split("```")[0].strip() if code: generated_code = code exec(code, {'df': df, 'plt': plt}) except Exception as e: print(f"Model failed, using fallback: {e}") generate_simple_plot(df, chart_type) numeric_cols = df.select_dtypes(include='number').columns.tolist() if len(numeric_cols) >= 2: cols = numeric_cols[:2] generated_code = f""" import pandas as pd import matplotlib.pyplot as plt data = {df[cols].to_dict()} df = pd.DataFrame(data) df.plot(kind='{chart_type if chart_type in ['bar', 'line', 'scatter'] else 'bar'}') plt.tight_layout() plt.show() """ elif len(numeric_cols) == 1: generated_code = f""" import pandas as pd import matplotlib.pyplot as plt data = {df[numeric_cols[0]].to_dict()} df = pd.DataFrame(data) df.plot(kind='bar') plt.tight_layout() plt.show() """ else: generated_code = f""" import pandas as pd import matplotlib.pyplot as plt data = {df.iloc[:, 0].value_counts().to_dict()} df = pd.DataFrame(list(data.items()), columns=['Category', 'Count']) df.plot(x='Category', y='Count', kind='bar') plt.tight_layout() plt.show() """ else: generate_simple_plot(df, chart_type) numeric_cols = df.select_dtypes(include='number').columns.tolist() if len(numeric_cols) >= 2: cols = numeric_cols[:2] generated_code = f""" import pandas as pd import matplotlib.pyplot as plt data = {df[cols].to_dict()} df = pd.DataFrame(data) df.plot(kind='{chart_type if chart_type in ['bar', 'line', 'scatter'] else 'bar'}') plt.tight_layout() plt.show() """ elif len(numeric_cols) == 1: generated_code = f""" import pandas as pd import matplotlib.pyplot as plt data = {df[numeric_cols[0]].to_dict()} df = pd.DataFrame(data) df.plot(kind='bar') plt.tight_layout() plt.show() """ else: generated_code = f""" import pandas as pd import matplotlib.pyplot as plt data = {df.iloc[:, 0].value_counts().to_dict()} df = pd.DataFrame(list(data.items()), columns=['Category', 'Count']) df.plot(x='Category', y='Count', kind='bar') plt.tight_layout() plt.show() """ # 4. Save output output_id = uuid.uuid4().hex[:8] image_path = f"{OUTPUT_FOLDER}/plot_{output_id}.png" plt.savefig(image_path, bbox_inches='tight', dpi=80) plt.close() return JSONResponse({ "image_url": f"/static/plot_{output_id}.png", "python_code": generated_code, "columns": list(df.columns), "note": "Visualization generated successfully" }) except HTTPException: raise except Exception as e: raise HTTPException(500, f"Processing error: {str(e)}") @app.get("/static/{filename}") async def serve_static(filename: str): file_path = f"{OUTPUT_FOLDER}/{filename}" if not os.path.exists(file_path): raise HTTPException(404, "Image not found") return FileResponse(file_path) # ===== ADD THIS AT THE BOTTOM OF main.py ===== if __name__ == "__main__": # Run the FastAPI application uvicorn.run( app, host="0.0.0.0", port=8000, timeout_keep_alive=15 ) # ===== TESTING CODE (OPTIONAL) ===== # This should be in a separate test file, not in main.py """ def test_visualization(): from fastapi.testclient import TestClient from io import BytesIO import base64 from PIL import Image import matplotlib.pyplot as plt client = TestClient(app) test_file = "test.xlsx" test_prompt = "Show me a bar chart of sales by region" with open(test_file, "rb") as f: response = client.post( "/visualize/natural", files={"file": ("test.xlsx", f, "application/vnd.ms-excel")}, data={"prompt": test_prompt} ) if response.status_code == 200: result = response.json() print("Visualization generated successfully!") image_data = result["image"].split(",")[1] image_bytes = base64.b64decode(image_data) image = Image.open(BytesIO(image_bytes)) plt.imshow(image) plt.axis("off") plt.show() else: print(f"Error: {response.status_code}\n{response.text}") """