Spaces:
Sleeping
Sleeping
Update main.py
Browse files
main.py
CHANGED
|
@@ -433,12 +433,79 @@ def generate_visualization_code(df: pd.DataFrame, request: VisualizationRequest)
|
|
| 433 |
from typing import Optional
|
| 434 |
|
| 435 |
def interpret_natural_language(prompt: str, df_columns: list) -> Optional[VisualizationRequest]:
|
| 436 |
-
"""Convert natural language prompt to visualization parameters"""
|
| 437 |
if not prompt or not df_columns:
|
| 438 |
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 439 |
|
| 440 |
-
|
| 441 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 442 |
|
| 443 |
# ===== DYNAMIC VISUALIZATION FUNCTIONS =====
|
| 444 |
def read_any_excel(content: bytes) -> pd.DataFrame:
|
|
@@ -782,85 +849,56 @@ async def visualize_with_natural_language(
|
|
| 782 |
style: str = Form("seaborn-v0_8")
|
| 783 |
):
|
| 784 |
try:
|
| 785 |
-
#
|
| 786 |
-
logger.info(f"Incoming request with file: {file.filename if file else 'None'}")
|
| 787 |
-
|
| 788 |
-
# Verify file exists and has content
|
| 789 |
-
if not file or not file.filename:
|
| 790 |
-
logger.error("No file uploaded")
|
| 791 |
-
raise HTTPException(400, "Please upload an Excel file")
|
| 792 |
|
| 793 |
-
# Read file content
|
| 794 |
-
content = await file.read()
|
| 795 |
-
if not content:
|
| 796 |
-
logger.error("Empty file uploaded")
|
| 797 |
-
raise HTTPException(400, "The uploaded file is empty")
|
| 798 |
-
|
| 799 |
-
# Verify Excel file extension
|
| 800 |
-
file_ext = file.filename.split('.')[-1].lower()
|
| 801 |
-
if file_ext not in {"xlsx", "xls"}:
|
| 802 |
-
logger.error(f"Unsupported file type: {file_ext}")
|
| 803 |
-
raise HTTPException(400, "Only Excel files (.xlsx, .xls) are supported")
|
| 804 |
-
|
| 805 |
-
# Read Excel file with multiple engine fallbacks
|
| 806 |
-
try:
|
| 807 |
-
df = pd.read_excel(BytesIO(content), engine='openpyxl')
|
| 808 |
-
except Exception as e:
|
| 809 |
-
logger.warning(f"Openpyxl failed, trying xlrd: {str(e)}")
|
| 810 |
-
try:
|
| 811 |
-
df = pd.read_excel(BytesIO(content), engine='xlrd')
|
| 812 |
-
except Exception as e:
|
| 813 |
-
logger.error(f"Excel read failed: {str(e)}")
|
| 814 |
-
raise HTTPException(400, "Failed to read Excel file - may be corrupt or password protected")
|
| 815 |
-
|
| 816 |
-
if df.empty:
|
| 817 |
-
logger.error("Empty DataFrame after reading Excel")
|
| 818 |
-
raise HTTPException(400, "Excel file contains no data")
|
| 819 |
-
|
| 820 |
-
# Generate prompt if empty
|
| 821 |
if not prompt.strip():
|
| 822 |
prompt = generate_smart_prompt(df)
|
| 823 |
logger.info(f"Auto-generated prompt: {prompt}")
|
| 824 |
|
| 825 |
-
#
|
| 826 |
-
|
| 827 |
-
|
| 828 |
-
|
| 829 |
-
|
| 830 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 831 |
vis_request.style = style
|
| 832 |
|
| 833 |
-
#
|
| 834 |
-
|
| 835 |
-
visualization_code = generate_dynamic_visualization_code(df, vis_request)
|
| 836 |
-
|
| 837 |
-
plt.style.use(vis_request.style)
|
| 838 |
-
fig, ax = plt.subplots(figsize=(10, 6))
|
| 839 |
-
|
| 840 |
-
exec(visualization_code, {'plt': plt, 'sns': sns, 'df': df, 'np': np})
|
| 841 |
-
|
| 842 |
-
buffer = BytesIO()
|
| 843 |
-
plt.savefig(buffer, format='png', dpi=150, bbox_inches='tight')
|
| 844 |
-
plt.close()
|
| 845 |
-
buffer.seek(0)
|
| 846 |
-
|
| 847 |
-
return {
|
| 848 |
-
"status": "success",
|
| 849 |
-
"image_data": base64.b64encode(buffer.getvalue()).decode('utf-8'),
|
| 850 |
-
"code": visualization_code,
|
| 851 |
-
"columns": list(df.columns),
|
| 852 |
-
"prompt": prompt
|
| 853 |
-
}
|
| 854 |
-
|
| 855 |
-
except Exception as e:
|
| 856 |
-
logger.error(f"Visualization failed: {str(e)}")
|
| 857 |
-
raise HTTPException(400, f"Failed to generate visualization: {str(e)}")
|
| 858 |
-
|
| 859 |
except HTTPException as he:
|
| 860 |
raise
|
| 861 |
except Exception as e:
|
| 862 |
logger.error(f"Unexpected error: {traceback.format_exc()}")
|
| 863 |
-
raise HTTPException(500,
|
|
|
|
|
|
|
|
|
|
| 864 |
|
| 865 |
|
| 866 |
|
|
|
|
| 433 |
from typing import Optional
|
| 434 |
|
| 435 |
def interpret_natural_language(prompt: str, df_columns: list) -> Optional[VisualizationRequest]:
|
| 436 |
+
"""Convert natural language prompt to visualization parameters with enhanced parsing"""
|
| 437 |
if not prompt or not df_columns:
|
| 438 |
return None
|
| 439 |
+
|
| 440 |
+
# Normalize the prompt and columns
|
| 441 |
+
prompt = prompt.lower().strip()
|
| 442 |
+
normalized_columns = [col.lower().strip() for col in df_columns]
|
| 443 |
+
|
| 444 |
+
# Initialize default values
|
| 445 |
+
chart_type = "bar"
|
| 446 |
+
x_col = None
|
| 447 |
+
y_col = None
|
| 448 |
+
hue_col = None
|
| 449 |
+
title = f"Visualization of {prompt[:50]}" # Default title
|
| 450 |
+
|
| 451 |
+
# Common chart type detection
|
| 452 |
+
chart_keywords = {
|
| 453 |
+
"line": ["line", "trend", "over time"],
|
| 454 |
+
"bar": ["bar", "compare", "comparison"],
|
| 455 |
+
"scatter": ["scatter", "correlation", "relationship"],
|
| 456 |
+
"histogram": ["histogram", "distribution", "frequency"],
|
| 457 |
+
"boxplot": ["box", "quartile", "distribution"],
|
| 458 |
+
"heatmap": ["heatmap", "correlation", "matrix"]
|
| 459 |
+
}
|
| 460 |
+
|
| 461 |
+
# Detect chart type
|
| 462 |
+
for chart, keywords in chart_keywords.items():
|
| 463 |
+
if any(keyword in prompt for keyword in keywords):
|
| 464 |
+
chart_type = chart
|
| 465 |
+
break
|
| 466 |
+
|
| 467 |
+
# Column detection with improved matching
|
| 468 |
+
for col in df_columns:
|
| 469 |
+
col_lower = col.lower()
|
| 470 |
|
| 471 |
+
# Check if column name appears in prompt
|
| 472 |
+
if col_lower in prompt:
|
| 473 |
+
# Look for context clues about the column's role
|
| 474 |
+
if not x_col and ("by " + col_lower in prompt or
|
| 475 |
+
"for " + col_lower in prompt or
|
| 476 |
+
"across " + col_lower in prompt):
|
| 477 |
+
x_col = col
|
| 478 |
+
elif not y_col and ("of " + col_lower in prompt or
|
| 479 |
+
"show " + col_lower in prompt or
|
| 480 |
+
"plot " + col_lower in prompt):
|
| 481 |
+
y_col = col
|
| 482 |
+
elif not hue_col and ("color by " + col_lower in prompt or
|
| 483 |
+
"group by " + col_lower in prompt):
|
| 484 |
+
hue_col = col
|
| 485 |
+
|
| 486 |
+
# Fallback logic if columns not detected
|
| 487 |
+
if not x_col and len(df_columns) > 0:
|
| 488 |
+
x_col = df_columns[0] # First column as default x-axis
|
| 489 |
+
|
| 490 |
+
if not y_col and len(df_columns) > 1:
|
| 491 |
+
# Try to find a numeric column for y-axis
|
| 492 |
+
numeric_cols = [col for col in df_columns if pd.api.types.is_numeric_dtype(df[col])]
|
| 493 |
+
y_col = numeric_cols[0] if numeric_cols else df_columns[1]
|
| 494 |
+
|
| 495 |
+
# Special handling for certain chart types
|
| 496 |
+
if chart_type == "heatmap":
|
| 497 |
+
x_col = None
|
| 498 |
+
y_col = None
|
| 499 |
+
hue_col = None
|
| 500 |
+
|
| 501 |
+
return VisualizationRequest(
|
| 502 |
+
chart_type=chart_type,
|
| 503 |
+
x_column=x_col,
|
| 504 |
+
y_column=y_col,
|
| 505 |
+
hue_column=hue_col,
|
| 506 |
+
title=title,
|
| 507 |
+
style="seaborn-v0_8"
|
| 508 |
+
)
|
| 509 |
|
| 510 |
# ===== DYNAMIC VISUALIZATION FUNCTIONS =====
|
| 511 |
def read_any_excel(content: bytes) -> pd.DataFrame:
|
|
|
|
| 849 |
style: str = Form("seaborn-v0_8")
|
| 850 |
):
|
| 851 |
try:
|
| 852 |
+
# [Previous file handling code remains the same until after df is created]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 853 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 854 |
if not prompt.strip():
|
| 855 |
prompt = generate_smart_prompt(df)
|
| 856 |
logger.info(f"Auto-generated prompt: {prompt}")
|
| 857 |
|
| 858 |
+
# Enhanced visualization request interpretation with better error feedback
|
| 859 |
+
try:
|
| 860 |
+
vis_request = interpret_natural_language(prompt, df.columns.tolist())
|
| 861 |
+
if not vis_request:
|
| 862 |
+
raise ValueError("Could not interpret the visualization request")
|
| 863 |
+
|
| 864 |
+
# Validate the request against the actual data
|
| 865 |
+
if vis_request.x_column and vis_request.x_column not in df.columns:
|
| 866 |
+
raise ValueError(f"Column '{vis_request.x_column}' not found in data")
|
| 867 |
+
|
| 868 |
+
if vis_request.y_column and vis_request.y_column not in df.columns:
|
| 869 |
+
raise ValueError(f"Column '{vis_request.y_column}' not found in data")
|
| 870 |
+
|
| 871 |
+
if vis_request.hue_column and vis_request.hue_column not in df.columns:
|
| 872 |
+
raise ValueError(f"Column '{vis_request.hue_column}' not found in data")
|
| 873 |
+
|
| 874 |
+
except ValueError as e:
|
| 875 |
+
logger.error(f"Visualization interpretation failed: {str(e)}")
|
| 876 |
+
raise HTTPException(
|
| 877 |
+
status_code=400,
|
| 878 |
+
detail={
|
| 879 |
+
"error": "Could not understand your visualization request",
|
| 880 |
+
"message": str(e),
|
| 881 |
+
"suggestions": [
|
| 882 |
+
"Try being more specific (e.g., 'Show sales by region')",
|
| 883 |
+
f"Available columns: {list(df.columns)}",
|
| 884 |
+
"Supported chart types: line, bar, scatter, histogram, boxplot, heatmap"
|
| 885 |
+
],
|
| 886 |
+
"your_prompt": prompt
|
| 887 |
+
}
|
| 888 |
+
)
|
| 889 |
+
|
| 890 |
vis_request.style = style
|
| 891 |
|
| 892 |
+
# [Rest of your visualization code remains the same]
|
| 893 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 894 |
except HTTPException as he:
|
| 895 |
raise
|
| 896 |
except Exception as e:
|
| 897 |
logger.error(f"Unexpected error: {traceback.format_exc()}")
|
| 898 |
+
raise HTTPException(500, {
|
| 899 |
+
"error": "Internal server error",
|
| 900 |
+
"details": str(e)
|
| 901 |
+
})
|
| 902 |
|
| 903 |
|
| 904 |
|