Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, UploadFile, File, Form, HTTPException, Request | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import JSONResponse | |
| from transformers import pipeline | |
| from typing import Tuple, Optional | |
| import io | |
| import fitz # PyMuPDF | |
| from PIL import Image | |
| import pandas as pd | |
| import uvicorn | |
| from docx import Document | |
| from pptx import Presentation | |
| import pytesseract | |
| import logging | |
| import re | |
| from slowapi import Limiter | |
| from slowapi.util import get_remote_address | |
| from slowapi.errors import RateLimitExceeded | |
| from slowapi.middleware import SlowAPIMiddleware | |
| import matplotlib.pyplot as plt | |
| import seaborn as sns | |
| import tempfile | |
| import base64 | |
| from io import BytesIO | |
| from pydantic import BaseModel | |
| import traceback | |
| import ast | |
| from fastapi.responses import HTMLResponse | |
| from fastapi import Request | |
| from pathlib import Path | |
| from fastapi.staticfiles import StaticFiles | |
| import numpy as np # Add this import | |
| import pandas as pd | |
| from io import BytesIO | |
| import os | |
| import torch | |
| # Standard library imports | |
| import io | |
| import re | |
| import logging | |
| import tempfile | |
| import base64 | |
| import warnings | |
| from typing import Tuple, Optional | |
| from pathlib import Path | |
| from docx import Document | |
| from pptx import Presentation | |
| import re | |
| from concurrent.futures import ThreadPoolExecutor | |
| from fastapi import FastAPI, UploadFile, File, Form, HTTPException, Request | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import JSONResponse, HTMLResponse | |
| from transformers import pipeline | |
| import fitz # PyMuPDF | |
| from PIL import Image | |
| import pandas as pd | |
| import uvicorn | |
| from docx import Document | |
| from pptx import Presentation | |
| import pytesseract | |
| from slowapi import Limiter | |
| from slowapi.util import get_remote_address | |
| from slowapi.errors import RateLimitExceeded | |
| from slowapi.middleware import SlowAPIMiddleware | |
| import matplotlib.pyplot as plt | |
| import seaborn as sns | |
| from pydantic import BaseModel | |
| import traceback | |
| import ast | |
| from openpyxl import Workbook | |
| import uuid | |
| # Suppress openpyxl warnings | |
| warnings.filterwarnings("ignore", category=UserWarning, module="openpyxl") | |
| # Rest of your code (app setup, routes, etc.)... | |
| # Initialize rate limiter | |
| limiter = Limiter(key_func=get_remote_address) | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| app = FastAPI() | |
| # Serve static files (frontend) | |
| app.mount("/static", StaticFiles(directory="static"), name="static") | |
| def home (): | |
| with open("static/indexAI.html","r") as file : | |
| return file.read() | |
| # Apply rate limiting middleware | |
| app.state.limiter = limiter | |
| app.add_middleware(SlowAPIMiddleware) | |
| # CORS Configuration | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| UPLOAD_FOLDER = "uploads" | |
| OUTPUT_FOLDER = "static" | |
| os.makedirs(UPLOAD_FOLDER, exist_ok=True) | |
| os.makedirs(OUTPUT_FOLDER, exist_ok=True) | |
| # Lightweight model configuration | |
| MODEL_NAME = "distilgpt2" | |
| MAX_FILE_SIZE = 2 * 1024 * 1024 # 2MB | |
| TIMEOUT = 10 # seconds | |
| MAX_ROWS = 100 | |
| MAX_COLUMNS = 5 | |
| try: | |
| visualization_model = pipeline( | |
| "text-generation", | |
| model=MODEL_NAME, | |
| device=-1, # CPU | |
| framework="pt" | |
| ) | |
| except Exception as e: | |
| print(f"Model loading failed: {str(e)}") | |
| visualization_model = None | |
| executor = ThreadPoolExecutor(max_workers=2) | |
| def safe_read_file(file_content, file_ext): | |
| """Robust file reading with size limits""" | |
| file_like = io.BytesIO(file_content) | |
| if file_ext == 'csv': | |
| return pd.read_csv(file_like, nrows=MAX_ROWS) | |
| return pd.read_excel(file_like, nrows=MAX_ROWS) | |
| def generate_simple_plot(df, chart_type): | |
| """Fallback plotting function""" | |
| plt.figure(figsize=(8, 5)) | |
| numeric_cols = df.select_dtypes(include='number').columns | |
| if len(numeric_cols) >= 2: | |
| df[numeric_cols[:2]].plot(kind=chart_type if chart_type in ['bar', 'line', 'scatter'] else 'bar') | |
| elif len(numeric_cols) == 1: | |
| df[numeric_cols[0]].plot(kind='bar') | |
| else: | |
| df.iloc[:, 0].value_counts().plot(kind='bar') | |
| plt.tight_layout() | |
| SUPPORTED_FILE_TYPES = { | |
| "docx", "xlsx", "pptx", "pdf", "jpg", "jpeg", "png", "txt" | |
| } | |
| # Model caching | |
| summarizer = None | |
| qa_model = None | |
| image_captioner = None | |
| def get_summarizer(): | |
| global summarizer | |
| if summarizer is None: | |
| summarizer = pipeline("summarization", model="facebook/bart-large-cnn") | |
| return summarizer | |
| #def get_qa_model(): | |
| # global qa_model | |
| # if qa_model is None: | |
| # qa_model= pipe = pipeline("question-answering", model="deepset/roberta-base-squad2") | |
| #return qa_model | |
| MODEL_CHOICES = [ | |
| "cmarkea/flan-t5-base-fr", # Best for French | |
| "moussaKam/barthez-orangesum-abstract", # French summarization | |
| "google/flan-t5-xl" # Higher quality fallback | |
| ] | |
| qa_pipeline = None | |
| current_model = None | |
| def initialize_qa(): | |
| global qa_pipeline, current_model | |
| # Try each model in order | |
| for model_name in MODEL_CHOICES: | |
| try: | |
| logger.info(f"Attempting to load {model_name}") | |
| qa_pipeline = pipeline( | |
| "text2text-generation", | |
| model=model_name, | |
| device=0 if torch.cuda.is_available() else -1, | |
| torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32 | |
| ) | |
| current_model = model_name | |
| logger.info(f"Successfully loaded {model_name}") | |
| return True | |
| except Exception as e: | |
| logger.warning(f"Failed to load {model_name}: {str(e)}") | |
| continue | |
| logger.error("All model loading attempts failed") | |
| return False | |
| async def startup_event(): | |
| if not initialize_qa(): | |
| logger.error("QA system failed to initialize") | |
| def get_image_captioner(): | |
| global image_captioner | |
| if image_captioner is None: | |
| image_captioner = pipeline("image-to-text", model="Salesforce/blip-image-captioning-large") | |
| return image_captioner | |
| async def process_uploaded_file(file: UploadFile) -> Tuple[str, bytes]: | |
| """Your existing file processing function""" | |
| if not file.filename: | |
| raise HTTPException(400, "No filename provided") | |
| file_ext = file.filename.split('.')[-1].lower() | |
| if file_ext not in SUPPORTED_FILE_TYPES: | |
| raise HTTPException(400, f"Unsupported file type. Supported: {', '.join(SUPPORTED_FILE_TYPES)}") | |
| content = await file.read() | |
| if len(content) > MAX_FILE_SIZE: | |
| raise HTTPException(413, f"File too large. Max size: {MAX_FILE_SIZE//1024//1024}MB") | |
| if file_ext == "pdf": | |
| try: | |
| with fitz.open(stream=content, filetype="pdf") as doc: | |
| if doc.is_encrypted: | |
| if not doc.authenticate(""): | |
| raise ValueError("Encrypted PDF - cannot extract text") | |
| if len(doc) > 50: | |
| raise ValueError("PDF too large (max 50 pages)") | |
| except Exception as e: | |
| logger.error(f"PDF validation failed: {str(e)}") | |
| raise HTTPException(422, detail=f"Invalid PDF file: {str(e)}") | |
| await file.seek(0) | |
| return file_ext, content | |
| def extract_text(content: bytes, file_ext: str) -> str: | |
| """Your existing text extraction function""" | |
| try: | |
| if file_ext == "txt": | |
| return content.decode("utf-8", errors="replace").strip() | |
| if file_ext == "docx": | |
| doc = Document(io.BytesIO(content)) | |
| return "\n".join(para.text for para in doc.paragraphs if para.text.strip()) | |
| elif file_ext in {"xlsx", "xls"}: | |
| df = pd.read_excel( | |
| io.BytesIO(content), | |
| sheet_name=None, | |
| engine='openpyxl', | |
| na_values=['', 'NA', 'N/A', 'NaN', 'null'], | |
| keep_default_na=False, | |
| parse_dates=True | |
| ) | |
| all_text = [] | |
| for sheet_name, sheet_data in df.items(): | |
| sheet_text = [] | |
| for column in sheet_data.columns: | |
| if pd.api.types.is_datetime64_any_dtype(sheet_data[column]): | |
| sheet_data[column] = sheet_data[column].dt.strftime('%Y-%m-%d %H:%M:%S') | |
| col_text = sheet_data[column].astype(str).replace(['nan', 'None', 'NaT'], '').tolist() | |
| sheet_text.extend([x for x in col_text if x.strip()]) | |
| all_text.append(f"Sheet: {sheet_name}\n" + "\n".join(sheet_text)) | |
| return "\n\n".join(all_text) | |
| elif file_ext == "pptx": | |
| ppt = Presentation(io.BytesIO(content)) | |
| text = [] | |
| for slide in ppt.slides: | |
| for shape in slide.shapes: | |
| if hasattr(shape, "text") and shape.text.strip(): | |
| text.append(shape.text) | |
| return "\n".join(text) | |
| elif file_ext == "pdf": | |
| pdf = fitz.open(stream=content, filetype="pdf") | |
| return "\n".join(page.get_text("text") for page in pdf) | |
| elif file_ext in {"jpg", "jpeg", "png"}: | |
| try: | |
| image = Image.open(io.BytesIO(content)) | |
| text = pytesseract.image_to_string(image, config='--psm 6') | |
| if text.strip(): | |
| return text | |
| captioner = get_image_captioner() | |
| result = captioner(image) | |
| return result[0]['generated_text'] | |
| except Exception as img_e: | |
| logger.error(f"Image processing failed: {str(img_e)}") | |
| raise ValueError("Could not extract text or caption from image") | |
| except Exception as e: | |
| logger.error(f"Text extraction failed for {file_ext}: {str(e)}", exc_info=True) | |
| raise HTTPException(422, f"Failed to extract text from {file_ext} file: {str(e)}") | |
| from concurrent.futures import ThreadPoolExecutor | |
| import asyncio | |
| # Global thread pool for CPU-bound tasks | |
| executor = ThreadPoolExecutor(max_workers=4) | |
| async def summarize_document(request: Request, file: UploadFile = File(...)): | |
| """Optimized document summarization with parallel processing""" | |
| try: | |
| # 1. Fast file processing | |
| file_ext, content = await process_uploaded_file(file) | |
| # 2. Parallel text extraction | |
| loop = asyncio.get_event_loop() | |
| text = await loop.run_in_executor(executor, extract_text, content, file_ext) | |
| if not text.strip(): | |
| raise HTTPException(400, "No extractable text found") | |
| # 3. Efficient text cleaning | |
| text = re.sub(r'\s+', ' ', text).strip() | |
| # 4. Smart chunking with sentence boundaries | |
| sentences = [s for s in re.split(r'(?<=[.!?])\s+', text) if s.strip()] | |
| chunks = [] | |
| current_chunk = [] | |
| current_length = 0 | |
| for sentence in sentences: | |
| sent_length = len(sentence) | |
| if current_length + sent_length <= 800: # Slightly smaller chunks for faster processing | |
| current_chunk.append(sentence) | |
| current_length += sent_length | |
| else: | |
| chunks.append(' '.join(current_chunk)) | |
| current_chunk = [sentence] | |
| current_length = sent_length | |
| if current_chunk: | |
| chunks.append(' '.join(current_chunk)) | |
| # 5. Parallel summarization | |
| summarizer = get_summarizer() | |
| def summarize_chunk(chunk): | |
| return summarizer( | |
| chunk, | |
| max_length=120, # Smaller output for faster processing | |
| min_length=40, | |
| do_sample=False, | |
| truncation=True | |
| )[0]["summary_text"] | |
| # Process chunks in parallel | |
| with ThreadPoolExecutor(max_workers=min(4, len(chunks))) as pool: | |
| summaries = list(pool.map(summarize_chunk, chunks)) | |
| # 6. Fast final combination | |
| combined = ' '.join(summaries) | |
| # Optional: Single refinement pass if needed | |
| if len(combined.split()) > 200: | |
| combined = summarize_chunk(combined[:3000]) # Limit input size | |
| return {"summary": combined} | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| logger.error(f"Summarization failed: {str(e)}", exc_info=True) | |
| raise HTTPException(500, "Document summarization failed") | |
| from typing import Optional | |
| async def question_answering( | |
| question: str = Form(...), | |
| file: Optional[UploadFile] = File(None) | |
| ): | |
| if qa_pipeline is None: | |
| raise HTTPException(503, detail="Service temporairement indisponible") | |
| try: | |
| # Process file with enhanced cleaning | |
| context = None | |
| if file: | |
| _, content = await process_uploaded_file(file) | |
| raw_text = extract_text(content, file.filename.split('.')[-1]) | |
| context = clean_and_translate_to_french(raw_text)[:1200] # New cleaning function | |
| # Theme detection with strict French enforcement | |
| if "thème" in question.lower() or "theme" in question.lower(): | |
| if not context: | |
| return { | |
| "question": question, | |
| "answer": "Aucun document fourni pour identifier le thème", | |
| "context_used": False | |
| } | |
| # Optimized three-step process | |
| response = generate_theme_answer(context) | |
| return { | |
| "question": question, | |
| "answer": response, | |
| "model": current_model, | |
| "context_used": True | |
| } | |
| # Standard QA with language enforcement | |
| input_text = ( | |
| f"En tant qu'expert, réponds en français à la question suivante " | |
| f"en utilisant exclusivement le contexte fourni:\n" | |
| f"Question: {question}\n" | |
| f"Contexte: {context[:1000] if context else 'Aucun contexte disponible'}\n" | |
| f"Réponse concise:" | |
| ) | |
| result = qa_pipeline( | |
| input_text, | |
| max_length=80, | |
| num_beams=2, | |
| temperature=0.2, | |
| repetition_penalty=3.0 | |
| ) | |
| # Final validation and cleaning | |
| final_answer = validate_french_response(result[0]["generated_text"]) | |
| return { | |
| "question": question, | |
| "answer": final_answer, | |
| "model": current_model, | |
| "context_used": context is not None | |
| } | |
| except Exception as e: | |
| logger.error(f"Erreur: {str(e)}") | |
| raise HTTPException(500, "Erreur lors du traitement") | |
| # New helper functions | |
| def clean_and_translate_to_french(text: str) -> str: | |
| """Enhanced text cleaning with basic translation""" | |
| # Remove headers/footers | |
| text = re.sub(r'^\s*\d+\s*$', '', text, flags=re.MULTILINE) | |
| # Convert common English terms to French | |
| replacements = { | |
| "welcome": "bienvenue", | |
| "introduction": "introduction", | |
| "chapter": "chapitre", | |
| "section": "section" | |
| } | |
| for eng, fr in replacements.items(): | |
| text = text.replace(eng, fr) | |
| return text[:2000] | |
| def generate_theme_answer(context: str) -> str: | |
| """Specialized theme extraction pipeline""" | |
| # Step 1: Identify key topics | |
| topics_prompt = ( | |
| "Liste 3-5 mots-clés en français représentant les sujets principaux " | |
| f"de ce texte:\n{context[:1000]}" | |
| ) | |
| topics = qa_pipeline(topics_prompt, max_length=50)[0]["generated_text"] | |
| # Step 2: Generate French summary | |
| summary_prompt = ( | |
| "Résume en une phrase en français pour un étudiant:\n" | |
| f"Mots-clés: {topics}\nTexte: {context[:800]}" | |
| ) | |
| summary = qa_pipeline(summary_prompt, max_length=60)[0]["generated_text"] | |
| # Step 3: Format as theme answer | |
| return summary.split(":")[-1].split(".")[0].strip().capitalize() + "." | |
| def validate_french_response(text: str) -> str: | |
| """Ensure proper French output""" | |
| # Remove English fragments | |
| text = re.sub(r'[A-Za-z]{3,}', '', text) | |
| # Ensure proper sentence structure | |
| if not text.endswith(('.', '!', '?')): | |
| text = text.split('.')[0] + '.' | |
| return text.capitalize() | |
| async def generate_visualization( | |
| file: UploadFile = File(...), | |
| request: str = Form(...), | |
| chart_type: Optional[str] = Form("auto") | |
| ): | |
| try: | |
| # 1. Validate input | |
| file_ext = file.filename.split('.')[-1].lower() | |
| if file_ext not in ['csv', 'xlsx', 'xls']: | |
| raise HTTPException(400, "Only CSV/Excel files accepted") | |
| file_content = await file.read() | |
| if len(file_content) > MAX_FILE_SIZE: | |
| raise HTTPException(400, f"File size exceeds {MAX_FILE_SIZE//1024}KB limit") | |
| # 2. Process data | |
| df = await asyncio.get_event_loop().run_in_executor( | |
| executor, | |
| lambda: safe_read_file(file_content, file_ext) | |
| ) | |
| # Simplify dataframe | |
| df = df.iloc[:, :MAX_COLUMNS].dropna(how='all') | |
| if df.empty: | |
| raise HTTPException(400, "No plottable data found") | |
| # 3. Generate visualization | |
| plt.switch_backend('Agg') | |
| generated_code = None | |
| if visualization_model: | |
| try: | |
| prompt = f"Create {chart_type} chart for {list(df.columns)}. Python code only:" | |
| code = visualization_model( | |
| prompt, | |
| max_length=300, | |
| num_return_sequences=1, | |
| temperature=0.3 | |
| )[0]['generated_text'].split("```python")[-1].split("```")[0].strip() | |
| if code: | |
| generated_code = code | |
| exec(code, {'df': df, 'plt': plt}) | |
| except Exception as e: | |
| print(f"Model failed, using fallback: {e}") | |
| generate_simple_plot(df, chart_type) | |
| numeric_cols = df.select_dtypes(include='number').columns.tolist() | |
| if len(numeric_cols) >= 2: | |
| cols = numeric_cols[:2] | |
| generated_code = f""" | |
| import pandas as pd | |
| import matplotlib.pyplot as plt | |
| data = {df[cols].to_dict()} | |
| df = pd.DataFrame(data) | |
| df.plot(kind='{chart_type if chart_type in ['bar', 'line', 'scatter'] else 'bar'}') | |
| plt.tight_layout() | |
| plt.show() | |
| """ | |
| elif len(numeric_cols) == 1: | |
| generated_code = f""" | |
| import pandas as pd | |
| import matplotlib.pyplot as plt | |
| data = {df[numeric_cols[0]].to_dict()} | |
| df = pd.DataFrame(data) | |
| df.plot(kind='bar') | |
| plt.tight_layout() | |
| plt.show() | |
| """ | |
| else: | |
| generated_code = f""" | |
| import pandas as pd | |
| import matplotlib.pyplot as plt | |
| data = {df.iloc[:, 0].value_counts().to_dict()} | |
| df = pd.DataFrame(list(data.items()), columns=['Category', 'Count']) | |
| df.plot(x='Category', y='Count', kind='bar') | |
| plt.tight_layout() | |
| plt.show() | |
| """ | |
| else: | |
| generate_simple_plot(df, chart_type) | |
| numeric_cols = df.select_dtypes(include='number').columns.tolist() | |
| if len(numeric_cols) >= 2: | |
| cols = numeric_cols[:2] | |
| generated_code = f""" | |
| import pandas as pd | |
| import matplotlib.pyplot as plt | |
| data = {df[cols].to_dict()} | |
| df = pd.DataFrame(data) | |
| df.plot(kind='{chart_type if chart_type in ['bar', 'line', 'scatter'] else 'bar'}') | |
| plt.tight_layout() | |
| plt.show() | |
| """ | |
| elif len(numeric_cols) == 1: | |
| generated_code = f""" | |
| import pandas as pd | |
| import matplotlib.pyplot as plt | |
| data = {df[numeric_cols[0]].to_dict()} | |
| df = pd.DataFrame(data) | |
| df.plot(kind='bar') | |
| plt.tight_layout() | |
| plt.show() | |
| """ | |
| else: | |
| generated_code = f""" | |
| import pandas as pd | |
| import matplotlib.pyplot as plt | |
| data = {df.iloc[:, 0].value_counts().to_dict()} | |
| df = pd.DataFrame(list(data.items()), columns=['Category', 'Count']) | |
| df.plot(x='Category', y='Count', kind='bar') | |
| plt.tight_layout() | |
| plt.show() | |
| """ | |
| # 4. Save output | |
| output_id = uuid.uuid4().hex[:8] | |
| image_path = f"{OUTPUT_FOLDER}/plot_{output_id}.png" | |
| plt.savefig(image_path, bbox_inches='tight', dpi=80) | |
| plt.close() | |
| return JSONResponse({ | |
| "image_url": f"/static/plot_{output_id}.png", | |
| "python_code": generated_code, | |
| "columns": list(df.columns), | |
| "note": "Visualization generated successfully" | |
| }) | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| raise HTTPException(500, f"Processing error: {str(e)}") | |
| async def serve_static(filename: str): | |
| file_path = f"{OUTPUT_FOLDER}/{filename}" | |
| if not os.path.exists(file_path): | |
| raise HTTPException(404, "Image not found") | |
| return FileResponse(file_path) | |
| # ===== ADD THIS AT THE BOTTOM OF main.py ===== | |
| if __name__ == "__main__": | |
| # Run the FastAPI application | |
| uvicorn.run( | |
| app, | |
| host="0.0.0.0", | |
| port=8000, | |
| timeout_keep_alive=15 | |
| ) | |
| # ===== TESTING CODE (OPTIONAL) ===== | |
| # This should be in a separate test file, not in main.py | |
| """ | |
| def test_visualization(): | |
| from fastapi.testclient import TestClient | |
| from io import BytesIO | |
| import base64 | |
| from PIL import Image | |
| import matplotlib.pyplot as plt | |
| client = TestClient(app) | |
| test_file = "test.xlsx" | |
| test_prompt = "Show me a bar chart of sales by region" | |
| with open(test_file, "rb") as f: | |
| response = client.post( | |
| "/visualize/natural", | |
| files={"file": ("test.xlsx", f, "application/vnd.ms-excel")}, | |
| data={"prompt": test_prompt} | |
| ) | |
| if response.status_code == 200: | |
| result = response.json() | |
| print("Visualization generated successfully!") | |
| image_data = result["image"].split(",")[1] | |
| image_bytes = base64.b64decode(image_data) | |
| image = Image.open(BytesIO(image_bytes)) | |
| plt.imshow(image) | |
| plt.axis("off") | |
| plt.show() | |
| else: | |
| print(f"Error: {response.status_code}\n{response.text}") | |
| """ |