chenguittiMaroua commited on
Commit
360e004
·
verified ·
1 Parent(s): 7a504cf

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +109 -71
main.py CHANGED
@@ -433,12 +433,79 @@ def generate_visualization_code(df: pd.DataFrame, request: VisualizationRequest)
433
  from typing import Optional
434
 
435
  def interpret_natural_language(prompt: str, df_columns: list) -> Optional[VisualizationRequest]:
436
- """Convert natural language prompt to visualization parameters"""
437
  if not prompt or not df_columns:
438
  return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
439
 
440
- prompt = prompt.lower()
441
- # [rest of your existing function...]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
442
 
443
  # ===== DYNAMIC VISUALIZATION FUNCTIONS =====
444
  def read_any_excel(content: bytes) -> pd.DataFrame:
@@ -782,85 +849,56 @@ async def visualize_with_natural_language(
782
  style: str = Form("seaborn-v0_8")
783
  ):
784
  try:
785
- # Debugging: Log incoming request
786
- logger.info(f"Incoming request with file: {file.filename if file else 'None'}")
787
-
788
- # Verify file exists and has content
789
- if not file or not file.filename:
790
- logger.error("No file uploaded")
791
- raise HTTPException(400, "Please upload an Excel file")
792
 
793
- # Read file content
794
- content = await file.read()
795
- if not content:
796
- logger.error("Empty file uploaded")
797
- raise HTTPException(400, "The uploaded file is empty")
798
-
799
- # Verify Excel file extension
800
- file_ext = file.filename.split('.')[-1].lower()
801
- if file_ext not in {"xlsx", "xls"}:
802
- logger.error(f"Unsupported file type: {file_ext}")
803
- raise HTTPException(400, "Only Excel files (.xlsx, .xls) are supported")
804
-
805
- # Read Excel file with multiple engine fallbacks
806
- try:
807
- df = pd.read_excel(BytesIO(content), engine='openpyxl')
808
- except Exception as e:
809
- logger.warning(f"Openpyxl failed, trying xlrd: {str(e)}")
810
- try:
811
- df = pd.read_excel(BytesIO(content), engine='xlrd')
812
- except Exception as e:
813
- logger.error(f"Excel read failed: {str(e)}")
814
- raise HTTPException(400, "Failed to read Excel file - may be corrupt or password protected")
815
-
816
- if df.empty:
817
- logger.error("Empty DataFrame after reading Excel")
818
- raise HTTPException(400, "Excel file contains no data")
819
-
820
- # Generate prompt if empty
821
  if not prompt.strip():
822
  prompt = generate_smart_prompt(df)
823
  logger.info(f"Auto-generated prompt: {prompt}")
824
 
825
- # Create visualization
826
- vis_request = interpret_natural_language(prompt, df.columns.tolist())
827
- if not vis_request:
828
- logger.error("Could not interpret visualization request")
829
- raise HTTPException(400, "Could not understand your visualization request")
830
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
831
  vis_request.style = style
832
 
833
- # Generate visualization
834
- try:
835
- visualization_code = generate_dynamic_visualization_code(df, vis_request)
836
-
837
- plt.style.use(vis_request.style)
838
- fig, ax = plt.subplots(figsize=(10, 6))
839
-
840
- exec(visualization_code, {'plt': plt, 'sns': sns, 'df': df, 'np': np})
841
-
842
- buffer = BytesIO()
843
- plt.savefig(buffer, format='png', dpi=150, bbox_inches='tight')
844
- plt.close()
845
- buffer.seek(0)
846
-
847
- return {
848
- "status": "success",
849
- "image_data": base64.b64encode(buffer.getvalue()).decode('utf-8'),
850
- "code": visualization_code,
851
- "columns": list(df.columns),
852
- "prompt": prompt
853
- }
854
-
855
- except Exception as e:
856
- logger.error(f"Visualization failed: {str(e)}")
857
- raise HTTPException(400, f"Failed to generate visualization: {str(e)}")
858
-
859
  except HTTPException as he:
860
  raise
861
  except Exception as e:
862
  logger.error(f"Unexpected error: {traceback.format_exc()}")
863
- raise HTTPException(500, "Internal server error")
 
 
 
864
 
865
 
866
 
 
433
  from typing import Optional
434
 
435
  def interpret_natural_language(prompt: str, df_columns: list) -> Optional[VisualizationRequest]:
436
+ """Convert natural language prompt to visualization parameters with enhanced parsing"""
437
  if not prompt or not df_columns:
438
  return None
