chenguittiMaroua commited on
Commit
dd793b9
·
verified ·
1 Parent(s): 3a67355

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +187 -1
main.py CHANGED
@@ -101,12 +101,58 @@ app.add_middleware(SlowAPIMiddleware)
101
  app.add_middleware(
102
  CORSMiddleware,
103
  allow_origins=["*"],
 
104
  allow_methods=["*"],
105
  allow_headers=["*"],
106
  )
107
 
108
- # Constants
 
 
 
 
 
 
 
 
 
 
109
  MAX_FILE_SIZE = 10 * 1024 * 1024 # 10MB
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
  SUPPORTED_FILE_TYPES = {
111
  "docx", "xlsx", "pptx", "pdf", "jpg", "jpeg", "png", "txt"
112
  }
@@ -498,6 +544,146 @@ def validate_french_response(text: str) -> str:
498
  return text.capitalize()
499
 
500
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
501
 
502
 
503
 
 
101
  app.add_middleware(
102
  CORSMiddleware,
103
  allow_origins=["*"],
104
+ allow_credentials=True,
105
  allow_methods=["*"],
106
  allow_headers=["*"],
107
  )
108
 
109
+
110
+ UPLOAD_FOLDER = "uploads"
111
+ OUTPUT_FOLDER = "static"
112
+ os.makedirs(UPLOAD_FOLDER, exist_ok=True)
113
+ os.makedirs(OUTPUT_FOLDER, exist_ok=True)
114
+
115
+ # Lightweight model configuration
116
+ MODEL_NAME = "distilgpt2"
117
+ TIMEOUT = 10 # seconds
118
+ MAX_ROWS = 100
119
+ MAX_COLUMNS = 5
120
  MAX_FILE_SIZE = 10 * 1024 * 1024 # 10MB
121
+ try:
122
+ visualization_model = pipeline(
123
+ "text-generation",
124
+ model=MODEL_NAME,
125
+ device=-1, # CPU
126
+ framework="pt"
127
+ )
128
+ except Exception as e:
129
+ print(f"Model loading failed: {str(e)}")
130
+ visualization_model = None
131
+
132
+ executor = ThreadPoolExecutor(max_workers=2)
133
+
134
+ def safe_read_file(file_content, file_ext):
135
+ """Robust file reading with size limits"""
136
+ file_like = io.BytesIO(file_content)
137
+ if file_ext == 'csv':
138
+ return pd.read_csv(file_like, nrows=MAX_ROWS)
139
+ return pd.read_excel(file_like, nrows=MAX_ROWS)
140
+
141
+ def generate_simple_plot(df, chart_type):
142
+ """Fallback plotting function"""
143
+ plt.figure(figsize=(8, 5))
144
+ numeric_cols = df.select_dtypes(include='number').columns
145
+
146
+ if len(numeric_cols) >= 2:
147
+ df[numeric_cols[:2]].plot(kind=chart_type if chart_type in ['bar', 'line', 'scatter'] else 'bar')
148
+ elif len(numeric_cols) == 1:
149
+ df[numeric_cols[0]].plot(kind='bar')
150
+ else:
151
+ df.iloc[:, 0].value_counts().plot(kind='bar')
152
+
153
+ plt.tight_layout()
154
+
155
+
156
  SUPPORTED_FILE_TYPES = {
157
  "docx", "xlsx", "pptx", "pdf", "jpg", "jpeg", "png", "txt"
158
  }
 
544
  return text.capitalize()
545
 
546
 