439
+
440
+ # Normalize the prompt and columns
441
+ prompt = prompt.lower().strip()
442
+ normalized_columns = [col.lower().strip() for col in df_columns]
443
+
444
+ # Initialize default values
445
+ chart_type = "bar"
446
+ x_col = None
447
+ y_col = None
448
+ hue_col = None
449
+ title = f"Visualization of {prompt[:50]}" # Default title
450
+
451
+ # Common chart type detection
452
+ chart_keywords = {
453
+ "line": ["line", "trend", "over time"],
454
+ "bar": ["bar", "compare", "comparison"],
455
+ "scatter": ["scatter", "correlation", "relationship"],
456
+ "histogram": ["histogram", "distribution", "frequency"],
457
+ "boxplot": ["box", "quartile", "distribution"],
458
+ "heatmap": ["heatmap", "correlation", "matrix"]
459
+ }
460
+
461
+ # Detect chart type
462
+ for chart, keywords in chart_keywords.items():
463
+ if any(keyword in prompt for keyword in keywords):
464
+ chart_type = chart
465
+ break
466
+
467
+ # Column detection with improved matching
468
+ for col in df_columns:
469
+ col_lower = col.lower()
470
 
471
+ # Check if column name appears in prompt
472
+ if col_lower in prompt:
473
+ # Look for context clues about the column's role
474
+ if not x_col and ("by " + col_lower in prompt or
475
+ "for " + col_lower in prompt or
476
+ "across " + col_lower in prompt):
477
+ x_col = col
478
+ elif not y_col and ("of " + col_lower in prompt or
479
+ "show " + col_lower in prompt or
480
+ "plot " + col_lower in prompt):
481
+ y_col = col
482
+ elif not hue_col and ("color by " + col_lower in prompt or
483
+ "group by " + col_lower in prompt):
484
+ hue_col = col
485
+
486
+ # Fallback logic if columns not detected
487
+ if not x_col and len(df_columns) > 0:
488
+ x_col = df_columns[0] # First column as default x-axis
489
+
490
+ if not y_col and len(df_columns) > 1:
491
+ # Try to find a numeric column for y-axis
492
+ numeric_cols = [col for col in df_columns if pd.api.types.is_numeric_dtype(df[col])]
493
+ y_col = numeric_cols[0] if numeric_cols else df_columns[1]
494
+
495
+ # Special handling for certain chart types
496
+ if chart_type == "heatmap":
497
+ x_col = None
498
+ y_col = None
499
+ hue_col = None
500
+
501
+ return VisualizationRequest(
502
+ chart_type=chart_type,
503
+ x_column=x_col,
504
+ y_column=y_col,
505
+ hue_column=hue_col,
506
+ title=title,
507
+ style="seaborn-v0_8"
508
+ )
509
 
510
  # ===== DYNAMIC VISUALIZATION FUNCTIONS =====
511
  def read_any_excel(content: bytes) -> pd.DataFrame:
 
849
  style: str = Form("seaborn-v0_8")
850
  ):
851
  try:
852
+ # [Previous file handling code remains the same until after df is created]
 
 
 
 
 
 
853
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
854
  if not prompt.strip():
855
  prompt = generate_smart_prompt(df)
856
  logger.info(f"Auto-generated prompt: {prompt}")
857
 
858
+ # Enhanced visualization request interpretation with better error feedback
859
+ try:
860
+ vis_request = interpret_natural_language(prompt, df.columns.tolist())
861
+ if not vis_request:
862
+ raise ValueError("Could not interpret the visualization request")
863
+
864
+ # Validate the request against the actual data
865
+ if vis_request.x_column and vis_request.x_column not in df.columns:
866
+ raise ValueError(f"Column '{vis_request.x_column}' not found in data")
867
+
868
+ if vis_request.y_column and vis_request.y_column not in df.columns:
869
+ raise ValueError(f"Column '{vis_request.y_column}' not found in data")
870
+
871
+ if vis_request.hue_column and vis_request.hue_column not in df.columns:
872
+ raise ValueError(f"Column '{vis_request.hue_column}' not found in data")
873
+
874
+ except ValueError as e:
875
+ logger.error(f"Visualization interpretation failed: {str(e)}")
876
+ raise HTTPException(
877
+ status_code=400,
878
+ detail={
879
+ "error": "Could not understand your visualization request",
880
+ "message": str(e),
881
+ "suggestions": [
882
+ "Try being more specific (e.g., 'Show sales by region')",
883
+ f"Available columns: {list(df.columns)}",
884
+ "Supported chart types: line, bar, scatter, histogram, boxplot, heatmap"
885
+ ],
886
+ "your_prompt": prompt
887
+ }
888
+ )
889
+
890
  vis_request.style = style
891
 
892
+ # [Rest of your visualization code remains the same]
893
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
894
  except HTTPException as he:
895
  raise
896
  except Exception as e:
897
  logger.error(f"Unexpected error: {traceback.format_exc()}")
898
+ raise HTTPException(500, {
899
+ "error": "Internal server error",
900
+ "details": str(e)
901
+ })
902
 
903
 
904