547
+
548
+
549
+ @app.post("/generate-visualization")
550
+ async def generate_visualization(
551
+ file: UploadFile = File(...),
552
+ request: str = Form(...),
553
+ chart_type: Optional[str] = Form("auto")
554
+ ):
555
+ try:
556
+ # 1. Validate input
557
+ file_ext = file.filename.split('.')[-1].lower()
558
+ if file_ext not in ['csv', 'xlsx', 'xls']:
559
+ raise HTTPException(400, "Only CSV/Excel files accepted")
560
+
561
+ file_content = await file.read()
562
+ if len(file_content) > MAX_FILE_SIZE:
563
+ raise HTTPException(400, f"File size exceeds {MAX_FILE_SIZE//1024}KB limit")
564
+
565
+ # 2. Process data
566
+ df = await asyncio.get_event_loop().run_in_executor(
567
+ executor,
568
+ lambda: safe_read_file(file_content, file_ext)
569
+ )
570
+
571
+ # Simplify dataframe
572
+ df = df.iloc[:, :MAX_COLUMNS].dropna(how='all')
573
+ if df.empty:
574
+ raise HTTPException(400, "No plottable data found")
575
+
576
+ # 3. Generate visualization
577
+ plt.switch_backend('Agg')
578
+ generated_code = None
579
+
580
+ if visualization_model:
581
+ try:
582
+ prompt = f"Create {chart_type} chart for {list(df.columns)}. Python code only:"
583
+ code = visualization_model(
584
+ prompt,
585
+ max_length=300,
586
+ num_return_sequences=1,
587
+ temperature=0.3
588
+ )[0]['generated_text'].split("```python")[-1].split("```")[0].strip()
589
+
590
+ if code:
591
+ generated_code = code
592
+ exec(code, {'df': df, 'plt': plt})
593
+ except Exception as e:
594
+ print(f"Model failed, using fallback: {e}")
595
+ generate_simple_plot(df, chart_type)
596
+ numeric_cols = df.select_dtypes(include='number').columns.tolist()
597
+ if len(numeric_cols) >= 2:
598
+ cols = numeric_cols[:2]
599
+ generated_code = f"""
600
+ import pandas as pd
601
+ import matplotlib.pyplot as plt
602
+ data = {df[cols].to_dict()}
603
+ df = pd.DataFrame(data)
604
+ df.plot(kind='{chart_type if chart_type in ['bar', 'line', 'scatter'] else 'bar'}')
605
+ plt.tight_layout()
606
+ plt.show()
607
+ """
608
+ elif len(numeric_cols) == 1:
609
+ generated_code = f"""
610
+ import pandas as pd
611
+ import matplotlib.pyplot as plt
612
+ data = {df[numeric_cols[0]].to_dict()}
613
+ df = pd.DataFrame(data)
614
+ df.plot(kind='bar')
615
+ plt.tight_layout()
616
+ plt.show()
617
+ """
618
+ else:
619
+ generated_code = f"""
620
+ import pandas as pd
621
+ import matplotlib.pyplot as plt
622
+ data = {df.iloc[:, 0].value_counts().to_dict()}
623
+ df = pd.DataFrame(list(data.items()), columns=['Category', 'Count'])
624
+ df.plot(x='Category', y='Count', kind='bar')
625
+ plt.tight_layout()
626
+ plt.show()
627
+ """
628
+ else:
629
+ generate_simple_plot(df, chart_type)
630
+ numeric_cols = df.select_dtypes(include='number').columns.tolist()
631
+ if len(numeric_cols) >= 2:
632
+ cols = numeric_cols[:2]
633
+ generated_code = f"""
634
+ import pandas as pd
635
+ import matplotlib.pyplot as plt
636
+ data = {df[cols].to_dict()}
637
+ df = pd.DataFrame(data)
638
+ df.plot(kind='{chart_type if chart_type in ['bar', 'line', 'scatter'] else 'bar'}')
639
+ plt.tight_layout()
640
+ plt.show()
641
+ """
642
+ elif len(numeric_cols) == 1:
643
+ generated_code = f"""
644
+ import pandas as pd
645
+ import matplotlib.pyplot as plt
646
+ data = {df[numeric_cols[0]].to_dict()}
647
+ df = pd.DataFrame(data)
648
+ df.plot(kind='bar')
649
+ plt.tight_layout()
650
+ plt.show()
651
+ """
652
+ else:
653
+ generated_code = f"""
654
+ import pandas as pd
655
+ import matplotlib.pyplot as plt
656
+ data = {df.iloc[:, 0].value_counts().to_dict()}
657
+ df = pd.DataFrame(list(data.items()), columns=['Category', 'Count'])
658
+ df.plot(x='Category', y='Count', kind='bar')
659
+ plt.tight_layout()
660
+ plt.show()
661
+ """
662
+
663
+ # 4. Save output
664
+ output_id = uuid.uuid4().hex[:8]
665
+ image_path = f"{OUTPUT_FOLDER}/plot_{output_id}.png"
666
+ plt.savefig(image_path, bbox_inches='tight', dpi=80)
667
+ plt.close()
668
+
669
+ return JSONResponse({
670
+ "image_url": f"/static/plot_{output_id}.png",
671
+ "python_code": generated_code,
672
+ "columns": list(df.columns),
673
+ "note": "Visualization generated successfully"
674
+ })
675
+
676
+ except HTTPException:
677
+ raise
678
+ except Exception as e:
679
+ raise HTTPException(500, f"Processing error: {str(e)}")
680
+
681
+ @app.get("/static/{filename}")
682
+ async def serve_static(filename: str):
683
+ file_path = f"{OUTPUT_FOLDER}/{filename}"
684
+ if not os.path.exists(file_path):
685
+ raise HTTPException(404, "Image not found")
686
+ return FileResponse(file_path)
687
 
688
 
689