Upload 82 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- DB_store_backup/agentic_context_convoid_management.py +1157 -0
- DB_store_backup/rough.py +510 -0
- DB_store_backup/stored_convoId_data.py +423 -0
- Dockerfile +74 -0
- Redis/redis_agent_memory.py +0 -0
- Redis/rough.py +132 -0
- Redis/sessions.py +121 -0
- Redis/sessions_new.py +203 -0
- Redis/sessions_old.py +598 -0
- Routes/generate_report.py +71 -0
- Routes/helpers/autovis_tool.py +132 -0
- Routes/helpers/df_to_vis.py +255 -0
- Routes/helpers/duck_db_agent.py +321 -0
- Routes/helpers/main_agent_best_as_of_now.py +700 -0
- Routes/helpers/main_agent_best_as_of_now.py-update.txt +1110 -0
- Routes/helpers/main_agent_helpers.py +255 -0
- Routes/helpers/pandas_ai_agent.py +39 -0
- Routes/helpers/plotly_langchain_agent.py +207 -0
- Routes/helpers/report_generation_helpers.py +242 -0
- Routes/helpers/rough.py +161 -0
- Routes/main_agent_chat_bot_v2.py +178 -0
- Routes/main_chat_bot.py +588 -0
- Routes/main_chat_bot.py_main.txt +24 -0
- agent_tools/Hybrid_Rag_agent.py +93 -0
- agent_tools/Reflection_agent.py +80 -0
- agent_tools/autovis_tool.py +124 -0
- agent_tools/code_intepreter.py +67 -0
- alembic/README +1 -0
- alembic/__pycache__/env.cpython-312.pyc +0 -0
- alembic/env.py +94 -0
- alembic/script.py.mako +26 -0
- alembic/versions/049cf12dc407_removed_data_column_and_added_file_.py +52 -0
- alembic/versions/0f27f624c0f9_make_created_by_non_nullable.py +30 -0
- alembic/versions/2cb6dd9a9f5b_improve_models_add_indexes_cascades_.py +242 -0
- alembic/versions/3b084b14f4b1_add_created_by_field_to_datasets.py +34 -0
- alembic/versions/519b15d0dca6_added_two_columns_in_convo_table.py +50 -0
- alembic/versions/872b723d49c9_added_created_by_fields_in_required_.py +46 -0
- alembic/versions/__pycache__/049cf12dc407_removed_data_column_and_added_file_.cpython-312.pyc +0 -0
- alembic/versions/__pycache__/0f27f624c0f9_make_created_by_non_nullable.cpython-312.pyc +0 -0
- alembic/versions/__pycache__/2cb6dd9a9f5b_improve_models_add_indexes_cascades_.cpython-312.pyc +0 -0
- alembic/versions/__pycache__/3b084b14f4b1_add_created_by_field_to_datasets.cpython-312.pyc +0 -0
- alembic/versions/__pycache__/519b15d0dca6_added_two_columns_in_convo_table.cpython-312.pyc +0 -0
- alembic/versions/__pycache__/872b723d49c9_added_created_by_fields_in_required_.cpython-312.pyc +0 -0
- alembic/versions/__pycache__/ac5b502d055a_added_userdatasetsmetadata_table.cpython-312.pyc +0 -0
- alembic/versions/__pycache__/c378ad11cd73_initial_migration_capture_current_schema.cpython-312.pyc +0 -0
- alembic/versions/ac5b502d055a_added_userdatasetsmetadata_table.py +98 -0
- alembic/versions/c378ad11cd73_initial_migration_capture_current_schema.py +30 -0
- app.py +91 -0
- backend/__init__.py +0 -0
- backend/__pycache__/__init__.cpython-312.pyc +0 -0
DB_store_backup/agentic_context_convoid_management.py
ADDED
|
@@ -0,0 +1,1157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# from fastapi import FastAPI, HTTPException, Query
|
| 2 |
+
# from fastapi.middleware.cors import CORSMiddleware
|
| 3 |
+
# from pydantic import BaseModel
|
| 4 |
+
# from typing import Optional, List, Dict, Any
|
| 5 |
+
# import psycopg2
|
| 6 |
+
# from psycopg2 import sql
|
| 7 |
+
# import json
|
| 8 |
+
# from datetime import datetime
|
| 9 |
+
# from fastapi import APIRouter
|
| 10 |
+
# from fastapi.responses import JSONResponse
|
| 11 |
+
|
| 12 |
+
# # PostgreSQL credentials
|
| 13 |
+
# PGHOST = 'ep-steep-dream-adqtvjel-pooler.c-2.us-east-1.aws.neon.tech'
|
| 14 |
+
# PGDATABASE = 'neondb'
|
| 15 |
+
# PGUSER = 'neondb_owner'
|
| 16 |
+
# PGPASSWORD = 'npg_Qq0B1uWRXavx'
|
| 17 |
+
# PGSSLMODE = 'require'
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
# # Add CORS middleware
|
| 21 |
+
|
| 22 |
+
# Db_store_router= APIRouter(prefix="/Db_store_router", tags=["Db_store_router"])
|
| 23 |
+
|
| 24 |
+
# # Pydantic models
|
| 25 |
+
# class UserMetadata(BaseModel):
|
| 26 |
+
# location: Optional[str] = None
|
| 27 |
+
# language: Optional[str] = None
|
| 28 |
+
|
| 29 |
+
# class UserQuery(BaseModel):
|
| 30 |
+
# query_id: str
|
| 31 |
+
# text: str
|
| 32 |
+
# user_metadata: Optional[UserMetadata] = None
|
| 33 |
+
|
| 34 |
+
# class ArtifactMetadata(BaseModel):
|
| 35 |
+
# created_by: Optional[str] = None
|
| 36 |
+
# associated_query_id: Optional[str] = None
|
| 37 |
+
# associated_session_id: Optional[str] = None
|
| 38 |
+
# artifact_timestamp: Optional[str] = None
|
| 39 |
+
|
| 40 |
+
# class Artifact(BaseModel):
|
| 41 |
+
# artifact_id: str
|
| 42 |
+
# file_id: str
|
| 43 |
+
# file_name: str
|
| 44 |
+
# file_type: str
|
| 45 |
+
# file_size: int
|
| 46 |
+
# file_url: str
|
| 47 |
+
# upload_timestamp: str
|
| 48 |
+
# metadata: Optional[ArtifactMetadata] = None
|
| 49 |
+
|
| 50 |
+
# class Response(BaseModel):
|
| 51 |
+
# text: str
|
| 52 |
+
# status: str
|
| 53 |
+
# response_time: str
|
| 54 |
+
# duration: str
|
| 55 |
+
# artifacts: List[Artifact]
|
| 56 |
+
|
| 57 |
+
# class conversation_data(BaseModel):
|
| 58 |
+
# user_query: UserQuery
|
| 59 |
+
# response: Response
|
| 60 |
+
# metadata: Optional[Dict[str, Any]] = None
|
| 61 |
+
|
| 62 |
+
# class ConversationInput(BaseModel):
|
| 63 |
+
# convo_id: str
|
| 64 |
+
# data: conversation_data
|
| 65 |
+
# is_saved: Optional[bool] = False
|
| 66 |
+
|
| 67 |
+
# class UpdateSavedStatus(BaseModel):
|
| 68 |
+
# is_saved: bool
|
| 69 |
+
|
| 70 |
+
# # Database connection
|
| 71 |
+
# def connect_to_db():
|
| 72 |
+
# try:
|
| 73 |
+
# conn = psycopg2.connect(
|
| 74 |
+
# host=PGHOST,
|
| 75 |
+
# database=PGDATABASE,
|
| 76 |
+
# user=PGUSER,
|
| 77 |
+
# password=PGPASSWORD,
|
| 78 |
+
# sslmode=PGSSLMODE
|
| 79 |
+
# )
|
| 80 |
+
# return conn
|
| 81 |
+
# except Exception as e:
|
| 82 |
+
# raise HTTPException(status_code=500, detail=f"Database connection error: {str(e)}")
|
| 83 |
+
|
| 84 |
+
# # Database initialization
|
| 85 |
+
# def create_table():
|
| 86 |
+
# conn = connect_to_db()
|
| 87 |
+
# cursor = conn.cursor()
|
| 88 |
+
|
| 89 |
+
# create_table_query = """
|
| 90 |
+
# CREATE TABLE IF NOT EXISTS "conversation_data" (
|
| 91 |
+
# id SERIAL PRIMARY KEY,
|
| 92 |
+
# convo_id VARCHAR(255) NOT NULL,
|
| 93 |
+
# user_query JSONB NOT NULL,
|
| 94 |
+
# response JSONB NOT NULL,
|
| 95 |
+
# file_metadata JSONB,
|
| 96 |
+
# is_saved BOOLEAN DEFAULT FALSE,
|
| 97 |
+
# created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
| 98 |
+
# updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
| 99 |
+
# )
|
| 100 |
+
# """
|
| 101 |
+
|
| 102 |
+
# create_index_query = """
|
| 103 |
+
# CREATE INDEX IF NOT EXISTS idx_convo_id_is_saved
|
| 104 |
+
# ON "conversation_data" (convo_id, is_saved);
|
| 105 |
+
# """
|
| 106 |
+
|
| 107 |
+
# try:
|
| 108 |
+
# cursor.execute(create_table_query)
|
| 109 |
+
# cursor.execute(create_index_query)
|
| 110 |
+
# conn.commit()
|
| 111 |
+
# except Exception as e:
|
| 112 |
+
# conn.rollback()
|
| 113 |
+
# raise HTTPException(status_code=500, detail=f"Table creation error: {str(e)}")
|
| 114 |
+
# finally:
|
| 115 |
+
# cursor.close()
|
| 116 |
+
# conn.close()
|
| 117 |
+
|
| 118 |
+
# # Initialize table on startup
|
| 119 |
+
# @Db_store_router.on_event("startup")
|
| 120 |
+
# async def startup_event():
|
| 121 |
+
# create_table()
|
| 122 |
+
|
| 123 |
+
# # API Endpoints
|
| 124 |
+
|
| 125 |
+
# @Db_store_router.get("/")
|
| 126 |
+
# async def root():
|
| 127 |
+
# return {
|
| 128 |
+
# "message": "Conversation Data API",
|
| 129 |
+
# "version": "1.0.0",
|
| 130 |
+
# "endpoints": {
|
| 131 |
+
# "POST /conversations": "Create a new conversation",
|
| 132 |
+
# "GET /conversations/{convo_id}": "Get conversation by ID",
|
| 133 |
+
# "GET /conversations": "Get conversations with filters",
|
| 134 |
+
# "PUT /conversations/{convo_id}/saved": "Update saved status",
|
| 135 |
+
# "DELETE /conversations/{convo_id}": "Delete conversation",
|
| 136 |
+
# "GET /health": "Health check"
|
| 137 |
+
# }
|
| 138 |
+
# }
|
| 139 |
+
|
| 140 |
+
# @Db_store_router.get("/health")
|
| 141 |
+
# async def health_check():
|
| 142 |
+
# try:
|
| 143 |
+
# conn = connect_to_db()
|
| 144 |
+
# conn.close()
|
| 145 |
+
# return {"status": "healthy", "database": "connected"}
|
| 146 |
+
# except Exception as e:
|
| 147 |
+
# raise HTTPException(status_code=503, detail=f"Database unhealthy: {str(e)}")
|
| 148 |
+
|
| 149 |
+
# @Db_store_router.post("/conversations", status_code=201)
|
| 150 |
+
# async def create_conversation(conversation: ConversationInput):
|
| 151 |
+
# """Create a new conversation entry"""
|
| 152 |
+
# conn = connect_to_db()
|
| 153 |
+
# cursor = conn.cursor()
|
| 154 |
+
|
| 155 |
+
# insert_query = sql.SQL("""
|
| 156 |
+
# INSERT INTO "conversation_data" (convo_id, user_query, response, file_metadata, is_saved)
|
| 157 |
+
# VALUES (%s, %s, %s, %s, %s)
|
| 158 |
+
# RETURNING id, created_at
|
| 159 |
+
# """)
|
| 160 |
+
|
| 161 |
+
# # Convert data to JSON
|
| 162 |
+
# user_query = json.dumps(conversation.data.user_query.dict())
|
| 163 |
+
# response = json.dumps(conversation.data.response.dict())
|
| 164 |
+
# file_metadata = json.dumps([artifact.dict() for artifact in conversation.data.response.artifacts])
|
| 165 |
+
|
| 166 |
+
# try:
|
| 167 |
+
# cursor.execute(insert_query, (
|
| 168 |
+
# conversation.convo_id,
|
| 169 |
+
# user_query,
|
| 170 |
+
# response,
|
| 171 |
+
# file_metadata,
|
| 172 |
+
# conversation.is_saved
|
| 173 |
+
# ))
|
| 174 |
+
# result = cursor.fetchone()
|
| 175 |
+
# conn.commit()
|
| 176 |
+
|
| 177 |
+
# return {
|
| 178 |
+
# "message": "Conversation created successfully",
|
| 179 |
+
# "convo_id": conversation.convo_id,
|
| 180 |
+
# "id": result[0],
|
| 181 |
+
# "created_at": str(result[1]),
|
| 182 |
+
# "is_saved": conversation.is_saved
|
| 183 |
+
# }
|
| 184 |
+
# except Exception as e:
|
| 185 |
+
# conn.rollback()
|
| 186 |
+
# raise HTTPException(status_code=500, detail=f"Error creating conversation: {str(e)}")
|
| 187 |
+
# finally:
|
| 188 |
+
# cursor.close()
|
| 189 |
+
# conn.close()
|
| 190 |
+
|
| 191 |
+
# @Db_store_router.get("/conversations/{convo_id}")
|
| 192 |
+
# async def get_conversation_by_id(
|
| 193 |
+
# convo_id: str,
|
| 194 |
+
# is_saved: Optional[bool] = Query(None, description="Filter by saved status")
|
| 195 |
+
# ):
|
| 196 |
+
# """Get conversation(s) by convo_id, optionally filtered by saved status"""
|
| 197 |
+
# conn = connect_to_db()
|
| 198 |
+
# cursor = conn.cursor()
|
| 199 |
+
|
| 200 |
+
# if is_saved is not None:
|
| 201 |
+
# select_query = """
|
| 202 |
+
# SELECT convo_id, user_query, response, file_metadata, is_saved, created_at, updated_at
|
| 203 |
+
# FROM "conversation_data"
|
| 204 |
+
# WHERE convo_id = %s AND is_saved = %s
|
| 205 |
+
# ORDER BY created_at DESC
|
| 206 |
+
# """
|
| 207 |
+
# cursor.execute(select_query, (convo_id, is_saved))
|
| 208 |
+
# else:
|
| 209 |
+
# select_query = """
|
| 210 |
+
# SELECT convo_id, user_query, response, file_metadata, is_saved, created_at, updated_at
|
| 211 |
+
# FROM "conversation_data"
|
| 212 |
+
# WHERE convo_id = %s
|
| 213 |
+
# ORDER BY created_at DESC
|
| 214 |
+
# """
|
| 215 |
+
# cursor.execute(select_query, (convo_id,))
|
| 216 |
+
|
| 217 |
+
# try:
|
| 218 |
+
# results = cursor.fetchall()
|
| 219 |
+
|
| 220 |
+
# if not results:
|
| 221 |
+
# raise HTTPException(status_code=404, detail=f"No conversations found for convo_id: {convo_id}")
|
| 222 |
+
|
| 223 |
+
# conversations = []
|
| 224 |
+
# for result in results:
|
| 225 |
+
# stored_convo_id, user_query_json, response_json, file_metadata_json, saved, created_at, updated_at = result
|
| 226 |
+
|
| 227 |
+
# conversations.append({
|
| 228 |
+
# "convo_id": stored_convo_id,
|
| 229 |
+
# "user_query": user_query_json,
|
| 230 |
+
# "response": response_json,
|
| 231 |
+
# "file_metadata": file_metadata_json,
|
| 232 |
+
# "is_saved": saved,
|
| 233 |
+
# "created_at": str(created_at),
|
| 234 |
+
# "updated_at": str(updated_at)
|
| 235 |
+
# })
|
| 236 |
+
|
| 237 |
+
# return {
|
| 238 |
+
# "count": len(conversations),
|
| 239 |
+
# "conversations": conversations
|
| 240 |
+
# }
|
| 241 |
+
# except HTTPException:
|
| 242 |
+
# raise
|
| 243 |
+
# except Exception as e:
|
| 244 |
+
# raise HTTPException(status_code=500, detail=f"Error retrieving conversation: {str(e)}")
|
| 245 |
+
# finally:
|
| 246 |
+
# cursor.close()
|
| 247 |
+
# conn.close()
|
| 248 |
+
|
| 249 |
+
# @Db_store_router.get("/conversations")
|
| 250 |
+
# async def get_conversations(
|
| 251 |
+
# is_saved: Optional[bool] = Query(None, description="Filter by saved status"),
|
| 252 |
+
# limit: int = Query(100, ge=1, le=1000, description="Maximum number of results"),
|
| 253 |
+
# offset: int = Query(0, ge=0, description="Number of results to skip")
|
| 254 |
+
# ):
|
| 255 |
+
# """Get all conversations with optional filters"""
|
| 256 |
+
# conn = connect_to_db()
|
| 257 |
+
# cursor = conn.cursor()
|
| 258 |
+
|
| 259 |
+
# if is_saved is not None:
|
| 260 |
+
# select_query = """
|
| 261 |
+
# SELECT convo_id, user_query, response, file_metadata, is_saved, created_at, updated_at
|
| 262 |
+
# FROM "conversation_data"
|
| 263 |
+
# WHERE is_saved = %s
|
| 264 |
+
# ORDER BY updated_at DESC
|
| 265 |
+
# LIMIT %s OFFSET %s
|
| 266 |
+
# """
|
| 267 |
+
# cursor.execute(select_query, (is_saved, limit, offset))
|
| 268 |
+
# else:
|
| 269 |
+
# select_query = """
|
| 270 |
+
# SELECT convo_id, user_query, response, file_metadata, is_saved, created_at, updated_at
|
| 271 |
+
# FROM "conversation_data"
|
| 272 |
+
# ORDER BY updated_at DESC
|
| 273 |
+
# LIMIT %s OFFSET %s
|
| 274 |
+
# """
|
| 275 |
+
# cursor.execute(select_query, (limit, offset))
|
| 276 |
+
|
| 277 |
+
# try:
|
| 278 |
+
# results = cursor.fetchall()
|
| 279 |
+
|
| 280 |
+
# if not results:
|
| 281 |
+
# return {"count": 0, "conversations": []}
|
| 282 |
+
|
| 283 |
+
# conversations = []
|
| 284 |
+
# for result in results:
|
| 285 |
+
# stored_convo_id, user_query_json, response_json, file_metadata_json, saved, created_at, updated_at = result
|
| 286 |
+
|
| 287 |
+
# conversations.append({
|
| 288 |
+
# "convo_id": stored_convo_id,
|
| 289 |
+
# "user_query": user_query_json,
|
| 290 |
+
# "response": response_json,
|
| 291 |
+
# "file_metadata": file_metadata_json,
|
| 292 |
+
# "is_saved": saved,
|
| 293 |
+
# "created_at": str(created_at),
|
| 294 |
+
# "updated_at": str(updated_at)
|
| 295 |
+
# })
|
| 296 |
+
|
| 297 |
+
# return {
|
| 298 |
+
# "count": len(conversations),
|
| 299 |
+
# "limit": limit,
|
| 300 |
+
# "offset": offset,
|
| 301 |
+
# "conversations": conversations
|
| 302 |
+
# }
|
| 303 |
+
# except Exception as e:
|
| 304 |
+
# raise HTTPException(status_code=500, detail=f"Error retrieving conversations: {str(e)}")
|
| 305 |
+
# finally:
|
| 306 |
+
# cursor.close()
|
| 307 |
+
# conn.close()
|
| 308 |
+
|
| 309 |
+
# @Db_store_router.put("/conversations/{convo_id}/saved")
|
| 310 |
+
# async def update_saved_status(convo_id: str, status: UpdateSavedStatus):
|
| 311 |
+
# """Update the is_saved status for a conversation"""
|
| 312 |
+
# conn = connect_to_db()
|
| 313 |
+
# cursor = conn.cursor()
|
| 314 |
+
|
| 315 |
+
# update_query = """
|
| 316 |
+
# UPDATE "conversation_data"
|
| 317 |
+
# SET is_saved = %s, updated_at = CURRENT_TIMESTAMP
|
| 318 |
+
# WHERE convo_id = %s
|
| 319 |
+
# """
|
| 320 |
+
|
| 321 |
+
# try:
|
| 322 |
+
# cursor.execute(update_query, (status.is_saved, convo_id))
|
| 323 |
+
# conn.commit()
|
| 324 |
+
|
| 325 |
+
# if cursor.rowcount == 0:
|
| 326 |
+
# raise HTTPException(status_code=404, detail=f"No conversation found with convo_id: {convo_id}")
|
| 327 |
+
|
| 328 |
+
# return {
|
| 329 |
+
# "message": f"Updated {cursor.rowcount} record(s)",
|
| 330 |
+
# "convo_id": convo_id,
|
| 331 |
+
# "is_saved": status.is_saved
|
| 332 |
+
# }
|
| 333 |
+
# except HTTPException:
|
| 334 |
+
# raise
|
| 335 |
+
# except Exception as e:
|
| 336 |
+
# conn.rollback()
|
| 337 |
+
# raise HTTPException(status_code=500, detail=f"Error updating saved status: {str(e)}")
|
| 338 |
+
# finally:
|
| 339 |
+
# cursor.close()
|
| 340 |
+
# conn.close()
|
| 341 |
+
|
| 342 |
+
# @Db_store_router.delete("/conversations/{convo_id}")
|
| 343 |
+
# async def delete_conversation(
|
| 344 |
+
# convo_id: str,
|
| 345 |
+
# delete_all: bool = Query(False, description="Delete all records with this convo_id")
|
| 346 |
+
# ):
|
| 347 |
+
# """Delete conversation(s) by convo_id"""
|
| 348 |
+
# conn = connect_to_db()
|
| 349 |
+
# cursor = conn.cursor()
|
| 350 |
+
|
| 351 |
+
# if delete_all:
|
| 352 |
+
# delete_query = 'DELETE FROM "conversation_data" WHERE convo_id = %s'
|
| 353 |
+
# cursor.execute(delete_query, (convo_id,))
|
| 354 |
+
# else:
|
| 355 |
+
# # Delete only the most recent one
|
| 356 |
+
# delete_query = """
|
| 357 |
+
# DELETE FROM "conversation_data"
|
| 358 |
+
# WHERE id = (
|
| 359 |
+
# SELECT id FROM "conversation_data"
|
| 360 |
+
# WHERE convo_id = %s
|
| 361 |
+
# ORDER BY created_at DESC
|
| 362 |
+
# LIMIT 1
|
| 363 |
+
# )
|
| 364 |
+
# """
|
| 365 |
+
# cursor.execute(delete_query, (convo_id,))
|
| 366 |
+
|
| 367 |
+
# try:
|
| 368 |
+
# conn.commit()
|
| 369 |
+
|
| 370 |
+
# if cursor.rowcount == 0:
|
| 371 |
+
# raise HTTPException(status_code=404, detail=f"No conversation found with convo_id: {convo_id}")
|
| 372 |
+
|
| 373 |
+
# return {
|
| 374 |
+
# "message": f"Deleted {cursor.rowcount} record(s)",
|
| 375 |
+
# "convo_id": convo_id
|
| 376 |
+
# }
|
| 377 |
+
# except HTTPException:
|
| 378 |
+
# raise
|
| 379 |
+
# except Exception as e:
|
| 380 |
+
# conn.rollback()
|
| 381 |
+
# raise HTTPException(status_code=500, detail=f"Error deleting conversation: {str(e)}")
|
| 382 |
+
# finally:
|
| 383 |
+
# cursor.close()
|
| 384 |
+
# conn.close()
|
| 385 |
+
|
| 386 |
+
# @Db_store_router.get("/stats")
|
| 387 |
+
# async def get_statistics():
|
| 388 |
+
# """Get statistics about stored conversations"""
|
| 389 |
+
# conn = connect_to_db()
|
| 390 |
+
# cursor = conn.cursor()
|
| 391 |
+
|
| 392 |
+
# stats_query = """
|
| 393 |
+
# SELECT
|
| 394 |
+
# COUNT(*) as total_conversations,
|
| 395 |
+
# COUNT(DISTINCT convo_id) as unique_convo_ids,
|
| 396 |
+
# SUM(CASE WHEN is_saved = TRUE THEN 1 ELSE 0 END) as saved_conversations,
|
| 397 |
+
# SUM(CASE WHEN is_saved = FALSE THEN 1 ELSE 0 END) as unsaved_conversations
|
| 398 |
+
# FROM "conversation_data"
|
| 399 |
+
# """
|
| 400 |
+
|
| 401 |
+
# try:
|
| 402 |
+
# cursor.execute(stats_query)
|
| 403 |
+
# result = cursor.fetchone()
|
| 404 |
+
|
| 405 |
+
# return {
|
| 406 |
+
# "total_conversations": result[0],
|
| 407 |
+
# "unique_convo_ids": result[1],
|
| 408 |
+
# "saved_conversations": result[2],
|
| 409 |
+
# "unsaved_conversations": result[3]
|
| 410 |
+
# }
|
| 411 |
+
# except Exception as e:
|
| 412 |
+
# raise HTTPException(status_code=500, detail=f"Error getting statistics: {str(e)}")
|
| 413 |
+
# finally:
|
| 414 |
+
# cursor.close()
|
| 415 |
+
# conn.close()
|
| 416 |
+
|
| 417 |
+
# # ------------------ FETCH DATA FUNCTION ------------------
|
| 418 |
+
# def get_all_convo_data_filtered(is_saved=None):
|
| 419 |
+
# conn = connect_to_db()
|
| 420 |
+
# if not conn:
|
| 421 |
+
# return None
|
| 422 |
+
|
| 423 |
+
# cursor = conn.cursor()
|
| 424 |
+
|
| 425 |
+
# if is_saved is None:
|
| 426 |
+
# select_query = """
|
| 427 |
+
# SELECT convo_id, user_query, response, file_metadata, is_saved, created_at, updated_at
|
| 428 |
+
# FROM "conversation_data"
|
| 429 |
+
# ORDER BY updated_at DESC
|
| 430 |
+
# """
|
| 431 |
+
# params = ()
|
| 432 |
+
# else:
|
| 433 |
+
# select_query = """
|
| 434 |
+
# SELECT convo_id, user_query, response, file_metadata, is_saved, created_at, updated_at
|
| 435 |
+
# FROM "conversation_data"
|
| 436 |
+
# WHERE is_saved = %s
|
| 437 |
+
# ORDER BY updated_at DESC
|
| 438 |
+
# """
|
| 439 |
+
# params = (is_saved,)
|
| 440 |
+
|
| 441 |
+
# try:
|
| 442 |
+
# cursor.execute(select_query, params)
|
| 443 |
+
# results = cursor.fetchall()
|
| 444 |
+
|
| 445 |
+
# if not results:
|
| 446 |
+
# return []
|
| 447 |
+
|
| 448 |
+
# conversations = []
|
| 449 |
+
# for result in results:
|
| 450 |
+
# stored_convo_id, user_query_json, response_json, file_metadata_json, is_saved_flag, created_at, updated_at = result
|
| 451 |
+
|
| 452 |
+
# try:
|
| 453 |
+
# user_query = user_query_json if isinstance(user_query_json, dict) else json.loads(user_query_json)
|
| 454 |
+
# except Exception:
|
| 455 |
+
# user_query = {}
|
| 456 |
+
|
| 457 |
+
# try:
|
| 458 |
+
# response = response_json if isinstance(response_json, dict) else json.loads(response_json)
|
| 459 |
+
# except Exception:
|
| 460 |
+
# response = {}
|
| 461 |
+
|
| 462 |
+
# try:
|
| 463 |
+
# file_metadata = file_metadata_json if isinstance(file_metadata_json, (dict, list)) else json.loads(file_metadata_json)
|
| 464 |
+
# except Exception:
|
| 465 |
+
# file_metadata = {}
|
| 466 |
+
|
| 467 |
+
# conversations.append({
|
| 468 |
+
# "convo_id": stored_convo_id,
|
| 469 |
+
# "user_query": user_query,
|
| 470 |
+
# "response": response,
|
| 471 |
+
# "file_metadata": file_metadata,
|
| 472 |
+
# "is_saved": is_saved_flag,
|
| 473 |
+
# "created_at": str(created_at),
|
| 474 |
+
# "updated_at": str(updated_at)
|
| 475 |
+
# })
|
| 476 |
+
|
| 477 |
+
# return conversations
|
| 478 |
+
|
| 479 |
+
# except Exception as e:
|
| 480 |
+
# print(f"❌ Error retrieving data: {e}")
|
| 481 |
+
# return None
|
| 482 |
+
|
| 483 |
+
# finally:
|
| 484 |
+
# cursor.close()
|
| 485 |
+
# conn.close()
|
| 486 |
+
|
| 487 |
+
|
| 488 |
+
# # ------------------ FASTAPI ENDPOINT ------------------
|
| 489 |
+
# @Db_store_router.get("/conversations")
|
| 490 |
+
# def fetch_conversations(is_saved: str = Query(None, description="true | false | leave empty for all")):
|
| 491 |
+
# """
|
| 492 |
+
# Endpoint: GET /conversations?is_saved=true|false
|
| 493 |
+
# Returns conversations filtered by is_saved flag or all if not provided.
|
| 494 |
+
# """
|
| 495 |
+
# try:
|
| 496 |
+
# # Convert query param to Python boolean or None
|
| 497 |
+
# if is_saved is None:
|
| 498 |
+
# is_saved_value = None
|
| 499 |
+
# elif is_saved.lower() == "true":
|
| 500 |
+
# is_saved_value = True
|
| 501 |
+
# elif is_saved.lower() == "false":
|
| 502 |
+
# is_saved_value = False
|
| 503 |
+
# else:
|
| 504 |
+
# return JSONResponse(
|
| 505 |
+
# status_code=400,
|
| 506 |
+
# content={"error": "Invalid value for is_saved. Use true, false, or omit it."}
|
| 507 |
+
# )
|
| 508 |
+
|
| 509 |
+
# data = get_all_convo_data_filtered(is_saved_value)
|
| 510 |
+
|
| 511 |
+
# if data is None:
|
| 512 |
+
# return JSONResponse(status_code=500, content={"error": "Database query failed."})
|
| 513 |
+
|
| 514 |
+
# if len(data) == 0:
|
| 515 |
+
# return JSONResponse(status_code=404, content={"message": "No conversations found."})
|
| 516 |
+
|
| 517 |
+
# return JSONResponse(
|
| 518 |
+
# status_code=200,
|
| 519 |
+
# content={
|
| 520 |
+
# "count": len(data),
|
| 521 |
+
# "filter": is_saved_value,
|
| 522 |
+
# "conversations": data
|
| 523 |
+
# }
|
| 524 |
+
# )
|
| 525 |
+
|
| 526 |
+
# except Exception as e:
|
| 527 |
+
# print(f"❌ API Error: {e}")
|
| 528 |
+
# return JSONResponse(status_code=500, content={"error": str(e)})
|
| 529 |
+
|
| 530 |
+
|
| 531 |
+
from fastapi import FastAPI, HTTPException, Query , Depends
|
| 532 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 533 |
+
from pydantic import BaseModel
|
| 534 |
+
from typing import Optional, List, Dict, Any
|
| 535 |
+
import psycopg2
|
| 536 |
+
from psycopg2 import sql
|
| 537 |
+
import json
|
| 538 |
+
from datetime import datetime
|
| 539 |
+
from fastapi import APIRouter
|
| 540 |
+
from fastapi.responses import JSONResponse
|
| 541 |
+
from backend.database import get_db
|
| 542 |
+
from sqlalchemy.orm import Session
|
| 543 |
+
from datetime import timedelta, datetime, timezone
|
| 544 |
+
from sqlalchemy import text
|
| 545 |
+
|
| 546 |
+
# PostgreSQL credentials
|
| 547 |
+
# PGHOST = 'ep-steep-dream-adqtvjel-pooler.c-2.us-east-1.aws.neon.tech'
|
| 548 |
+
# PGDATABASE = 'neondb'
|
| 549 |
+
# PGUSER = 'neondb_owner'
|
| 550 |
+
# PGPASSWORD = 'npg_Qq0B1uWRXavx'
|
| 551 |
+
# PGSSLMODE = 'require'
|
| 552 |
+
|
| 553 |
+
PGHOST ='igwapp-proddb.ingenspark.com'
|
| 554 |
+
PGUSER = 'postgres'
|
| 555 |
+
PGPASSWORD = '0gMaLjyKeO5PUAznprQ1'
|
| 556 |
+
PGDATABASE = 'mrapp-demodb'
|
| 557 |
+
|
| 558 |
+
|
| 559 |
+
# Add CORS middleware
|
| 560 |
+
Db_store_router = APIRouter(prefix="/Db_store_router", tags=["Db_store_router"])
|
| 561 |
+
|
| 562 |
+
# Pydantic models
|
| 563 |
+
class UserMetadata(BaseModel):
|
| 564 |
+
location: Optional[str] = None
|
| 565 |
+
language: Optional[str] = None
|
| 566 |
+
|
| 567 |
+
class UserQuery(BaseModel):
|
| 568 |
+
query_id: str
|
| 569 |
+
text: str
|
| 570 |
+
user_metadata: Optional[UserMetadata] = None
|
| 571 |
+
|
| 572 |
+
class ArtifactMetadata(BaseModel):
|
| 573 |
+
created_by: Optional[str] = None
|
| 574 |
+
associated_query_id: Optional[str] = None
|
| 575 |
+
associated_session_id: Optional[str] = None
|
| 576 |
+
artifact_timestamp: Optional[str] = None
|
| 577 |
+
|
| 578 |
+
class Artifact(BaseModel):
|
| 579 |
+
artifact_id: str
|
| 580 |
+
file_id: str
|
| 581 |
+
file_name: str
|
| 582 |
+
file_type: str
|
| 583 |
+
file_size: int
|
| 584 |
+
file_url: str
|
| 585 |
+
upload_timestamp: str
|
| 586 |
+
metadata: Optional[ArtifactMetadata] = None
|
| 587 |
+
|
| 588 |
+
class Response(BaseModel):
|
| 589 |
+
text: str
|
| 590 |
+
status: str
|
| 591 |
+
response_time: str
|
| 592 |
+
duration: str
|
| 593 |
+
artifacts: List[Artifact]
|
| 594 |
+
|
| 595 |
+
class conversation_data(BaseModel):
|
| 596 |
+
user_query: UserQuery
|
| 597 |
+
response: Response
|
| 598 |
+
metadata: Optional[Dict[str, Any]] = None
|
| 599 |
+
|
| 600 |
+
class ConversationInput(BaseModel):
|
| 601 |
+
user_id: str
|
| 602 |
+
convo_id: str
|
| 603 |
+
data: conversation_data
|
| 604 |
+
is_saved: Optional[bool] = False
|
| 605 |
+
|
| 606 |
+
class UpdateSavedStatus(BaseModel):
|
| 607 |
+
is_saved: bool
|
| 608 |
+
|
| 609 |
+
# Database connection
|
| 610 |
+
def connect_to_db():
|
| 611 |
+
try:
|
| 612 |
+
conn = psycopg2.connect(
|
| 613 |
+
host=PGHOST,
|
| 614 |
+
database=PGDATABASE,
|
| 615 |
+
user=PGUSER,
|
| 616 |
+
password=PGPASSWORD
|
| 617 |
+
# sslmode=PGSSLMODE
|
| 618 |
+
)
|
| 619 |
+
return conn
|
| 620 |
+
except Exception as e:
|
| 621 |
+
raise HTTPException(status_code=500, detail=f"Database connection error: {str(e)}")
|
| 622 |
+
|
| 623 |
+
# Database initialization
|
| 624 |
+
def create_table():
|
| 625 |
+
conn = connect_to_db()
|
| 626 |
+
cursor = conn.cursor()
|
| 627 |
+
|
| 628 |
+
create_table_query = """
|
| 629 |
+
CREATE TABLE IF NOT EXISTS "conversation_data" (
|
| 630 |
+
id SERIAL PRIMARY KEY,
|
| 631 |
+
user_id VARCHAR(255) NOT NULL,
|
| 632 |
+
convo_id VARCHAR(255) NOT NULL,
|
| 633 |
+
user_query JSONB NOT NULL,
|
| 634 |
+
response JSONB NOT NULL,
|
| 635 |
+
file_metadata JSONB,
|
| 636 |
+
is_saved BOOLEAN DEFAULT FALSE,
|
| 637 |
+
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
| 638 |
+
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
| 639 |
+
)
|
| 640 |
+
"""
|
| 641 |
+
|
| 642 |
+
create_index_query = """
|
| 643 |
+
CREATE INDEX IF NOT EXISTS idx_user_id_convo_id_is_saved
|
| 644 |
+
ON "conversation_data" (user_id, convo_id, is_saved);
|
| 645 |
+
"""
|
| 646 |
+
|
| 647 |
+
try:
|
| 648 |
+
# cursor.execute(create_table_query)
|
| 649 |
+
# cursor.execute(create_index_query)
|
| 650 |
+
conn.commit()
|
| 651 |
+
except Exception as e:
|
| 652 |
+
conn.rollback()
|
| 653 |
+
raise HTTPException(status_code=500, detail=f"Table creation error: {str(e)}")
|
| 654 |
+
finally:
|
| 655 |
+
cursor.close()
|
| 656 |
+
conn.close()
|
| 657 |
+
|
| 658 |
+
# Initialize table on startup
|
| 659 |
+
# @Db_store_router.on_event("startup")
|
| 660 |
+
# async def startup_event():
|
| 661 |
+
# create_table()
|
| 662 |
+
|
| 663 |
+
# API Endpoints
|
| 664 |
+
@Db_store_router.get("/")
|
| 665 |
+
async def root():
|
| 666 |
+
return {
|
| 667 |
+
"message": "Conversation Data API",
|
| 668 |
+
"version": "1.0.0",
|
| 669 |
+
"endpoints": {
|
| 670 |
+
"POST /conversations": "Create a new conversation",
|
| 671 |
+
"GET /conversations/{convo_id}": "Get conversation by ID",
|
| 672 |
+
"GET /conversations": "Get conversations with filters",
|
| 673 |
+
"PUT /conversations/{convo_id}/saved": "Update saved status",
|
| 674 |
+
"DELETE /conversations/{convo_id}": "Delete conversation",
|
| 675 |
+
"GET /health": "Health check",
|
| 676 |
+
"GET /stats": "Get conversation statistics"
|
| 677 |
+
}
|
| 678 |
+
}
|
| 679 |
+
|
| 680 |
+
@Db_store_router.get("/health")
|
| 681 |
+
async def health_check():
|
| 682 |
+
try:
|
| 683 |
+
conn = connect_to_db()
|
| 684 |
+
conn.close()
|
| 685 |
+
return {"status": "healthy", "database": "connected"}
|
| 686 |
+
except Exception as e:
|
| 687 |
+
raise HTTPException(status_code=503, detail=f"Database unhealthy: {str(e)}")
|
| 688 |
+
|
| 689 |
+
# @Db_store_router.post("/conversations", status_code=201)
|
| 690 |
+
# async def create_conversation(conversation: ConversationInput):
|
| 691 |
+
# """Create a new conversation entry"""
|
| 692 |
+
# conn = connect_to_db()
|
| 693 |
+
# cursor = conn.cursor()
|
| 694 |
+
# print("vastav conversation api hitted")
|
| 695 |
+
|
| 696 |
+
# insert_query = sql.SQL("""
|
| 697 |
+
# INSERT INTO "conversation_data" (user_id, convo_id, user_query, response, file_metadata, is_saved)
|
| 698 |
+
# VALUES (%s, %s, %s, %s, %s, %s)
|
| 699 |
+
# RETURNING id, created_at
|
| 700 |
+
# """)
|
| 701 |
+
|
| 702 |
+
# # Convert data to JSON
|
| 703 |
+
# user_query = json.dumps(conversation.data.user_query.dict())
|
| 704 |
+
# response = json.dumps(conversation.data.response.dict())
|
| 705 |
+
# file_metadata = json.dumps([artifact.dict() for artifact in conversation.data.response.artifacts])
|
| 706 |
+
# print("file_metadata", file_metadata)
|
| 707 |
+
|
| 708 |
+
# try:
|
| 709 |
+
# cursor.execute(insert_query, (
|
| 710 |
+
# conversation.user_id,
|
| 711 |
+
# conversation.convo_id,
|
| 712 |
+
# user_query,
|
| 713 |
+
# response,
|
| 714 |
+
# file_metadata,
|
| 715 |
+
# conversation.is_saved
|
| 716 |
+
# ))
|
| 717 |
+
# result = cursor.fetchone()
|
| 718 |
+
# conn.commit()
|
| 719 |
+
|
| 720 |
+
# return {
|
| 721 |
+
# "message": "Conversation created successfully",
|
| 722 |
+
# "user_id": conversation.user_id,
|
| 723 |
+
# "convo_id": conversation.convo_id,
|
| 724 |
+
# "id": result[0],
|
| 725 |
+
# "created_at": str(result[1]),
|
| 726 |
+
# "is_saved": conversation.is_saved
|
| 727 |
+
# }
|
| 728 |
+
# except Exception as e:
|
| 729 |
+
# conn.rollback()
|
| 730 |
+
# raise HTTPException(status_code=500, detail=f"Error creating conversation: {str(e)}")
|
| 731 |
+
# finally:
|
| 732 |
+
# cursor.close()
|
| 733 |
+
# conn.close()
|
| 734 |
+
|
| 735 |
+
@Db_store_router.post("/conversations", status_code=201)
|
| 736 |
+
async def create_conversation(payload: dict, db: Session = Depends(get_db)):
|
| 737 |
+
"""
|
| 738 |
+
Create a new conversation entry
|
| 739 |
+
"""
|
| 740 |
+
try:
|
| 741 |
+
user_id = payload.get("user_id")
|
| 742 |
+
convo_id = payload.get("convo_id")
|
| 743 |
+
data = payload.get("data", {})
|
| 744 |
+
is_saved = payload.get("is_saved", False)
|
| 745 |
+
|
| 746 |
+
# Convert data fields to JSON
|
| 747 |
+
user_query = json.dumps(data.get("user_query", {}))
|
| 748 |
+
response = json.dumps(data.get("response", {}))
|
| 749 |
+
|
| 750 |
+
# Extract file_metadata (if present inside response.artifacts)
|
| 751 |
+
artifacts = data.get("response", {}).get("artifacts", [])
|
| 752 |
+
file_metadata = json.dumps(artifacts) if artifacts else json.dumps([])
|
| 753 |
+
|
| 754 |
+
created_at = datetime.now(timezone.utc)
|
| 755 |
+
updated_at = created_at
|
| 756 |
+
|
| 757 |
+
print("vastav conversation api hitted")
|
| 758 |
+
|
| 759 |
+
insert_query = text("""
|
| 760 |
+
INSERT INTO conversation_data (
|
| 761 |
+
user_id, convo_id, user_query, response, file_metadata, is_saved, created_at, updated_at
|
| 762 |
+
)
|
| 763 |
+
VALUES (
|
| 764 |
+
:user_id, :convo_id, :user_query, :response, :file_metadata, :is_saved, :created_at, :updated_at
|
| 765 |
+
)
|
| 766 |
+
RETURNING id, created_at
|
| 767 |
+
""")
|
| 768 |
+
print("insert_query", insert_query)
|
| 769 |
+
|
| 770 |
+
result = db.execute(insert_query, {
|
| 771 |
+
"user_id": user_id,
|
| 772 |
+
"convo_id": convo_id,
|
| 773 |
+
"user_query": user_query,
|
| 774 |
+
"response": response,
|
| 775 |
+
"file_metadata": file_metadata,
|
| 776 |
+
"is_saved": is_saved,
|
| 777 |
+
"created_at": created_at,
|
| 778 |
+
"updated_at": updated_at
|
| 779 |
+
})
|
| 780 |
+
|
| 781 |
+
db.commit()
|
| 782 |
+
inserted = result.fetchone()
|
| 783 |
+
|
| 784 |
+
return {
|
| 785 |
+
"message": "Conversation created successfully",
|
| 786 |
+
"id": str(inserted[0]),
|
| 787 |
+
"convo_id": convo_id,
|
| 788 |
+
"created_at": str(inserted[1]),
|
| 789 |
+
"is_saved": is_saved
|
| 790 |
+
}
|
| 791 |
+
|
| 792 |
+
except Exception as e:
|
| 793 |
+
db.rollback()
|
| 794 |
+
raise HTTPException(status_code=500, detail=f"Error creating conversation: {str(e)}")
|
| 795 |
+
|
| 796 |
+
|
| 797 |
+
@Db_store_router.get("/conversations/{convo_id}")
|
| 798 |
+
async def get_conversation_by_id(
|
| 799 |
+
convo_id: str,
|
| 800 |
+
user_id: str = Query(..., description="User ID"),
|
| 801 |
+
is_saved: Optional[bool] = Query(None, description="Filter by saved status")
|
| 802 |
+
):
|
| 803 |
+
"""Get conversation(s) by convo_id and user_id, optionally filtered by saved status"""
|
| 804 |
+
conn = connect_to_db()
|
| 805 |
+
cursor = conn.cursor()
|
| 806 |
+
|
| 807 |
+
if is_saved is not None:
|
| 808 |
+
select_query = """
|
| 809 |
+
SELECT convo_id, user_id, user_query, response, file_metadata, is_saved, created_at, updated_at
|
| 810 |
+
FROM "conversation_data"
|
| 811 |
+
WHERE convo_id = %s AND user_id = %s AND is_saved = %s
|
| 812 |
+
ORDER BY created_at DESC
|
| 813 |
+
"""
|
| 814 |
+
cursor.execute(select_query, (convo_id, user_id, is_saved))
|
| 815 |
+
else:
|
| 816 |
+
select_query = """
|
| 817 |
+
SELECT convo_id, user_id, user_query, response, file_metadata, is_saved, created_at, updated_at
|
| 818 |
+
FROM "conversation_data"
|
| 819 |
+
WHERE convo_id = %s AND user_id = %s
|
| 820 |
+
ORDER BY created_at DESC
|
| 821 |
+
"""
|
| 822 |
+
cursor.execute(select_query, (convo_id, user_id))
|
| 823 |
+
|
| 824 |
+
try:
|
| 825 |
+
results = cursor.fetchall()
|
| 826 |
+
|
| 827 |
+
if not results:
|
| 828 |
+
raise HTTPException(status_code=404, detail=f"No conversations found for convo_id: {convo_id} and user_id: {user_id}")
|
| 829 |
+
|
| 830 |
+
conversations = []
|
| 831 |
+
for result in results:
|
| 832 |
+
stored_convo_id, stored_user_id, user_query_json, response_json, file_metadata_json, saved, created_at, updated_at = result
|
| 833 |
+
|
| 834 |
+
conversations.append({
|
| 835 |
+
"convo_id": stored_convo_id,
|
| 836 |
+
"user_id": stored_user_id,
|
| 837 |
+
"user_query": user_query_json,
|
| 838 |
+
"response": response_json,
|
| 839 |
+
"file_metadata": file_metadata_json,
|
| 840 |
+
"is_saved": saved,
|
| 841 |
+
"created_at": str(created_at),
|
| 842 |
+
"updated_at": str(updated_at)
|
| 843 |
+
})
|
| 844 |
+
|
| 845 |
+
return {
|
| 846 |
+
"count": len(conversations),
|
| 847 |
+
"conversations": conversations
|
| 848 |
+
}
|
| 849 |
+
except HTTPException:
|
| 850 |
+
raise
|
| 851 |
+
except Exception as e:
|
| 852 |
+
raise HTTPException(status_code=500, detail=f"Error retrieving conversation: {str(e)}")
|
| 853 |
+
finally:
|
| 854 |
+
cursor.close()
|
| 855 |
+
conn.close()
|
| 856 |
+
|
| 857 |
+
@Db_store_router.get("/conversations")
|
| 858 |
+
async def get_conversations(
|
| 859 |
+
user_id: str = Query(..., description="User ID"),
|
| 860 |
+
is_saved: Optional[bool] = Query(None, description="Filter by saved status"),
|
| 861 |
+
limit: int = Query(100, ge=1, le=1000, description="Maximum number of results"),
|
| 862 |
+
offset: int = Query(0, ge=0, description="Number of results to skip")
|
| 863 |
+
):
|
| 864 |
+
"""Get all conversations for a user with optional filters"""
|
| 865 |
+
conn = connect_to_db()
|
| 866 |
+
cursor = conn.cursor()
|
| 867 |
+
|
| 868 |
+
if is_saved is not None:
|
| 869 |
+
select_query = """
|
| 870 |
+
SELECT convo_id, user_id, user_query, response, file_metadata, is_saved, created_at, updated_at
|
| 871 |
+
FROM "conversation_data"
|
| 872 |
+
WHERE user_id = %s AND is_saved = %s
|
| 873 |
+
ORDER BY updated_at DESC
|
| 874 |
+
LIMIT %s OFFSET %s
|
| 875 |
+
"""
|
| 876 |
+
cursor.execute(select_query, (user_id, is_saved, limit, offset))
|
| 877 |
+
else:
|
| 878 |
+
select_query = """
|
| 879 |
+
SELECT convo_id, user_id, user_query, response, file_metadata, is_saved, created_at, updated_at
|
| 880 |
+
FROM "conversation_data"
|
| 881 |
+
WHERE user_id = %s
|
| 882 |
+
ORDER BY updated_at DESC
|
| 883 |
+
LIMIT %s OFFSET %s
|
| 884 |
+
"""
|
| 885 |
+
cursor.execute(select_query, (user_id, limit, offset))
|
| 886 |
+
|
| 887 |
+
try:
|
| 888 |
+
results = cursor.fetchall()
|
| 889 |
+
|
| 890 |
+
if not results:
|
| 891 |
+
return {"count": 0, "conversations": []}
|
| 892 |
+
|
| 893 |
+
conversations = []
|
| 894 |
+
for result in results:
|
| 895 |
+
stored_convo_id, stored_user_id, user_query_json, response_json, file_metadata_json, saved, created_at, updated_at = result
|
| 896 |
+
|
| 897 |
+
conversations.append({
|
| 898 |
+
"convo_id": stored_convo_id,
|
| 899 |
+
"user_id": stored_user_id,
|
| 900 |
+
"user_query": user_query_json,
|
| 901 |
+
"response": response_json,
|
| 902 |
+
"file_metadata": file_metadata_json,
|
| 903 |
+
"is_saved": saved,
|
| 904 |
+
"created_at": str(created_at),
|
| 905 |
+
"updated_at": str(updated_at)
|
| 906 |
+
})
|
| 907 |
+
|
| 908 |
+
return {
|
| 909 |
+
"count": len(conversations),
|
| 910 |
+
"limit": limit,
|
| 911 |
+
"offset": offset,
|
| 912 |
+
"conversations": conversations
|
| 913 |
+
}
|
| 914 |
+
except Exception as e:
|
| 915 |
+
raise HTTPException(status_code=500, detail=f"Error retrieving conversations: {str(e)}")
|
| 916 |
+
finally:
|
| 917 |
+
cursor.close()
|
| 918 |
+
conn.close()
|
| 919 |
+
|
| 920 |
+
@Db_store_router.get("/conversations", response_class=JSONResponse)
|
| 921 |
+
async def fetch_conversations(
|
| 922 |
+
user_id: str = Query(..., description="User ID"),
|
| 923 |
+
is_saved: Optional[str] = Query(None, description="true | false | leave empty for all")
|
| 924 |
+
):
|
| 925 |
+
"""
|
| 926 |
+
Endpoint: GET /conversations?user_id={user_id}&is_saved=true|false
|
| 927 |
+
Returns conversations filtered by user_id and is_saved flag or all if not provided.
|
| 928 |
+
"""
|
| 929 |
+
try:
|
| 930 |
+
# Convert query param to Python boolean or None
|
| 931 |
+
if is_saved is None:
|
| 932 |
+
is_saved_value = None
|
| 933 |
+
elif is_saved.lower() == "true":
|
| 934 |
+
is_saved_value = True
|
| 935 |
+
elif is_saved.lower() == "false":
|
| 936 |
+
is_saved_value = False
|
| 937 |
+
else:
|
| 938 |
+
return JSONResponse(
|
| 939 |
+
status_code=400,
|
| 940 |
+
content={"error": "Invalid value for is_saved. Use true, false, or omit it."}
|
| 941 |
+
)
|
| 942 |
+
|
| 943 |
+
data = get_all_convo_data_filtered(user_id, is_saved_value)
|
| 944 |
+
|
| 945 |
+
if data is None:
|
| 946 |
+
return JSONResponse(status_code=500, content={"error": "Database query failed."})
|
| 947 |
+
|
| 948 |
+
if len(data) == 0:
|
| 949 |
+
return JSONResponse(status_code=404, content={"message": f"No conversations found for user_id: {user_id}."})
|
| 950 |
+
|
| 951 |
+
return JSONResponse(
|
| 952 |
+
status_code=200,
|
| 953 |
+
content={
|
| 954 |
+
"count": len(data),
|
| 955 |
+
"filter": {"user_id": user_id, "is_saved": is_saved_value},
|
| 956 |
+
"conversations": data
|
| 957 |
+
}
|
| 958 |
+
)
|
| 959 |
+
except Exception as e:
|
| 960 |
+
print(f"❌ API Error: {e}")
|
| 961 |
+
return JSONResponse(status_code=500, content={"error": str(e)})
|
| 962 |
+
|
| 963 |
+
@Db_store_router.delete("/conversations/{convo_id}")
|
| 964 |
+
async def delete_conversation(
|
| 965 |
+
convo_id: str,
|
| 966 |
+
user_id: str = Query(..., description="User ID"),
|
| 967 |
+
delete_all: bool = Query(False, description="Delete all records with this convo_id for the user")
|
| 968 |
+
):
|
| 969 |
+
"""Delete conversation(s) by convo_id and user_id"""
|
| 970 |
+
conn = connect_to_db()
|
| 971 |
+
cursor = conn.cursor()
|
| 972 |
+
|
| 973 |
+
if delete_all:
|
| 974 |
+
delete_query = 'DELETE FROM "conversation_data" WHERE convo_id = %s AND user_id = %s'
|
| 975 |
+
cursor.execute(delete_query, (convo_id, user_id))
|
| 976 |
+
else:
|
| 977 |
+
# Delete only the most recent one
|
| 978 |
+
delete_query = """
|
| 979 |
+
DELETE FROM "conversation_data"
|
| 980 |
+
WHERE id = (
|
| 981 |
+
SELECT id FROM "conversation_data"
|
| 982 |
+
WHERE convo_id = %s AND user_id = %s
|
| 983 |
+
ORDER BY created_at DESC
|
| 984 |
+
LIMIT 1
|
| 985 |
+
)
|
| 986 |
+
"""
|
| 987 |
+
cursor.execute(delete_query, (convo_id, user_id))
|
| 988 |
+
|
| 989 |
+
try:
|
| 990 |
+
conn.commit()
|
| 991 |
+
|
| 992 |
+
if cursor.rowcount == 0:
|
| 993 |
+
raise HTTPException(status_code=404, detail=f"No conversation found with convo_id: {convo_id} and user_id: {user_id}")
|
| 994 |
+
|
| 995 |
+
return {
|
| 996 |
+
"message": f"Deleted {cursor.rowcount} record(s)",
|
| 997 |
+
"user_id": user_id,
|
| 998 |
+
"convo_id": convo_id
|
| 999 |
+
}
|
| 1000 |
+
except HTTPException:
|
| 1001 |
+
raise
|
| 1002 |
+
except Exception as e:
|
| 1003 |
+
conn.rollback()
|
| 1004 |
+
raise HTTPException(status_code=500, detail=f"Error deleting conversation: {str(e)}")
|
| 1005 |
+
finally:
|
| 1006 |
+
cursor.close()
|
| 1007 |
+
conn.close()
|
| 1008 |
+
|
| 1009 |
+
@Db_store_router.get("/stats")
|
| 1010 |
+
async def get_statistics(user_id: str = Query(..., description="User ID")):
|
| 1011 |
+
"""Get statistics about stored conversations for a user"""
|
| 1012 |
+
conn = connect_to_db()
|
| 1013 |
+
cursor = conn.cursor()
|
| 1014 |
+
|
| 1015 |
+
stats_query = """
|
| 1016 |
+
SELECT
|
| 1017 |
+
COUNT(*) as total_conversations,
|
| 1018 |
+
COUNT(DISTINCT convo_id) as unique_convo_ids,
|
| 1019 |
+
SUM(CASE WHEN is_saved = TRUE THEN 1 ELSE 0 END) as saved_conversations,
|
| 1020 |
+
SUM(CASE WHEN is_saved = FALSE THEN 1 ELSE 0 END) as unsaved_conversations
|
| 1021 |
+
FROM "conversation_data"
|
| 1022 |
+
WHERE user_id = %s
|
| 1023 |
+
"""
|
| 1024 |
+
|
| 1025 |
+
try:
|
| 1026 |
+
cursor.execute(stats_query, (user_id,))
|
| 1027 |
+
result = cursor.fetchone()
|
| 1028 |
+
|
| 1029 |
+
return {
|
| 1030 |
+
"user_id": user_id,
|
| 1031 |
+
"total_conversations": result[0],
|
| 1032 |
+
"unique_convo_ids": result[1],
|
| 1033 |
+
"saved_conversations": result[2],
|
| 1034 |
+
"unsaved_conversations": result[3]
|
| 1035 |
+
}
|
| 1036 |
+
except Exception as e:
|
| 1037 |
+
raise HTTPException(status_code=500, detail=f"Error getting statistics: {str(e)}")
|
| 1038 |
+
finally:
|
| 1039 |
+
cursor.close()
|
| 1040 |
+
conn.close()
|
| 1041 |
+
|
| 1042 |
+
# ------------------ FETCH DATA FUNCTION ------------------
|
| 1043 |
+
def get_all_convo_data_filtered(user_id: str, is_saved: Optional[bool] = None):
|
| 1044 |
+
conn = connect_to_db()
|
| 1045 |
+
if not conn:
|
| 1046 |
+
return None
|
| 1047 |
+
|
| 1048 |
+
cursor = conn.cursor()
|
| 1049 |
+
|
| 1050 |
+
if is_saved is None:
|
| 1051 |
+
select_query = """
|
| 1052 |
+
SELECT convo_id, user_id, user_query, response, file_metadata, is_saved, created_at, updated_at
|
| 1053 |
+
FROM "conversation_data"
|
| 1054 |
+
WHERE user_id = %s
|
| 1055 |
+
ORDER BY updated_at DESC
|
| 1056 |
+
"""
|
| 1057 |
+
params = (user_id,)
|
| 1058 |
+
else:
|
| 1059 |
+
select_query = """
|
| 1060 |
+
SELECT convo_id, user_id, user_query, response, file_metadata, is_saved, created_at, updated_at
|
| 1061 |
+
FROM "conversation_data"
|
| 1062 |
+
WHERE user_id = %s AND is_saved = %s
|
| 1063 |
+
ORDER BY updated_at DESC
|
| 1064 |
+
"""
|
| 1065 |
+
params = (user_id, is_saved)
|
| 1066 |
+
|
| 1067 |
+
try:
|
| 1068 |
+
cursor.execute(select_query, params)
|
| 1069 |
+
results = cursor.fetchall()
|
| 1070 |
+
|
| 1071 |
+
if not results:
|
| 1072 |
+
return []
|
| 1073 |
+
|
| 1074 |
+
conversations = []
|
| 1075 |
+
for result in results:
|
| 1076 |
+
stored_convo_id, stored_user_id, user_query_json, response_json, file_metadata_json, is_saved_flag, created_at, updated_at = result
|
| 1077 |
+
|
| 1078 |
+
try:
|
| 1079 |
+
user_query = user_query_json if isinstance(user_query_json, dict) else json.loads(user_query_json)
|
| 1080 |
+
except Exception:
|
| 1081 |
+
user_query = {}
|
| 1082 |
+
|
| 1083 |
+
try:
|
| 1084 |
+
response = response_json if isinstance(response_json, dict) else json.loads(response_json)
|
| 1085 |
+
except Exception:
|
| 1086 |
+
response = {}
|
| 1087 |
+
|
| 1088 |
+
try:
|
| 1089 |
+
file_metadata = file_metadata_json if isinstance(file_metadata_json, (dict, list)) else json.loads(file_metadata_json)
|
| 1090 |
+
except Exception:
|
| 1091 |
+
file_metadata = {}
|
| 1092 |
+
|
| 1093 |
+
conversations.append({
|
| 1094 |
+
"convo_id": stored_convo_id,
|
| 1095 |
+
"user_id": stored_user_id,
|
| 1096 |
+
"user_query": user_query,
|
| 1097 |
+
"response": response,
|
| 1098 |
+
"file_metadata": file_metadata,
|
| 1099 |
+
"is_saved": is_saved_flag,
|
| 1100 |
+
"created_at": str(created_at),
|
| 1101 |
+
"updated_at": str(updated_at)
|
| 1102 |
+
})
|
| 1103 |
+
|
| 1104 |
+
return conversations
|
| 1105 |
+
|
| 1106 |
+
except Exception as e:
|
| 1107 |
+
print(f"❌ Error retrieving data: {e}")
|
| 1108 |
+
return None
|
| 1109 |
+
|
| 1110 |
+
finally:
|
| 1111 |
+
cursor.close()
|
| 1112 |
+
conn.close()
|
| 1113 |
+
|
| 1114 |
+
# ------------------ FASTAPI ENDPOINT ------------------
|
| 1115 |
+
@Db_store_router.get("/conversations", response_class=JSONResponse)
|
| 1116 |
+
async def fetch_conversations(
|
| 1117 |
+
user_id: str = Query(..., description="User ID"),
|
| 1118 |
+
is_saved: Optional[str] = Query(None, description="true | false | leave empty for all")
|
| 1119 |
+
):
|
| 1120 |
+
"""
|
| 1121 |
+
Endpoint: GET /conversations?user_id={user_id}&is_saved=true|false
|
| 1122 |
+
Returns conversations filtered by user_id and is_saved flag or all if not provided.
|
| 1123 |
+
"""
|
| 1124 |
+
try:
|
| 1125 |
+
# Convert query param to Python boolean or None
|
| 1126 |
+
if is_saved is None:
|
| 1127 |
+
is_saved_value = None
|
| 1128 |
+
elif is_saved.lower() == "true":
|
| 1129 |
+
is_saved_value = True
|
| 1130 |
+
elif is_saved.lower() == "false":
|
| 1131 |
+
is_saved_value = False
|
| 1132 |
+
else:
|
| 1133 |
+
return JSONResponse(
|
| 1134 |
+
status_code=400,
|
| 1135 |
+
content={"error": "Invalid value for is_saved. Use true, false, or omit it."}
|
| 1136 |
+
)
|
| 1137 |
+
|
| 1138 |
+
data = get_all_convo_data_filtered(user_id, is_saved_value)
|
| 1139 |
+
|
| 1140 |
+
if data is None:
|
| 1141 |
+
return JSONResponse(status_code=500, content={"error": "Database query failed."})
|
| 1142 |
+
|
| 1143 |
+
if len(data) == 0:
|
| 1144 |
+
return JSONResponse(status_code=404, content={"message": f"No conversations found for user_id: {user_id}."})
|
| 1145 |
+
|
| 1146 |
+
return JSONResponse(
|
| 1147 |
+
status_code=200,
|
| 1148 |
+
content={
|
| 1149 |
+
"count": len(data),
|
| 1150 |
+
"filter": {"user_id": user_id, "is_saved": is_saved_value},
|
| 1151 |
+
"conversations": data
|
| 1152 |
+
}
|
| 1153 |
+
)
|
| 1154 |
+
|
| 1155 |
+
except Exception as e:
|
| 1156 |
+
print(f"❌ API Error: {e}")
|
| 1157 |
+
return JSONResponse(status_code=500, content={"error": str(e)})
|
DB_store_backup/rough.py
ADDED
|
@@ -0,0 +1,510 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import psycopg2
|
| 2 |
+
from psycopg2 import sql
|
| 3 |
+
import json
|
| 4 |
+
|
| 5 |
+
# PostgreSQL credentials
|
| 6 |
+
PGHOST = 'ep-steep-dream-adqtvjel-pooler.c-2.us-east-1.aws.neon.tech'
|
| 7 |
+
PGDATABASE = 'neondb'
|
| 8 |
+
PGUSER = 'neondb_owner'
|
| 9 |
+
PGPASSWORD = 'npg_Qq0B1uWRXavx'
|
| 10 |
+
PGSSLMODE = 'require'
|
| 11 |
+
|
| 12 |
+
# Function to connect to the PostgreSQL database
|
| 13 |
+
def connect_to_db():
|
| 14 |
+
try:
|
| 15 |
+
conn = psycopg2.connect(
|
| 16 |
+
host=PGHOST,
|
| 17 |
+
database=PGDATABASE,
|
| 18 |
+
user=PGUSER,
|
| 19 |
+
password=PGPASSWORD,
|
| 20 |
+
sslmode=PGSSLMODE
|
| 21 |
+
)
|
| 22 |
+
return conn
|
| 23 |
+
except Exception as e:
|
| 24 |
+
print(f"Error connecting to PostgreSQL database: {e}")
|
| 25 |
+
return None
|
| 26 |
+
|
| 27 |
+
# Function to delete the table
|
| 28 |
+
def delete_table_by_name(table_name: str):
|
| 29 |
+
conn = connect_to_db()
|
| 30 |
+
if not conn:
|
| 31 |
+
print("Failed to connect to the database")
|
| 32 |
+
return
|
| 33 |
+
|
| 34 |
+
cursor = conn.cursor()
|
| 35 |
+
|
| 36 |
+
# Create the DROP TABLE query dynamically using the table_name argument
|
| 37 |
+
drop_table_query = sql.SQL("DROP TABLE IF EXISTS {}").format(sql.Identifier(table_name))
|
| 38 |
+
|
| 39 |
+
try:
|
| 40 |
+
cursor.execute(drop_table_query)
|
| 41 |
+
conn.commit()
|
| 42 |
+
print(f"Table '{table_name}' deleted successfully.")
|
| 43 |
+
except Exception as e:
|
| 44 |
+
conn.rollback()
|
| 45 |
+
print(f"Error deleting table '{table_name}': {e}")
|
| 46 |
+
finally:
|
| 47 |
+
cursor.close()
|
| 48 |
+
conn.close()
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
# Function to create the table if it doesn't exist
|
| 53 |
+
def create_table():
|
| 54 |
+
conn = connect_to_db()
|
| 55 |
+
if not conn:
|
| 56 |
+
return False
|
| 57 |
+
|
| 58 |
+
cursor = conn.cursor()
|
| 59 |
+
|
| 60 |
+
create_table_query = """
|
| 61 |
+
CREATE TABLE IF NOT EXISTS "stored_convoId_data" (
|
| 62 |
+
id SERIAL PRIMARY KEY,
|
| 63 |
+
convo_id VARCHAR(255) NOT NULL,
|
| 64 |
+
user_query JSONB NOT NULL,
|
| 65 |
+
response JSONB NOT NULL,
|
| 66 |
+
file_metadata JSONB,
|
| 67 |
+
is_saved BOOLEAN DEFAULT FALSE,
|
| 68 |
+
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
| 69 |
+
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
| 70 |
+
)
|
| 71 |
+
"""
|
| 72 |
+
|
| 73 |
+
# Create index for faster queries
|
| 74 |
+
create_index_query = """
|
| 75 |
+
CREATE INDEX IF NOT EXISTS idx_convo_id_is_saved
|
| 76 |
+
ON "stored_convoId_data" (convo_id, is_saved);
|
| 77 |
+
"""
|
| 78 |
+
|
| 79 |
+
try:
|
| 80 |
+
cursor.execute(create_table_query)
|
| 81 |
+
cursor.execute(create_index_query)
|
| 82 |
+
conn.commit()
|
| 83 |
+
print("Table 'stored_convoId_data' created successfully or already exists.")
|
| 84 |
+
return True
|
| 85 |
+
except Exception as e:
|
| 86 |
+
print(f"Error creating table: {e}")
|
| 87 |
+
conn.rollback()
|
| 88 |
+
return False
|
| 89 |
+
finally:
|
| 90 |
+
cursor.close()
|
| 91 |
+
conn.close()
|
| 92 |
+
|
| 93 |
+
# Function to insert the data into PostgreSQL table
|
| 94 |
+
def insert_convo_data(convo_id, data, is_saved=False):
|
| 95 |
+
conn = connect_to_db()
|
| 96 |
+
if not conn:
|
| 97 |
+
return
|
| 98 |
+
|
| 99 |
+
cursor = conn.cursor()
|
| 100 |
+
|
| 101 |
+
insert_query = sql.SQL("""
|
| 102 |
+
INSERT INTO "stored_convoId_data" (convo_id, user_query, response, file_metadata, is_saved)
|
| 103 |
+
VALUES (%s, %s, %s, %s, %s)
|
| 104 |
+
""")
|
| 105 |
+
|
| 106 |
+
# Prepare data to insert (use json.dumps for JSONB columns)
|
| 107 |
+
user_query = json.dumps(data["user_query"])
|
| 108 |
+
response = json.dumps(data["response"])
|
| 109 |
+
file_metadata = json.dumps(data["response"]["artifacts"])
|
| 110 |
+
|
| 111 |
+
try:
|
| 112 |
+
cursor.execute(insert_query, (convo_id, user_query, response, file_metadata, is_saved))
|
| 113 |
+
conn.commit()
|
| 114 |
+
print(f"Data for convo_id {convo_id} inserted successfully with is_saved={is_saved}.")
|
| 115 |
+
except Exception as e:
|
| 116 |
+
print(f"Error inserting data: {e}")
|
| 117 |
+
conn.rollback()
|
| 118 |
+
finally:
|
| 119 |
+
cursor.close()
|
| 120 |
+
conn.close()
|
| 121 |
+
|
| 122 |
+
# Function to update the is_saved flag
|
| 123 |
+
def update_saved_status(convo_id, is_saved=True):
|
| 124 |
+
conn = connect_to_db()
|
| 125 |
+
if not conn:
|
| 126 |
+
return False
|
| 127 |
+
|
| 128 |
+
cursor = conn.cursor()
|
| 129 |
+
|
| 130 |
+
update_query = """
|
| 131 |
+
UPDATE "stored_convoId_data"
|
| 132 |
+
SET is_saved = %s, updated_at = CURRENT_TIMESTAMP
|
| 133 |
+
WHERE convo_id = %s
|
| 134 |
+
"""
|
| 135 |
+
|
| 136 |
+
try:
|
| 137 |
+
cursor.execute(update_query, (is_saved, convo_id))
|
| 138 |
+
conn.commit()
|
| 139 |
+
if cursor.rowcount > 0:
|
| 140 |
+
print(f"Updated {cursor.rowcount} record(s) for convo_id {convo_id} to is_saved={is_saved}.")
|
| 141 |
+
return True
|
| 142 |
+
else:
|
| 143 |
+
print(f"No records found for convo_id {convo_id}.")
|
| 144 |
+
return False
|
| 145 |
+
except Exception as e:
|
| 146 |
+
print(f"Error updating saved status: {e}")
|
| 147 |
+
conn.rollback()
|
| 148 |
+
return False
|
| 149 |
+
finally:
|
| 150 |
+
cursor.close()
|
| 151 |
+
conn.close()
|
| 152 |
+
|
| 153 |
+
# Function to retrieve data by convo_id only
|
| 154 |
+
def get_convo_data(convo_id):
|
| 155 |
+
conn = connect_to_db()
|
| 156 |
+
if not conn:
|
| 157 |
+
return None
|
| 158 |
+
|
| 159 |
+
cursor = conn.cursor()
|
| 160 |
+
|
| 161 |
+
select_query = """
|
| 162 |
+
SELECT convo_id, user_query, response, file_metadata, is_saved, created_at, updated_at
|
| 163 |
+
FROM "stored_convoId_data"
|
| 164 |
+
WHERE convo_id = %s
|
| 165 |
+
ORDER BY created_at DESC
|
| 166 |
+
"""
|
| 167 |
+
|
| 168 |
+
try:
|
| 169 |
+
cursor.execute(select_query, (convo_id,))
|
| 170 |
+
results = cursor.fetchall()
|
| 171 |
+
|
| 172 |
+
if results:
|
| 173 |
+
conversations = []
|
| 174 |
+
for result in results:
|
| 175 |
+
stored_convo_id, user_query_json, response_json, file_metadata_json, is_saved, created_at, updated_at = result
|
| 176 |
+
|
| 177 |
+
# JSONB columns are already returned as dict/list by psycopg2
|
| 178 |
+
user_query = user_query_json if isinstance(user_query_json, dict) else json.loads(user_query_json)
|
| 179 |
+
response = response_json if isinstance(response_json, dict) else json.loads(response_json)
|
| 180 |
+
file_metadata = file_metadata_json if isinstance(file_metadata_json, (dict, list)) else json.loads(file_metadata_json)
|
| 181 |
+
|
| 182 |
+
conversations.append({
|
| 183 |
+
"convo_id": stored_convo_id,
|
| 184 |
+
"user_query": user_query,
|
| 185 |
+
"response": response,
|
| 186 |
+
"file_metadata": file_metadata,
|
| 187 |
+
"is_saved": is_saved,
|
| 188 |
+
"created_at": str(created_at),
|
| 189 |
+
"updated_at": str(updated_at)
|
| 190 |
+
})
|
| 191 |
+
|
| 192 |
+
return conversations
|
| 193 |
+
else:
|
| 194 |
+
print(f"No data found for convo_id: {convo_id}")
|
| 195 |
+
return None
|
| 196 |
+
|
| 197 |
+
except Exception as e:
|
| 198 |
+
print(f"Error retrieving data: {e}")
|
| 199 |
+
return None
|
| 200 |
+
finally:
|
| 201 |
+
cursor.close()
|
| 202 |
+
conn.close()
|
| 203 |
+
|
| 204 |
+
# Function to retrieve data by convo_id and is_saved status
|
| 205 |
+
def get_convo_data_by_saved_status(convo_id, is_saved=True):
|
| 206 |
+
conn = connect_to_db()
|
| 207 |
+
if not conn:
|
| 208 |
+
return None
|
| 209 |
+
|
| 210 |
+
cursor = conn.cursor()
|
| 211 |
+
|
| 212 |
+
select_query = """
|
| 213 |
+
SELECT convo_id, user_query, response, file_metadata, is_saved, created_at, updated_at
|
| 214 |
+
FROM "stored_convoId_data"
|
| 215 |
+
WHERE convo_id = %s AND is_saved = %s
|
| 216 |
+
ORDER BY created_at DESC
|
| 217 |
+
"""
|
| 218 |
+
|
| 219 |
+
try:
|
| 220 |
+
cursor.execute(select_query, (convo_id, is_saved))
|
| 221 |
+
results = cursor.fetchall()
|
| 222 |
+
|
| 223 |
+
if results:
|
| 224 |
+
conversations = []
|
| 225 |
+
for result in results:
|
| 226 |
+
stored_convo_id, user_query_json, response_json, file_metadata_json, is_saved_flag, created_at, updated_at = result
|
| 227 |
+
|
| 228 |
+
user_query = user_query_json if isinstance(user_query_json, dict) else json.loads(user_query_json)
|
| 229 |
+
response = response_json if isinstance(response_json, dict) else json.loads(response_json)
|
| 230 |
+
file_metadata = file_metadata_json if isinstance(file_metadata_json, (dict, list)) else json.loads(file_metadata_json)
|
| 231 |
+
|
| 232 |
+
conversations.append({
|
| 233 |
+
"convo_id": stored_convo_id,
|
| 234 |
+
"user_query": user_query,
|
| 235 |
+
"response": response,
|
| 236 |
+
"file_metadata": file_metadata,
|
| 237 |
+
"is_saved": is_saved_flag,
|
| 238 |
+
"created_at": str(created_at),
|
| 239 |
+
"updated_at": str(updated_at)
|
| 240 |
+
})
|
| 241 |
+
|
| 242 |
+
return conversations
|
| 243 |
+
else:
|
| 244 |
+
print(f"No data found for convo_id: {convo_id} with is_saved={is_saved}")
|
| 245 |
+
return None
|
| 246 |
+
|
| 247 |
+
except Exception as e:
|
| 248 |
+
print(f"Error retrieving data: {e}")
|
| 249 |
+
return None
|
| 250 |
+
finally:
|
| 251 |
+
cursor.close()
|
| 252 |
+
conn.close()
|
| 253 |
+
|
| 254 |
+
# Function to get all saved conversations (across all convo_ids)
|
| 255 |
+
def get_all_saved_conversations():
|
| 256 |
+
conn = connect_to_db()
|
| 257 |
+
if not conn:
|
| 258 |
+
return None
|
| 259 |
+
|
| 260 |
+
cursor = conn.cursor()
|
| 261 |
+
|
| 262 |
+
select_query = """
|
| 263 |
+
SELECT convo_id, user_query, response, file_metadata, is_saved, created_at, updated_at
|
| 264 |
+
FROM "stored_convoId_data"
|
| 265 |
+
WHERE is_saved = TRUE
|
| 266 |
+
ORDER BY updated_at DESC
|
| 267 |
+
"""
|
| 268 |
+
|
| 269 |
+
try:
|
| 270 |
+
cursor.execute(select_query)
|
| 271 |
+
results = cursor.fetchall()
|
| 272 |
+
|
| 273 |
+
if results:
|
| 274 |
+
conversations = []
|
| 275 |
+
for result in results:
|
| 276 |
+
stored_convo_id, user_query_json, response_json, file_metadata_json, is_saved, created_at, updated_at = result
|
| 277 |
+
|
| 278 |
+
user_query = user_query_json if isinstance(user_query_json, dict) else json.loads(user_query_json)
|
| 279 |
+
response = response_json if isinstance(response_json, dict) else json.loads(response_json)
|
| 280 |
+
file_metadata = file_metadata_json if isinstance(file_metadata_json, (dict, list)) else json.loads(file_metadata_json)
|
| 281 |
+
|
| 282 |
+
conversations.append({
|
| 283 |
+
"convo_id": stored_convo_id,
|
| 284 |
+
"user_query": user_query,
|
| 285 |
+
"response": response,
|
| 286 |
+
"file_metadata": file_metadata,
|
| 287 |
+
"is_saved": is_saved,
|
| 288 |
+
"created_at": str(created_at),
|
| 289 |
+
"updated_at": str(updated_at)
|
| 290 |
+
})
|
| 291 |
+
|
| 292 |
+
return conversations
|
| 293 |
+
else:
|
| 294 |
+
print("No saved conversations found.")
|
| 295 |
+
return None
|
| 296 |
+
|
| 297 |
+
except Exception as e:
|
| 298 |
+
print(f"Error retrieving saved conversations: {e}")
|
| 299 |
+
return None
|
| 300 |
+
finally:
|
| 301 |
+
cursor.close()
|
| 302 |
+
conn.close()
|
| 303 |
+
|
| 304 |
+
# Sample data
|
| 305 |
+
data = {
|
| 306 |
+
"user_query": {
|
| 307 |
+
"query_id": "12345",
|
| 308 |
+
"text": "What is the weather like today?",
|
| 309 |
+
"user_metadata": {
|
| 310 |
+
"location": "New York",
|
| 311 |
+
"language": "en"
|
| 312 |
+
}
|
| 313 |
+
},
|
| 314 |
+
"response": {
|
| 315 |
+
"text": "The weather today in New York is sunny and 75°F. Would you like to see a detailed report or a forecast?",
|
| 316 |
+
"status": "success",
|
| 317 |
+
"response_time": "2025-10-13T08:10:00Z",
|
| 318 |
+
"duration": "2s",
|
| 319 |
+
"artifacts": [
|
| 320 |
+
{
|
| 321 |
+
"artifact_id": "artifact_1",
|
| 322 |
+
"file_id": "file_1",
|
| 323 |
+
"file_name": "weather_report.pdf",
|
| 324 |
+
"file_type": "pdf",
|
| 325 |
+
"file_size": 1024,
|
| 326 |
+
"file_url": "path_to_file",
|
| 327 |
+
"upload_timestamp": "2025-10-13T08:00:00Z",
|
| 328 |
+
"metadata": {
|
| 329 |
+
"created_by": "system",
|
| 330 |
+
"associated_query_id": "12345",
|
| 331 |
+
"associated_session_id": "session_001",
|
| 332 |
+
"artifact_timestamp": "2025-10-13T08:10:00Z"
|
| 333 |
+
}
|
| 334 |
+
},
|
| 335 |
+
{
|
| 336 |
+
"artifact_id": "artifact_2",
|
| 337 |
+
"file_id": "file_2",
|
| 338 |
+
"file_name": "weather_forecast_image.jpg",
|
| 339 |
+
"file_type": "jpg",
|
| 340 |
+
"file_size": 2048,
|
| 341 |
+
"file_url": "path_to_image",
|
| 342 |
+
"upload_timestamp": "2025-10-13T08:05:00Z",
|
| 343 |
+
"metadata": {
|
| 344 |
+
"created_by": "system",
|
| 345 |
+
"associated_query_id": "12345",
|
| 346 |
+
"associated_session_id": "session_001",
|
| 347 |
+
"artifact_timestamp": "2025-10-13T08:10:00Z"
|
| 348 |
+
}
|
| 349 |
+
}
|
| 350 |
+
]
|
| 351 |
+
},
|
| 352 |
+
"metadata": {
|
| 353 |
+
"query_response_metadata": {
|
| 354 |
+
"response_time": "2025-10-13T08:10:00Z",
|
| 355 |
+
"response_status": "success",
|
| 356 |
+
"response_duration": "2s"
|
| 357 |
+
},
|
| 358 |
+
"file_metadata": [
|
| 359 |
+
{
|
| 360 |
+
"file_id": "file_1",
|
| 361 |
+
"file_name": "weather_report.pdf",
|
| 362 |
+
"file_size": 1024,
|
| 363 |
+
"upload_timestamp": "2025-10-13T08:00:00Z"
|
| 364 |
+
},
|
| 365 |
+
{
|
| 366 |
+
"file_id": "file_2",
|
| 367 |
+
"file_name": "weather_forecast_image.jpg",
|
| 368 |
+
"file_size": 2048,
|
| 369 |
+
"upload_timestamp": "2025-10-13T08:05:00Z"
|
| 370 |
+
}
|
| 371 |
+
]
|
| 372 |
+
}
|
| 373 |
+
}
|
| 374 |
+
|
| 375 |
+
def get_all_convo_data_filtered(is_saved=None):
|
| 376 |
+
"""
|
| 377 |
+
Retrieve all conversations across all convo_ids.
|
| 378 |
+
- If is_saved=True → returns only saved conversations
|
| 379 |
+
- If is_saved=False → returns only unsaved conversations
|
| 380 |
+
- If is_saved=None → returns all conversations
|
| 381 |
+
"""
|
| 382 |
+
conn = connect_to_db()
|
| 383 |
+
if not conn:
|
| 384 |
+
return None
|
| 385 |
+
cursor = conn.cursor()
|
| 386 |
+
|
| 387 |
+
# Build the query dynamically
|
| 388 |
+
if is_saved is None:
|
| 389 |
+
select_query = """
|
| 390 |
+
SELECT convo_id, user_query, response, file_metadata, is_saved, created_at, updated_at
|
| 391 |
+
FROM "stored_convoId_data"
|
| 392 |
+
ORDER BY updated_at DESC
|
| 393 |
+
"""
|
| 394 |
+
params = ()
|
| 395 |
+
else:
|
| 396 |
+
select_query = """
|
| 397 |
+
SELECT convo_id, user_query, response, file_metadata, is_saved, created_at, updated_at
|
| 398 |
+
FROM "stored_convoId_data"
|
| 399 |
+
WHERE is_saved = %s
|
| 400 |
+
ORDER BY updated_at DESC
|
| 401 |
+
"""
|
| 402 |
+
params = (is_saved,)
|
| 403 |
+
|
| 404 |
+
try:
|
| 405 |
+
cursor.execute(select_query, params)
|
| 406 |
+
results = cursor.fetchall()
|
| 407 |
+
if not results:
|
| 408 |
+
print(f"No conversations found with is_saved={is_saved}")
|
| 409 |
+
return None
|
| 410 |
+
|
| 411 |
+
conversations = []
|
| 412 |
+
for result in results:
|
| 413 |
+
stored_convo_id, user_query_json, response_json, file_metadata_json, is_saved_flag, created_at, updated_at = result
|
| 414 |
+
|
| 415 |
+
user_query = user_query_json if isinstance(user_query_json, dict) else json.loads(user_query_json)
|
| 416 |
+
response = response_json if isinstance(response_json, dict) else json.loads(response_json)
|
| 417 |
+
file_metadata = file_metadata_json if isinstance(file_metadata_json, (dict, list)) else json.loads(file_metadata_json)
|
| 418 |
+
|
| 419 |
+
conversations.append({
|
| 420 |
+
"convo_id": stored_convo_id,
|
| 421 |
+
"user_query": user_query,
|
| 422 |
+
"response": response,
|
| 423 |
+
"file_metadata": file_metadata,
|
| 424 |
+
"is_saved": is_saved_flag,
|
| 425 |
+
"created_at": str(created_at),
|
| 426 |
+
"updated_at": str(updated_at)
|
| 427 |
+
})
|
| 428 |
+
|
| 429 |
+
return conversations
|
| 430 |
+
|
| 431 |
+
except Exception as e:
|
| 432 |
+
print(f"Error retrieving conversations: {e}")
|
| 433 |
+
return None
|
| 434 |
+
|
| 435 |
+
finally:
|
| 436 |
+
cursor.close()
|
| 437 |
+
conn.close()
|
| 438 |
+
|
| 439 |
+
|
| 440 |
+
if __name__ == "__main__":
|
| 441 |
+
# Step 1: Create the table
|
| 442 |
+
#deleting the table
|
| 443 |
+
print("Step none: Deleting table...")
|
| 444 |
+
delete_table_by_name("stored_convoId_data") # Deletes the 'stored_convoId_data' table
|
| 445 |
+
print("Table 'stored_convoId_data' deleted successfully or does not exist.")
|
| 446 |
+
|
| 447 |
+
print("Step 1: Creating table...")
|
| 448 |
+
if create_table():
|
| 449 |
+
# Step 2: Insert data with is_saved=False
|
| 450 |
+
print("\nStep 2: Inserting data with is_saved=False...")
|
| 451 |
+
insert_convo_data("12345", data, is_saved=False)
|
| 452 |
+
|
| 453 |
+
# Step 3: Insert another conversation with is_saved=True
|
| 454 |
+
print("\nStep 3: Inserting another conversation with is_saved=True...")
|
| 455 |
+
insert_convo_data("67890", data, is_saved=True)
|
| 456 |
+
|
| 457 |
+
# Step 4: Retrieve all data by convo_id
|
| 458 |
+
print("\n" + "="*60)
|
| 459 |
+
print("Step 4: Retrieving all data for convo_id='12345'...")
|
| 460 |
+
print("="*60)
|
| 461 |
+
convo_data = get_convo_data("12345")
|
| 462 |
+
if convo_data:
|
| 463 |
+
print(json.dumps(convo_data, indent=2))
|
| 464 |
+
|
| 465 |
+
# Step 5: Update is_saved status
|
| 466 |
+
print("\n" + "="*60)
|
| 467 |
+
print("Step 5: Updating is_saved status to True for convo_id='12345'...")
|
| 468 |
+
print("="*60)
|
| 469 |
+
update_saved_status("12345", is_saved=True)
|
| 470 |
+
|
| 471 |
+
# Step 6: Retrieve saved conversations for specific convo_id
|
| 472 |
+
print("\n" + "="*60)
|
| 473 |
+
print("Step 6: Retrieving saved conversations for convo_id='12345'...")
|
| 474 |
+
print("="*60)
|
| 475 |
+
saved_convo = get_convo_data_by_saved_status("12345", is_saved=True)
|
| 476 |
+
if saved_convo:
|
| 477 |
+
print(json.dumps(saved_convo, indent=2))
|
| 478 |
+
|
| 479 |
+
# Step 7: Retrieve all saved conversations
|
| 480 |
+
print("\n" + "="*60)
|
| 481 |
+
print("Step 7: Retrieving ALL saved conversations...")
|
| 482 |
+
print("="*60)
|
| 483 |
+
all_saved = get_all_saved_conversations()
|
| 484 |
+
if all_saved:
|
| 485 |
+
print(f"Found {len(all_saved)} saved conversation(s):")
|
| 486 |
+
print(json.dumps(all_saved, indent=2))
|
| 487 |
+
else:
|
| 488 |
+
print("Failed to create table. Exiting...")
|
| 489 |
+
|
| 490 |
+
# Step 8: Retrieve all conversations
|
| 491 |
+
# Get all saved conversations
|
| 492 |
+
print("\n" + "="*60)
|
| 493 |
+
saved_convos = get_all_convo_data_filtered(is_saved=True)
|
| 494 |
+
if saved_convos:
|
| 495 |
+
print(f"Found {len(saved_convos)} saved conversation(s):")
|
| 496 |
+
print(json.dumps(saved_convos, indent=2))
|
| 497 |
+
print("="*60)
|
| 498 |
+
# Get all unsaved conversations
|
| 499 |
+
unsaved_convos = get_all_convo_data_filtered(is_saved=False)
|
| 500 |
+
if unsaved_convos:
|
| 501 |
+
print(f"Found {len(unsaved_convos)} unsaved conversation(s):")
|
| 502 |
+
print(json.dumps(unsaved_convos, indent=2))
|
| 503 |
+
|
| 504 |
+
print("="*60)
|
| 505 |
+
|
| 506 |
+
# Get all conversations (both saved and unsaved)
|
| 507 |
+
all_convos = get_all_convo_data_filtered(is_saved=None)
|
| 508 |
+
if all_convos:
|
| 509 |
+
print(f"Found {len(all_convos)} conversation(s):")
|
| 510 |
+
print(json.dumps(all_convos, indent=2))
|
DB_store_backup/stored_convoId_data.py
ADDED
|
@@ -0,0 +1,423 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import psycopg2
|
| 2 |
+
from psycopg2 import sql
|
| 3 |
+
import json
|
| 4 |
+
|
| 5 |
+
# PostgreSQL credentials
|
| 6 |
+
PGHOST = 'ep-steep-dream-adqtvjel-pooler.c-2.us-east-1.aws.neon.tech'
|
| 7 |
+
PGDATABASE = 'neondb'
|
| 8 |
+
PGUSER = 'neondb_owner'
|
| 9 |
+
PGPASSWORD = 'npg_Qq0B1uWRXavx'
|
| 10 |
+
PGSSLMODE = 'require'
|
| 11 |
+
|
| 12 |
+
# Function to connect to the PostgreSQL database
|
| 13 |
+
def connect_to_db():
|
| 14 |
+
try:
|
| 15 |
+
conn = psycopg2.connect(
|
| 16 |
+
host=PGHOST,
|
| 17 |
+
database=PGDATABASE,
|
| 18 |
+
user=PGUSER,
|
| 19 |
+
password=PGPASSWORD,
|
| 20 |
+
sslmode=PGSSLMODE
|
| 21 |
+
)
|
| 22 |
+
return conn
|
| 23 |
+
except Exception as e:
|
| 24 |
+
print(f"Error connecting to PostgreSQL database: {e}")
|
| 25 |
+
return None
|
| 26 |
+
|
| 27 |
+
# Function to delete the table
|
| 28 |
+
def delete_table_by_name(table_name: str):
|
| 29 |
+
conn = connect_to_db()
|
| 30 |
+
if not conn:
|
| 31 |
+
print("Failed to connect to the database")
|
| 32 |
+
return
|
| 33 |
+
|
| 34 |
+
cursor = conn.cursor()
|
| 35 |
+
|
| 36 |
+
# Create the DROP TABLE query dynamically using the table_name argument
|
| 37 |
+
drop_table_query = sql.SQL("DROP TABLE IF EXISTS {}").format(sql.Identifier(table_name))
|
| 38 |
+
|
| 39 |
+
try:
|
| 40 |
+
cursor.execute(drop_table_query)
|
| 41 |
+
conn.commit()
|
| 42 |
+
print(f"Table '{table_name}' deleted successfully.")
|
| 43 |
+
except Exception as e:
|
| 44 |
+
conn.rollback()
|
| 45 |
+
print(f"Error deleting table '{table_name}': {e}")
|
| 46 |
+
finally:
|
| 47 |
+
cursor.close()
|
| 48 |
+
conn.close()
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
# Function to create the table if it doesn't exist
|
| 53 |
+
def create_table():
|
| 54 |
+
conn = connect_to_db()
|
| 55 |
+
if not conn:
|
| 56 |
+
return False
|
| 57 |
+
|
| 58 |
+
cursor = conn.cursor()
|
| 59 |
+
|
| 60 |
+
create_table_query = """
|
| 61 |
+
CREATE TABLE IF NOT EXISTS "stored_convoId_data" (
|
| 62 |
+
id SERIAL PRIMARY KEY,
|
| 63 |
+
convo_id VARCHAR(255) NOT NULL,
|
| 64 |
+
user_query JSONB NOT NULL,
|
| 65 |
+
response JSONB NOT NULL,
|
| 66 |
+
file_metadata JSONB,
|
| 67 |
+
is_saved BOOLEAN DEFAULT FALSE,
|
| 68 |
+
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
| 69 |
+
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
| 70 |
+
)
|
| 71 |
+
"""
|
| 72 |
+
|
| 73 |
+
# Create index for faster queries
|
| 74 |
+
create_index_query = """
|
| 75 |
+
CREATE INDEX IF NOT EXISTS idx_convo_id_is_saved
|
| 76 |
+
ON "stored_convoId_data" (convo_id, is_saved);
|
| 77 |
+
"""
|
| 78 |
+
|
| 79 |
+
try:
|
| 80 |
+
cursor.execute(create_table_query)
|
| 81 |
+
cursor.execute(create_index_query)
|
| 82 |
+
conn.commit()
|
| 83 |
+
print("Table 'stored_convoId_data' created successfully or already exists.")
|
| 84 |
+
return True
|
| 85 |
+
except Exception as e:
|
| 86 |
+
print(f"Error creating table: {e}")
|
| 87 |
+
conn.rollback()
|
| 88 |
+
return False
|
| 89 |
+
finally:
|
| 90 |
+
cursor.close()
|
| 91 |
+
conn.close()
|
| 92 |
+
|
| 93 |
+
# Function to insert the data into PostgreSQL table
|
| 94 |
+
def insert_convo_data(convo_id, data, is_saved=False):
|
| 95 |
+
conn = connect_to_db()
|
| 96 |
+
if not conn:
|
| 97 |
+
return
|
| 98 |
+
|
| 99 |
+
cursor = conn.cursor()
|
| 100 |
+
|
| 101 |
+
insert_query = sql.SQL("""
|
| 102 |
+
INSERT INTO "stored_convoId_data" (convo_id, user_query, response, file_metadata, is_saved)
|
| 103 |
+
VALUES (%s, %s, %s, %s, %s)
|
| 104 |
+
""")
|
| 105 |
+
|
| 106 |
+
# Prepare data to insert (use json.dumps for JSONB columns)
|
| 107 |
+
user_query = json.dumps(data["user_query"])
|
| 108 |
+
response = json.dumps(data["response"])
|
| 109 |
+
file_metadata = json.dumps(data["response"]["artifacts"])
|
| 110 |
+
|
| 111 |
+
try:
|
| 112 |
+
cursor.execute(insert_query, (convo_id, user_query, response, file_metadata, is_saved))
|
| 113 |
+
conn.commit()
|
| 114 |
+
print(f"Data for convo_id {convo_id} inserted successfully with is_saved={is_saved}.")
|
| 115 |
+
except Exception as e:
|
| 116 |
+
print(f"Error inserting data: {e}")
|
| 117 |
+
conn.rollback()
|
| 118 |
+
finally:
|
| 119 |
+
cursor.close()
|
| 120 |
+
conn.close()
|
| 121 |
+
|
| 122 |
+
# Function to update the is_saved flag
|
| 123 |
+
def update_saved_status(convo_id, is_saved=True):
|
| 124 |
+
conn = connect_to_db()
|
| 125 |
+
if not conn:
|
| 126 |
+
return False
|
| 127 |
+
|
| 128 |
+
cursor = conn.cursor()
|
| 129 |
+
|
| 130 |
+
update_query = """
|
| 131 |
+
UPDATE "stored_convoId_data"
|
| 132 |
+
SET is_saved = %s, updated_at = CURRENT_TIMESTAMP
|
| 133 |
+
WHERE convo_id = %s
|
| 134 |
+
"""
|
| 135 |
+
|
| 136 |
+
try:
|
| 137 |
+
cursor.execute(update_query, (is_saved, convo_id))
|
| 138 |
+
conn.commit()
|
| 139 |
+
if cursor.rowcount > 0:
|
| 140 |
+
print(f"Updated {cursor.rowcount} record(s) for convo_id {convo_id} to is_saved={is_saved}.")
|
| 141 |
+
return True
|
| 142 |
+
else:
|
| 143 |
+
print(f"No records found for convo_id {convo_id}.")
|
| 144 |
+
return False
|
| 145 |
+
except Exception as e:
|
| 146 |
+
print(f"Error updating saved status: {e}")
|
| 147 |
+
conn.rollback()
|
| 148 |
+
return False
|
| 149 |
+
finally:
|
| 150 |
+
cursor.close()
|
| 151 |
+
conn.close()
|
| 152 |
+
|
| 153 |
+
# Function to retrieve data by convo_id only
|
| 154 |
+
def get_convo_data(convo_id):
|
| 155 |
+
conn = connect_to_db()
|
| 156 |
+
if not conn:
|
| 157 |
+
return None
|
| 158 |
+
|
| 159 |
+
cursor = conn.cursor()
|
| 160 |
+
|
| 161 |
+
select_query = """
|
| 162 |
+
SELECT convo_id, user_query, response, file_metadata, is_saved, created_at, updated_at
|
| 163 |
+
FROM "stored_convoId_data"
|
| 164 |
+
WHERE convo_id = %s
|
| 165 |
+
ORDER BY created_at DESC
|
| 166 |
+
"""
|
| 167 |
+
|
| 168 |
+
try:
|
| 169 |
+
cursor.execute(select_query, (convo_id,))
|
| 170 |
+
results = cursor.fetchall()
|
| 171 |
+
|
| 172 |
+
if results:
|
| 173 |
+
conversations = []
|
| 174 |
+
for result in results:
|
| 175 |
+
stored_convo_id, user_query_json, response_json, file_metadata_json, is_saved, created_at, updated_at = result
|
| 176 |
+
|
| 177 |
+
# JSONB columns are already returned as dict/list by psycopg2
|
| 178 |
+
user_query = user_query_json if isinstance(user_query_json, dict) else json.loads(user_query_json)
|
| 179 |
+
response = response_json if isinstance(response_json, dict) else json.loads(response_json)
|
| 180 |
+
file_metadata = file_metadata_json if isinstance(file_metadata_json, (dict, list)) else json.loads(file_metadata_json)
|
| 181 |
+
|
| 182 |
+
conversations.append({
|
| 183 |
+
"convo_id": stored_convo_id,
|
| 184 |
+
"user_query": user_query,
|
| 185 |
+
"response": response,
|
| 186 |
+
"file_metadata": file_metadata,
|
| 187 |
+
"is_saved": is_saved,
|
| 188 |
+
"created_at": str(created_at),
|
| 189 |
+
"updated_at": str(updated_at)
|
| 190 |
+
})
|
| 191 |
+
|
| 192 |
+
return conversations
|
| 193 |
+
else:
|
| 194 |
+
print(f"No data found for convo_id: {convo_id}")
|
| 195 |
+
return None
|
| 196 |
+
|
| 197 |
+
except Exception as e:
|
| 198 |
+
print(f"Error retrieving data: {e}")
|
| 199 |
+
return None
|
| 200 |
+
finally:
|
| 201 |
+
cursor.close()
|
| 202 |
+
conn.close()
|
| 203 |
+
|
| 204 |
+
# Function to retrieve data by convo_id and is_saved status
|
| 205 |
+
def get_convo_data_by_saved_status(convo_id, is_saved=True):
|
| 206 |
+
conn = connect_to_db()
|
| 207 |
+
if not conn:
|
| 208 |
+
return None
|
| 209 |
+
|
| 210 |
+
cursor = conn.cursor()
|
| 211 |
+
|
| 212 |
+
select_query = """
|
| 213 |
+
SELECT convo_id, user_query, response, file_metadata, is_saved, created_at, updated_at
|
| 214 |
+
FROM "stored_convoId_data"
|
| 215 |
+
WHERE convo_id = %s AND is_saved = %s
|
| 216 |
+
ORDER BY created_at DESC
|
| 217 |
+
"""
|
| 218 |
+
|
| 219 |
+
try:
|
| 220 |
+
cursor.execute(select_query, (convo_id, is_saved))
|
| 221 |
+
results = cursor.fetchall()
|
| 222 |
+
|
| 223 |
+
if results:
|
| 224 |
+
conversations = []
|
| 225 |
+
for result in results:
|
| 226 |
+
stored_convo_id, user_query_json, response_json, file_metadata_json, is_saved_flag, created_at, updated_at = result
|
| 227 |
+
|
| 228 |
+
user_query = user_query_json if isinstance(user_query_json, dict) else json.loads(user_query_json)
|
| 229 |
+
response = response_json if isinstance(response_json, dict) else json.loads(response_json)
|
| 230 |
+
file_metadata = file_metadata_json if isinstance(file_metadata_json, (dict, list)) else json.loads(file_metadata_json)
|
| 231 |
+
|
| 232 |
+
conversations.append({
|
| 233 |
+
"convo_id": stored_convo_id,
|
| 234 |
+
"user_query": user_query,
|
| 235 |
+
"response": response,
|
| 236 |
+
"file_metadata": file_metadata,
|
| 237 |
+
"is_saved": is_saved_flag,
|
| 238 |
+
"created_at": str(created_at),
|
| 239 |
+
"updated_at": str(updated_at)
|
| 240 |
+
})
|
| 241 |
+
|
| 242 |
+
return conversations
|
| 243 |
+
else:
|
| 244 |
+
print(f"No data found for convo_id: {convo_id} with is_saved={is_saved}")
|
| 245 |
+
return None
|
| 246 |
+
|
| 247 |
+
except Exception as e:
|
| 248 |
+
print(f"Error retrieving data: {e}")
|
| 249 |
+
return None
|
| 250 |
+
finally:
|
| 251 |
+
cursor.close()
|
| 252 |
+
conn.close()
|
| 253 |
+
|
| 254 |
+
# Function to get all saved conversations (across all convo_ids)
|
| 255 |
+
def get_all_saved_conversations():
|
| 256 |
+
conn = connect_to_db()
|
| 257 |
+
if not conn:
|
| 258 |
+
return None
|
| 259 |
+
|
| 260 |
+
cursor = conn.cursor()
|
| 261 |
+
|
| 262 |
+
select_query = """
|
| 263 |
+
SELECT convo_id, user_query, response, file_metadata, is_saved, created_at, updated_at
|
| 264 |
+
FROM "stored_convoId_data"
|
| 265 |
+
WHERE is_saved = TRUE
|
| 266 |
+
ORDER BY updated_at DESC
|
| 267 |
+
"""
|
| 268 |
+
|
| 269 |
+
try:
|
| 270 |
+
cursor.execute(select_query)
|
| 271 |
+
results = cursor.fetchall()
|
| 272 |
+
|
| 273 |
+
if results:
|
| 274 |
+
conversations = []
|
| 275 |
+
for result in results:
|
| 276 |
+
stored_convo_id, user_query_json, response_json, file_metadata_json, is_saved, created_at, updated_at = result
|
| 277 |
+
|
| 278 |
+
user_query = user_query_json if isinstance(user_query_json, dict) else json.loads(user_query_json)
|
| 279 |
+
response = response_json if isinstance(response_json, dict) else json.loads(response_json)
|
| 280 |
+
file_metadata = file_metadata_json if isinstance(file_metadata_json, (dict, list)) else json.loads(file_metadata_json)
|
| 281 |
+
|
| 282 |
+
conversations.append({
|
| 283 |
+
"convo_id": stored_convo_id,
|
| 284 |
+
"user_query": user_query,
|
| 285 |
+
"response": response,
|
| 286 |
+
"file_metadata": file_metadata,
|
| 287 |
+
"is_saved": is_saved,
|
| 288 |
+
"created_at": str(created_at),
|
| 289 |
+
"updated_at": str(updated_at)
|
| 290 |
+
})
|
| 291 |
+
|
| 292 |
+
return conversations
|
| 293 |
+
else:
|
| 294 |
+
print("No saved conversations found.")
|
| 295 |
+
return None
|
| 296 |
+
|
| 297 |
+
except Exception as e:
|
| 298 |
+
print(f"Error retrieving saved conversations: {e}")
|
| 299 |
+
return None
|
| 300 |
+
finally:
|
| 301 |
+
cursor.close()
|
| 302 |
+
conn.close()
|
| 303 |
+
|
| 304 |
+
# Sample data
|
| 305 |
+
data = {
|
| 306 |
+
"user_query": {
|
| 307 |
+
"query_id": "12345",
|
| 308 |
+
"text": "What is the weather like today?",
|
| 309 |
+
"user_metadata": {
|
| 310 |
+
"location": "New York",
|
| 311 |
+
"language": "en"
|
| 312 |
+
}
|
| 313 |
+
},
|
| 314 |
+
"response": {
|
| 315 |
+
"text": "The weather today in New York is sunny and 75°F. Would you like to see a detailed report or a forecast?",
|
| 316 |
+
"status": "success",
|
| 317 |
+
"response_time": "2025-10-13T08:10:00Z",
|
| 318 |
+
"duration": "2s",
|
| 319 |
+
"artifacts": [
|
| 320 |
+
{
|
| 321 |
+
"artifact_id": "artifact_1",
|
| 322 |
+
"file_id": "file_1",
|
| 323 |
+
"file_name": "weather_report.pdf",
|
| 324 |
+
"file_type": "pdf",
|
| 325 |
+
"file_size": 1024,
|
| 326 |
+
"file_url": "path_to_file",
|
| 327 |
+
"upload_timestamp": "2025-10-13T08:00:00Z",
|
| 328 |
+
"metadata": {
|
| 329 |
+
"created_by": "system",
|
| 330 |
+
"associated_query_id": "12345",
|
| 331 |
+
"associated_session_id": "session_001",
|
| 332 |
+
"artifact_timestamp": "2025-10-13T08:10:00Z"
|
| 333 |
+
}
|
| 334 |
+
},
|
| 335 |
+
{
|
| 336 |
+
"artifact_id": "artifact_2",
|
| 337 |
+
"file_id": "file_2",
|
| 338 |
+
"file_name": "weather_forecast_image.jpg",
|
| 339 |
+
"file_type": "jpg",
|
| 340 |
+
"file_size": 2048,
|
| 341 |
+
"file_url": "path_to_image",
|
| 342 |
+
"upload_timestamp": "2025-10-13T08:05:00Z",
|
| 343 |
+
"metadata": {
|
| 344 |
+
"created_by": "system",
|
| 345 |
+
"associated_query_id": "12345",
|
| 346 |
+
"associated_session_id": "session_001",
|
| 347 |
+
"artifact_timestamp": "2025-10-13T08:10:00Z"
|
| 348 |
+
}
|
| 349 |
+
}
|
| 350 |
+
]
|
| 351 |
+
},
|
| 352 |
+
"metadata": {
|
| 353 |
+
"query_response_metadata": {
|
| 354 |
+
"response_time": "2025-10-13T08:10:00Z",
|
| 355 |
+
"response_status": "success",
|
| 356 |
+
"response_duration": "2s"
|
| 357 |
+
},
|
| 358 |
+
"file_metadata": [
|
| 359 |
+
{
|
| 360 |
+
"file_id": "file_1",
|
| 361 |
+
"file_name": "weather_report.pdf",
|
| 362 |
+
"file_size": 1024,
|
| 363 |
+
"upload_timestamp": "2025-10-13T08:00:00Z"
|
| 364 |
+
},
|
| 365 |
+
{
|
| 366 |
+
"file_id": "file_2",
|
| 367 |
+
"file_name": "weather_forecast_image.jpg",
|
| 368 |
+
"file_size": 2048,
|
| 369 |
+
"upload_timestamp": "2025-10-13T08:05:00Z"
|
| 370 |
+
}
|
| 371 |
+
]
|
| 372 |
+
}
|
| 373 |
+
}
|
| 374 |
+
|
| 375 |
+
if __name__ == "__main__":
|
| 376 |
+
# Step 1: Create the table
|
| 377 |
+
#deleting the table
|
| 378 |
+
# print("Step none: Deleting table...")
|
| 379 |
+
# delete_table_by_name("stored_convoId_data") # Deletes the 'stored_convoId_data' table
|
| 380 |
+
print("Table 'stored_convoId_data' deleted successfully or does not exist.")
|
| 381 |
+
|
| 382 |
+
print("Step 1: Creating table...")
|
| 383 |
+
if create_table():
|
| 384 |
+
# Step 2: Insert data with is_saved=False
|
| 385 |
+
print("\nStep 2: Inserting data with is_saved=False...")
|
| 386 |
+
insert_convo_data("12345", data, is_saved=False)
|
| 387 |
+
|
| 388 |
+
# Step 3: Insert another conversation with is_saved=True
|
| 389 |
+
print("\nStep 3: Inserting another conversation with is_saved=True...")
|
| 390 |
+
insert_convo_data("67890", data, is_saved=True)
|
| 391 |
+
|
| 392 |
+
# Step 4: Retrieve all data by convo_id
|
| 393 |
+
print("\n" + "="*60)
|
| 394 |
+
print("Step 4: Retrieving all data for convo_id='12345'...")
|
| 395 |
+
print("="*60)
|
| 396 |
+
convo_data = get_convo_data("12345")
|
| 397 |
+
if convo_data:
|
| 398 |
+
print(json.dumps(convo_data, indent=2))
|
| 399 |
+
|
| 400 |
+
# Step 5: Update is_saved status
|
| 401 |
+
print("\n" + "="*60)
|
| 402 |
+
print("Step 5: Updating is_saved status to True for convo_id='12345'...")
|
| 403 |
+
print("="*60)
|
| 404 |
+
update_saved_status("12345", is_saved=True)
|
| 405 |
+
|
| 406 |
+
# Step 6: Retrieve saved conversations for specific convo_id
|
| 407 |
+
print("\n" + "="*60)
|
| 408 |
+
print("Step 6: Retrieving saved conversations for convo_id='12345'...")
|
| 409 |
+
print("="*60)
|
| 410 |
+
saved_convo = get_convo_data_by_saved_status("12345", is_saved=True)
|
| 411 |
+
if saved_convo:
|
| 412 |
+
print(json.dumps(saved_convo, indent=2))
|
| 413 |
+
|
| 414 |
+
# Step 7: Retrieve all saved conversations
|
| 415 |
+
print("\n" + "="*60)
|
| 416 |
+
print("Step 7: Retrieving ALL saved conversations...")
|
| 417 |
+
print("="*60)
|
| 418 |
+
all_saved = get_all_saved_conversations()
|
| 419 |
+
if all_saved:
|
| 420 |
+
print(f"Found {len(all_saved)} saved conversation(s):")
|
| 421 |
+
print(json.dumps(all_saved, indent=2))
|
| 422 |
+
else:
|
| 423 |
+
print("Failed to create table. Exiting...")
|
Dockerfile
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.10-slim
|
| 2 |
+
|
| 3 |
+
# Set working directory
|
| 4 |
+
WORKDIR /app
|
| 5 |
+
|
| 6 |
+
# Install system dependencies
|
| 7 |
+
RUN apt-get update && apt-get install -y \
|
| 8 |
+
gcc \
|
| 9 |
+
g++ \
|
| 10 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 11 |
+
|
| 12 |
+
# Copy requirements file
|
| 13 |
+
COPY requirements.txt .
|
| 14 |
+
|
| 15 |
+
# Install Python dependencies
|
| 16 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
| 17 |
+
|
| 18 |
+
# Copy the application code
|
| 19 |
+
COPY . .
|
| 20 |
+
|
| 21 |
+
# Create all necessary directories with proper permissions
|
| 22 |
+
RUN mkdir -p /app/tmp \
|
| 23 |
+
/app/.config/matplotlib \
|
| 24 |
+
/app/.cache \
|
| 25 |
+
/app/.local && \
|
| 26 |
+
chmod -R 777 /app/tmp /app/.config /app/.cache /app/.local
|
| 27 |
+
|
| 28 |
+
# Set comprehensive environment variables
|
| 29 |
+
ENV TMPDIR=/app/tmp \
|
| 30 |
+
TEMP=/app/tmp \
|
| 31 |
+
TMP=/app/tmp \
|
| 32 |
+
MPLCONFIGDIR=/app/.config/matplotlib \
|
| 33 |
+
XDG_CACHE_HOME=/app/.cache \
|
| 34 |
+
XDG_CONFIG_HOME=/app/.config \
|
| 35 |
+
HOME=/app \
|
| 36 |
+
MPLBACKEND=Agg \
|
| 37 |
+
PYTHONUNBUFFERED=1
|
| 38 |
+
|
| 39 |
+
# Expose the port FastAPI will run on
|
| 40 |
+
EXPOSE 7860
|
| 41 |
+
|
| 42 |
+
# Command to run the FastAPI application
|
| 43 |
+
CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860", "--workers", "1", "--timeout-keep-alive", "120"]
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
# # Builder stage
|
| 47 |
+
# FROM python:3.10-slim AS builder
|
| 48 |
+
# WORKDIR /app
|
| 49 |
+
# COPY requirements.txt .
|
| 50 |
+
# RUN apt-get update && apt-get install -y gcc g++ \
|
| 51 |
+
# && pip install --no-cache-dir -r requirements.txt \
|
| 52 |
+
# && apt-get purge -y gcc g++ \
|
| 53 |
+
# && apt-get autoremove -y \
|
| 54 |
+
# && rm -rf /var/lib/apt/lists/*
|
| 55 |
+
|
| 56 |
+
# # Runtime stage
|
| 57 |
+
# FROM python:3.10-slim
|
| 58 |
+
# WORKDIR /app
|
| 59 |
+
# COPY --from=builder /usr/local/lib/python3.10/site-packages /usr/local/lib/python3.10/site-packages
|
| 60 |
+
# COPY . .
|
| 61 |
+
# # Create TMP directories
|
| 62 |
+
# RUN mkdir -p /app/tmp /app/.config/matplotlib /app/.cache /app/.local && \
|
| 63 |
+
# chmod -R 777 /app/tmp /app/.config /app/.cache /app/.local
|
| 64 |
+
# ENV TMPDIR=/app/tmp \
|
| 65 |
+
# TEMP=/app/tmp \
|
| 66 |
+
# TMP=/app/tmp \
|
| 67 |
+
# MPLCONFIGDIR=/app/.config/matplotlib \
|
| 68 |
+
# XDG_CACHE_HOME=/app/.cache \
|
| 69 |
+
# XDG_CONFIG_HOME=/app/.config \
|
| 70 |
+
# HOME=/app \
|
| 71 |
+
# MPLBACKEND=Agg \
|
| 72 |
+
# PYTHONUNBUFFERED=1
|
| 73 |
+
# EXPOSE 7860
|
| 74 |
+
# CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860", "--workers", "1", "--timeout-keep-alive", "120"]
|
Redis/redis_agent_memory.py
ADDED
|
File without changes
|
Redis/rough.py
ADDED
|
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import asyncio
|
| 2 |
+
import os
|
| 3 |
+
from agents import Agent, Runner
|
| 4 |
+
from agents.extensions.memory import RedisSession
|
| 5 |
+
import dotenv
|
| 6 |
+
import redis.asyncio as aioredis
|
| 7 |
+
|
| 8 |
+
dotenv.load_dotenv()
|
| 9 |
+
Redis_url = os.getenv("REDIS_URL")
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
#===============transformers================
|
| 13 |
+
from transformers import pipeline
|
| 14 |
+
import re
|
| 15 |
+
from datetime import datetime
|
| 16 |
+
import json
|
| 17 |
+
import requests
|
| 18 |
+
|
| 19 |
+
# Initialize summarization model
|
| 20 |
+
try:
|
| 21 |
+
print("🚀 Loading BART model for summarization...")
|
| 22 |
+
summarizer = pipeline("summarization", model="facebook/bart-large-cnn", tokenizer="facebook/bart-large-cnn",framework="pt")
|
| 23 |
+
except Exception as e:
|
| 24 |
+
print(f"Error loading BART model: {e}")
|
| 25 |
+
summarizer = None
|
| 26 |
+
|
| 27 |
+
print("✅ Summarization pipeline ready!")
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
# ---------- Utility Functions ----------
|
| 33 |
+
|
| 34 |
+
async def get_sessions(user_login_id: str):
|
| 35 |
+
"""
|
| 36 |
+
Get all session IDs for a given user (based on key prefix).
|
| 37 |
+
"""
|
| 38 |
+
redis = await aioredis.from_url(Redis_url)
|
| 39 |
+
pattern = f"{user_login_id}:*"
|
| 40 |
+
keys = await redis.keys(pattern)
|
| 41 |
+
sessions = [key.decode().replace(f"{user_login_id}:", "") for key in keys]
|
| 42 |
+
await redis.close()
|
| 43 |
+
return sessions
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
async def get_session_history(user_login_id: str, session_id: str):
|
| 47 |
+
"""
|
| 48 |
+
Retrieve chat history for a given user's session.
|
| 49 |
+
"""
|
| 50 |
+
try:
|
| 51 |
+
session = RedisSession.from_url(
|
| 52 |
+
session_id,
|
| 53 |
+
url=Redis_url,
|
| 54 |
+
key_prefix=f"{user_login_id}:",
|
| 55 |
+
)
|
| 56 |
+
if not await session.ping():
|
| 57 |
+
raise Exception("Redis connection failed")
|
| 58 |
+
|
| 59 |
+
items = await session.get_items()
|
| 60 |
+
history = [
|
| 61 |
+
{"role": msg.get("role", "unknown"), "content": msg.get("content", "")}
|
| 62 |
+
for msg in items
|
| 63 |
+
]
|
| 64 |
+
await session.close()
|
| 65 |
+
return history
|
| 66 |
+
|
| 67 |
+
except Exception as e:
|
| 68 |
+
return {"error": str(e)}
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
async def delete_session(user_login_id: str, session_id: str):
|
| 72 |
+
"""
|
| 73 |
+
Delete a specific session for a given user.
|
| 74 |
+
"""
|
| 75 |
+
try:
|
| 76 |
+
session = RedisSession.from_url(
|
| 77 |
+
session_id,
|
| 78 |
+
url=Redis_url,
|
| 79 |
+
key_prefix=f"{user_login_id}:",
|
| 80 |
+
)
|
| 81 |
+
if not await session.ping():
|
| 82 |
+
raise Exception("Redis connection failed")
|
| 83 |
+
|
| 84 |
+
await session.clear_session()
|
| 85 |
+
await session.close()
|
| 86 |
+
return {"status": "success", "message": f"Session {session_id} deleted"}
|
| 87 |
+
|
| 88 |
+
except Exception as e:
|
| 89 |
+
return {"status": "error", "message": str(e)}
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
# ---------- Example Usage ----------
|
| 93 |
+
|
| 94 |
+
async def main_demo():
|
| 95 |
+
user_id = "vatsav_user2"
|
| 96 |
+
session_id = ":uuid_12345"
|
| 97 |
+
|
| 98 |
+
print("Creating session...")
|
| 99 |
+
session = RedisSession.from_url(
|
| 100 |
+
session_id,
|
| 101 |
+
url=Redis_url,
|
| 102 |
+
key_prefix=f"{user_id}:",
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
agent = Agent(name="Assistant", instructions="Be concise.")
|
| 106 |
+
await Runner.run(agent, "Hello!", session=session)
|
| 107 |
+
await Runner.run(agent, "How are you?", session=session)
|
| 108 |
+
await session.close()
|
| 109 |
+
|
| 110 |
+
print("\n--- All Sessions ---")
|
| 111 |
+
print(await get_sessions(user_id))
|
| 112 |
+
print("lenth of the sessions: ", len(await get_sessions(user_id)) or 0)
|
| 113 |
+
|
| 114 |
+
print("\n--- Session History ---")
|
| 115 |
+
|
| 116 |
+
history = await get_session_history(user_id, session_id)
|
| 117 |
+
print("lenght of the history: ", len(history) or 0)
|
| 118 |
+
print(history)
|
| 119 |
+
print("\nHistoryends=======================:")
|
| 120 |
+
for msg in history:
|
| 121 |
+
print(f"{msg['role']}: {msg['content']}")
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
print("\n--- Delete Session ---")
|
| 128 |
+
print(await delete_session(user_id, session_id))
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
if __name__ == "__main__":
|
| 132 |
+
asyncio.run(main_demo())
|
Redis/sessions.py
ADDED
|
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import asyncio
|
| 3 |
+
import dotenv
|
| 4 |
+
from fastapi import FastAPI, HTTPException
|
| 5 |
+
from fastapi.responses import JSONResponse
|
| 6 |
+
import redis.asyncio as aioredis
|
| 7 |
+
from agents import Agent, Runner
|
| 8 |
+
from agents.extensions.memory import RedisSession
|
| 9 |
+
# from transformers import pipeline
|
| 10 |
+
from fastapi import APIRouter
|
| 11 |
+
|
| 12 |
+
dotenv.load_dotenv()
|
| 13 |
+
Redis_url = os.getenv("REDIS_URL")
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
Redis_session_router = APIRouter(prefix="/Redis", tags=["Redis_Agent_Memory_Management"])
|
| 17 |
+
# ============ Transformers ============ #
|
| 18 |
+
# try:
|
| 19 |
+
# print("🚀 Loading summarization model (BART)...")
|
| 20 |
+
# summarizer = pipeline("summarization", model="facebook/bart-large-cnn")
|
| 21 |
+
# print("✅ Model ready!")
|
| 22 |
+
# except Exception as e:
|
| 23 |
+
# print(f"⚠️ Error loading model: {e}")
|
| 24 |
+
# summarizer = None
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
# ---------- Core Utility Functions ---------- #
|
| 28 |
+
async def get_sessions(user_login_id: str):
|
| 29 |
+
redis = await aioredis.from_url(Redis_url)
|
| 30 |
+
pattern = f"{user_login_id}:*"
|
| 31 |
+
keys = await redis.keys(pattern)
|
| 32 |
+
sessions = [key.decode().replace(f"{user_login_id}:", "") for key in keys]
|
| 33 |
+
await redis.close()
|
| 34 |
+
return sessions
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
async def get_session_history(user_login_id: str, session_id: str):
|
| 38 |
+
"""
|
| 39 |
+
Retrieve chat history for a given user's session.
|
| 40 |
+
"""
|
| 41 |
+
try:
|
| 42 |
+
session = RedisSession.from_url(
|
| 43 |
+
session_id,
|
| 44 |
+
url=Redis_url,
|
| 45 |
+
key_prefix=f"{user_login_id}:",
|
| 46 |
+
)
|
| 47 |
+
if not await session.ping():
|
| 48 |
+
raise Exception("Redis connection failed")
|
| 49 |
+
|
| 50 |
+
items = await session.get_items()
|
| 51 |
+
history = [
|
| 52 |
+
{"role": msg.get("role", "unknown"), "content": msg.get("content", "")}
|
| 53 |
+
for msg in items
|
| 54 |
+
]
|
| 55 |
+
await session.close()
|
| 56 |
+
return history
|
| 57 |
+
|
| 58 |
+
except Exception as e:
|
| 59 |
+
return {"error": str(e)}
|
| 60 |
+
|
| 61 |
+
async def delete_session(user_login_id: str, session_id: str):
|
| 62 |
+
session = RedisSession.from_url(session_id, url=Redis_url, key_prefix=f"{user_login_id}:")
|
| 63 |
+
if not await session.ping():
|
| 64 |
+
raise HTTPException(status_code=500, detail="Redis connection failed")
|
| 65 |
+
await session.clear_session()
|
| 66 |
+
await session.close()
|
| 67 |
+
return {"status": "success", "message": f"Session {session_id} deleted"}
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
# ---------- API Endpoints ---------- #
|
| 71 |
+
|
| 72 |
+
@Redis_session_router.get("/sessions/{user_login_id}")
|
| 73 |
+
async def api_get_sessions(user_login_id: str):
|
| 74 |
+
"""List all sessions for a given user."""
|
| 75 |
+
try:
|
| 76 |
+
sessions = await get_sessions(user_login_id)
|
| 77 |
+
return {"user": user_login_id, "sessions": sessions, "count": len(sessions)}
|
| 78 |
+
except Exception as e:
|
| 79 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
@Redis_session_router.get("/sessions/{user_login_id}/{session_id}")
|
| 83 |
+
async def api_get_session_history(user_login_id: str, session_id: str):
|
| 84 |
+
"""Get the chat history of a session. Optionally summarize with ?summarize=true."""
|
| 85 |
+
try:
|
| 86 |
+
history = await get_session_history(user_login_id, session_id)
|
| 87 |
+
if not history:
|
| 88 |
+
raise HTTPException(status_code=404, detail="Session history not found")
|
| 89 |
+
|
| 90 |
+
result = {"user": user_login_id, "session_id": session_id, "history": history}
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
text = " ".join([msg["content"] for msg in history if msg["role"] == "user"])
|
| 94 |
+
summary = summarizer(text, max_length=130, min_length=30, do_sample=False)[0]["summary_text"]
|
| 95 |
+
result["summary"] = summary
|
| 96 |
+
return result
|
| 97 |
+
|
| 98 |
+
except Exception as e:
|
| 99 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
@Redis_session_router.delete("/sessions/{user_login_id}/{session_id}")
|
| 103 |
+
async def api_delete_session(user_login_id: str, session_id: str):
|
| 104 |
+
"""Delete a session."""
|
| 105 |
+
try:
|
| 106 |
+
result = await delete_session(user_login_id, session_id)
|
| 107 |
+
return JSONResponse(content=result)
|
| 108 |
+
except Exception as e:
|
| 109 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
# ---------- Demo route (optional) ---------- #
|
| 113 |
+
@Redis_session_router.get("/")
|
| 114 |
+
async def root():
|
| 115 |
+
return {"message": "Redis Session Manager API running 🧠"}
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
# ---------- Local testing ---------- #
|
| 119 |
+
# if __name__ == "__main__":
|
| 120 |
+
# import uvicorn
|
| 121 |
+
# uvicorn.run("main:Redis_session_router", host="0.0.0.0", port=8000, reload=True)
|
Redis/sessions_new.py
ADDED
|
@@ -0,0 +1,203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi import APIRouter, HTTPException
|
| 2 |
+
|
| 3 |
+
import json
|
| 4 |
+
from pydantic import BaseModel
|
| 5 |
+
from typing import Optional, List, Dict
|
| 6 |
+
|
| 7 |
+
import redis, os, uuid, json
|
| 8 |
+
from fastapi import HTTPException
|
| 9 |
+
from datetime import datetime
|
| 10 |
+
from dotenv import load_dotenv
|
| 11 |
+
from fastapi import APIRouter
|
| 12 |
+
import dotenv
|
| 13 |
+
dotenv.load_dotenv()
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
load_dotenv()
|
| 18 |
+
Redis_url = os.getenv("REDIS_URL")
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
_new_new = APIRouter(prefix="/redis_session_new", tags=["redis_session_new"])
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class ChatRequest(BaseModel):
|
| 25 |
+
user_login_id: str
|
| 26 |
+
session_id: Optional[str] = None
|
| 27 |
+
query: str
|
| 28 |
+
org_id: Optional[str] = None
|
| 29 |
+
metadata: Optional[Dict] = None
|
| 30 |
+
|
| 31 |
+
class ChatResponse(BaseModel):
|
| 32 |
+
session_id: str
|
| 33 |
+
user_message: str
|
| 34 |
+
assistant_response: str
|
| 35 |
+
is_new_session: bool
|
| 36 |
+
session_title: str
|
| 37 |
+
timestamp: str
|
| 38 |
+
|
| 39 |
+
class MessageResponse(BaseModel):
|
| 40 |
+
message_id: str
|
| 41 |
+
role: str
|
| 42 |
+
content: str
|
| 43 |
+
timestamp: str
|
| 44 |
+
|
| 45 |
+
class ChatHistoryResponse(BaseModel):
|
| 46 |
+
session_id: str
|
| 47 |
+
title: str
|
| 48 |
+
created_at: str
|
| 49 |
+
message_count: int
|
| 50 |
+
messages: List[MessageResponse]
|
| 51 |
+
|
| 52 |
+
redis_session_route_new = APIRouter(prefix="/main_chatbot", tags=["Redis_session_main_chatbot"])
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def get_redis_client():
|
| 59 |
+
try:
|
| 60 |
+
REDIS_URL = os.getenv("REDIS_URL")
|
| 61 |
+
if REDIS_URL:
|
| 62 |
+
return redis.from_url(REDIS_URL, decode_responses=True)
|
| 63 |
+
return redis.StrictRedis(
|
| 64 |
+
host=os.getenv("REDIS_HOST", "localhost"),
|
| 65 |
+
port=int(os.getenv("REDIS_PORT", 6379)),
|
| 66 |
+
password=os.getenv("REDIS_PASSWORD"),
|
| 67 |
+
decode_responses=True
|
| 68 |
+
)
|
| 69 |
+
except Exception as e:
|
| 70 |
+
raise HTTPException(status_code=500, detail=f"Redis connection error: {e}")
|
| 71 |
+
|
| 72 |
+
redis_client = get_redis_client()
|
| 73 |
+
|
| 74 |
+
def create_session(user_login_id, org_id=None, metadata=None):
|
| 75 |
+
session_id = str(uuid.uuid4())
|
| 76 |
+
data = {
|
| 77 |
+
"session_id": session_id,
|
| 78 |
+
"user_login_id": user_login_id,
|
| 79 |
+
"org_id": org_id,
|
| 80 |
+
"created_at": datetime.now().isoformat(),
|
| 81 |
+
"title": "New Chat",
|
| 82 |
+
"message_count": 0,
|
| 83 |
+
"metadata": metadata or {}
|
| 84 |
+
}
|
| 85 |
+
redis_client.setex(f"session:{user_login_id}:{session_id}", 86400, json.dumps(data))
|
| 86 |
+
redis_client.setex(f"messages:{user_login_id}:{session_id}", 86400, json.dumps([]))
|
| 87 |
+
return data
|
| 88 |
+
|
| 89 |
+
def get_session(user_login_id, session_id):
|
| 90 |
+
key = f"session:{user_login_id}:{session_id}"
|
| 91 |
+
data = redis_client.get(key)
|
| 92 |
+
if not data:
|
| 93 |
+
raise HTTPException(status_code=404, detail="Session not found")
|
| 94 |
+
return json.loads(data)
|
| 95 |
+
|
| 96 |
+
def add_message(user_login_id, session_id, role, content):
|
| 97 |
+
msg = {
|
| 98 |
+
"message_id": str(uuid.uuid4()),
|
| 99 |
+
"role": role,
|
| 100 |
+
"content": content,
|
| 101 |
+
"timestamp": datetime.now().isoformat()
|
| 102 |
+
}
|
| 103 |
+
key = f"messages:{user_login_id}:{session_id}"
|
| 104 |
+
messages = json.loads(redis_client.get(key) or "[]")
|
| 105 |
+
messages.append(msg)
|
| 106 |
+
redis_client.setex(key, 86400, json.dumps(messages))
|
| 107 |
+
session = get_session(user_login_id, session_id)
|
| 108 |
+
session["message_count"] = len(messages)
|
| 109 |
+
redis_client.setex(f"session:{user_login_id}:{session_id}", 86400, json.dumps(session))
|
| 110 |
+
return msg["message_id"]
|
| 111 |
+
|
| 112 |
+
def get_message_history(user_login_id, session_id, limit=None):
|
| 113 |
+
key = f"messages:{user_login_id}:{session_id}"
|
| 114 |
+
data = redis_client.get(key)
|
| 115 |
+
if not data:
|
| 116 |
+
return []
|
| 117 |
+
messages = json.loads(data)
|
| 118 |
+
return messages[-limit:] if limit else messages
|
| 119 |
+
|
| 120 |
+
def generate_session_title(user_login_id, session_id):
|
| 121 |
+
messages = get_message_history(user_login_id, session_id)
|
| 122 |
+
first = next((m["content"] for m in messages if m["role"] == "user"), "")
|
| 123 |
+
title = " ".join(first.split()[:6]) + ("..." if len(first.split()) > 6 else "")
|
| 124 |
+
session = get_session(user_login_id, session_id)
|
| 125 |
+
session["title"] = title or "New Chat"
|
| 126 |
+
redis_client.setex(f"session:{user_login_id}:{session_id}", 86400, json.dumps(session))
|
| 127 |
+
return title or "New Chat"
|
| 128 |
+
|
| 129 |
+
def update_session_title_if_needed(user_login_id, session_id, is_new):
|
| 130 |
+
session = get_session(user_login_id, session_id)
|
| 131 |
+
if is_new or session.get("title") == "New Chat":
|
| 132 |
+
generate_session_title(user_login_id, session_id)
|
| 133 |
+
|
| 134 |
+
def format_conversation_context(messages, max_messages=5):
|
| 135 |
+
if not messages:
|
| 136 |
+
return ""
|
| 137 |
+
return "\n".join(f"{m['role']}: {m['content']}" for m in messages[-max_messages:])
|
| 138 |
+
|
| 139 |
+
@redis_session_route_new.get("/history/{user_login_id}/{session_id}", response_model=ChatHistoryResponse)
|
| 140 |
+
async def get_chat_history(user_login_id: str, session_id: str):
|
| 141 |
+
session_data = get_session(user_login_id, session_id)
|
| 142 |
+
messages = get_message_history(user_login_id, session_id)
|
| 143 |
+
|
| 144 |
+
message_responses = [
|
| 145 |
+
MessageResponse(**msg) for msg in messages
|
| 146 |
+
]
|
| 147 |
+
|
| 148 |
+
return ChatHistoryResponse(
|
| 149 |
+
session_id=session_id,
|
| 150 |
+
title=session_data.get("title", "New Chat"),
|
| 151 |
+
created_at=session_data.get("created_at"),
|
| 152 |
+
message_count=len(messages),
|
| 153 |
+
messages=message_responses
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
@redis_session_route_new.get("/sessions/{user_login_id}")
|
| 157 |
+
async def get_user_sessions(user_login_id: str):
|
| 158 |
+
sessions = []
|
| 159 |
+
pattern = f"session:{user_login_id}:*"
|
| 160 |
+
for key in redis_client.scan_iter(match=pattern):
|
| 161 |
+
session_data = redis_client.get(key)
|
| 162 |
+
if session_data:
|
| 163 |
+
sessions.append(json.loads(session_data))
|
| 164 |
+
sessions.sort(key=lambda x: x.get("created_at", ""), reverse=True)
|
| 165 |
+
return {
|
| 166 |
+
"user_login_id": user_login_id,
|
| 167 |
+
"total_sessions": len(sessions),
|
| 168 |
+
"sessions": sessions
|
| 169 |
+
}
|
| 170 |
+
|
| 171 |
+
@redis_session_route_new.delete("/sessions/{user_login_id}/{session_id}")
|
| 172 |
+
async def delete_session(user_login_id: str, session_id: str):
|
| 173 |
+
get_session(user_login_id, session_id) # Validate existence
|
| 174 |
+
redis_client.delete(f"session:{user_login_id}:{session_id}")
|
| 175 |
+
redis_client.delete(f"messages:{user_login_id}:{session_id}")
|
| 176 |
+
return {"message": "Session deleted", "session_id": session_id}
|
| 177 |
+
|
| 178 |
+
@redis_session_route_new.get("/health")
|
| 179 |
+
async def health():
|
| 180 |
+
try:
|
| 181 |
+
redis_client.ping()
|
| 182 |
+
status = "connected"
|
| 183 |
+
total_sessions = len(list(redis_client.scan_iter(match="session:*")))
|
| 184 |
+
except:
|
| 185 |
+
status = "disconnected"
|
| 186 |
+
total_sessions = 0
|
| 187 |
+
return {
|
| 188 |
+
"status": "ok",
|
| 189 |
+
"redis_status": status,
|
| 190 |
+
"total_sessions": total_sessions,
|
| 191 |
+
"session_ttl": "24 hours"
|
| 192 |
+
}
|
| 193 |
+
|
| 194 |
+
@redis_session_route_new.put("/sessions/{user_login_id}/{session_id}/title")
|
| 195 |
+
async def update_session_title(user_login_id: str, session_id: str, title: str):
|
| 196 |
+
key = f"session:{user_login_id}:{session_id}"
|
| 197 |
+
session_data = redis_client.get(key)
|
| 198 |
+
if not session_data:
|
| 199 |
+
raise HTTPException(status_code=404, detail="Session not found")
|
| 200 |
+
session = json.loads(session_data)
|
| 201 |
+
session["title"] = title
|
| 202 |
+
redis_client.setex(key, 86400, json.dumps(session))
|
| 203 |
+
return {"message": "Title updated"}
|
Redis/sessions_old.py
ADDED
|
@@ -0,0 +1,598 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi import FastAPI, HTTPException, Query as QueryParam, Request
|
| 2 |
+
from fastapi.responses import JSONResponse
|
| 3 |
+
from pydantic import BaseModel, Field
|
| 4 |
+
from typing import Optional, List, Dict, Any
|
| 5 |
+
import os
|
| 6 |
+
import json
|
| 7 |
+
import redis
|
| 8 |
+
import uuid
|
| 9 |
+
from datetime import datetime
|
| 10 |
+
from dotenv import load_dotenv
|
| 11 |
+
from fastapi import APIRouter
|
| 12 |
+
|
| 13 |
+
load_dotenv()
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
# Redis_session_router_old = FastAPI(title="Redis Session Management API")
|
| 18 |
+
Redis_session_router_old = APIRouter(prefix="/Redis", tags=["Redis_Management_old"])
|
| 19 |
+
|
| 20 |
+
# ==================== CONFIGURATION ====================
|
| 21 |
+
|
| 22 |
+
# Redis Configuration
|
| 23 |
+
REDIS_URL = os.getenv("REDIS_URL")
|
| 24 |
+
REDIS_HOST = os.getenv("REDIS_HOST", "127.0.0.1")
|
| 25 |
+
REDIS_PORT = int(os.getenv("REDIS_PORT", 6379))
|
| 26 |
+
REDIS_PASSWORD = os.getenv("REDIS_PASSWORD")
|
| 27 |
+
|
| 28 |
+
# ==================== REDIS CLIENT INITIALIZATION ====================
|
| 29 |
+
|
| 30 |
+
def get_redis_client():
|
| 31 |
+
"""Initialize Redis client with fallback to local Redis"""
|
| 32 |
+
try:
|
| 33 |
+
if REDIS_URL:
|
| 34 |
+
# Use deployed Redis URL
|
| 35 |
+
redis_client = redis.from_url(
|
| 36 |
+
REDIS_URL,
|
| 37 |
+
decode_responses=True,
|
| 38 |
+
socket_connect_timeout=5,
|
| 39 |
+
socket_timeout=5
|
| 40 |
+
)
|
| 41 |
+
# Test connection
|
| 42 |
+
redis_client.ping()
|
| 43 |
+
print(f"✅ Connected to deployed Redis: {REDIS_URL}")
|
| 44 |
+
return redis_client
|
| 45 |
+
else:
|
| 46 |
+
# Use local Redis
|
| 47 |
+
redis_client = redis.StrictRedis(
|
| 48 |
+
host=REDIS_HOST,
|
| 49 |
+
port=REDIS_PORT,
|
| 50 |
+
password=REDIS_PASSWORD,
|
| 51 |
+
decode_responses=True,
|
| 52 |
+
socket_connect_timeout=5,
|
| 53 |
+
socket_timeout=5
|
| 54 |
+
)
|
| 55 |
+
# Test connection
|
| 56 |
+
redis_client.ping()
|
| 57 |
+
print(f"✅ Connected to local Redis: {REDIS_HOST}:{REDIS_PORT}")
|
| 58 |
+
return redis_client
|
| 59 |
+
except Exception as e:
|
| 60 |
+
print(f"❌ Redis connection failed: {e}")
|
| 61 |
+
raise HTTPException(status_code=500, detail=f"Redis connection failed: {str(e)}")
|
| 62 |
+
|
| 63 |
+
# Initialize Redis client
|
| 64 |
+
# redis_client = get_redis_client()
|
| 65 |
+
|
| 66 |
+
# ==================== PYDANTIC MODELS ====================
|
| 67 |
+
|
| 68 |
+
class SessionResponse(BaseModel):
|
| 69 |
+
session_id: str
|
| 70 |
+
userLoginId: int
|
| 71 |
+
orgId: int
|
| 72 |
+
created_at: str
|
| 73 |
+
status: str
|
| 74 |
+
title: Optional[str] = "New Chat"
|
| 75 |
+
|
| 76 |
+
class MessageResponse(BaseModel):
|
| 77 |
+
message_id: str
|
| 78 |
+
session_id: str
|
| 79 |
+
role: str # "user" or "assistant"
|
| 80 |
+
message: str
|
| 81 |
+
timestamp: str
|
| 82 |
+
|
| 83 |
+
class ChatHistoryResponse(BaseModel):
|
| 84 |
+
session_id: str
|
| 85 |
+
messages: List[MessageResponse]
|
| 86 |
+
total_messages: int
|
| 87 |
+
|
| 88 |
+
class UpdateSessionTitleRequest(BaseModel):
|
| 89 |
+
new_title: str
|
| 90 |
+
|
| 91 |
+
class CreateSessionRequest(BaseModel):
|
| 92 |
+
userLoginId: int
|
| 93 |
+
orgId: int
|
| 94 |
+
auth_token: str
|
| 95 |
+
|
| 96 |
+
class AddMessageRequest(BaseModel):
|
| 97 |
+
session_id: str
|
| 98 |
+
role: str
|
| 99 |
+
message: str
|
| 100 |
+
|
| 101 |
+
# ==================== SESSION MANAGEMENT FUNCTIONS ====================
|
| 102 |
+
|
| 103 |
+
def create_session(userLoginId: int, orgId: int, auth_token: str) -> dict:
|
| 104 |
+
"""Create a new chat session"""
|
| 105 |
+
session_id = str(uuid.uuid4())
|
| 106 |
+
session_data = {
|
| 107 |
+
"session_id": session_id,
|
| 108 |
+
"userLoginId": userLoginId,
|
| 109 |
+
"orgId": orgId,
|
| 110 |
+
"auth_token": auth_token,
|
| 111 |
+
"created_at": datetime.now().isoformat(),
|
| 112 |
+
"status": "active",
|
| 113 |
+
"title": "New Chat"
|
| 114 |
+
}
|
| 115 |
+
|
| 116 |
+
# Store session in Redis with 24 hour TTL
|
| 117 |
+
redis_client.setex(
|
| 118 |
+
f"session:{session_id}",
|
| 119 |
+
86400, # 24 hours
|
| 120 |
+
json.dumps(session_data)
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
# Initialize empty chat history
|
| 124 |
+
redis_client.setex(
|
| 125 |
+
f"chat:{session_id}",
|
| 126 |
+
86400, # 24 hours
|
| 127 |
+
json.dumps([])
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
# Initialize conversation memory
|
| 131 |
+
redis_client.setex(
|
| 132 |
+
f"memory:{session_id}",
|
| 133 |
+
86400, # 24 hours
|
| 134 |
+
json.dumps([])
|
| 135 |
+
)
|
| 136 |
+
|
| 137 |
+
return session_data
|
| 138 |
+
|
| 139 |
+
def get_session(session_id: str) -> dict:
|
| 140 |
+
"""Get session data from Redis"""
|
| 141 |
+
session_data = redis_client.get(f"session:{session_id}")
|
| 142 |
+
if not session_data:
|
| 143 |
+
raise HTTPException(status_code=404, detail="Session not found or expired")
|
| 144 |
+
return json.loads(session_data)
|
| 145 |
+
|
| 146 |
+
def add_message_to_session(session_id: str, role: str, message: str) -> str:
|
| 147 |
+
"""Add message to session chat history"""
|
| 148 |
+
message_id = str(uuid.uuid4())
|
| 149 |
+
message_data = {
|
| 150 |
+
"message_id": message_id,
|
| 151 |
+
"session_id": session_id,
|
| 152 |
+
"role": role,
|
| 153 |
+
"message": message,
|
| 154 |
+
"timestamp": datetime.now().isoformat()
|
| 155 |
+
}
|
| 156 |
+
|
| 157 |
+
# Get current chat history
|
| 158 |
+
chat_history = redis_client.get(f"chat:{session_id}")
|
| 159 |
+
if chat_history:
|
| 160 |
+
messages = json.loads(chat_history)
|
| 161 |
+
else:
|
| 162 |
+
messages = []
|
| 163 |
+
|
| 164 |
+
# Add new message
|
| 165 |
+
messages.Redis_session_routerend(message_data)
|
| 166 |
+
|
| 167 |
+
# Update chat history in Redis with extended TTL
|
| 168 |
+
redis_client.setex(
|
| 169 |
+
f"chat:{session_id}",
|
| 170 |
+
86400, # 24 hours
|
| 171 |
+
json.dumps(messages)
|
| 172 |
+
)
|
| 173 |
+
|
| 174 |
+
return message_id
|
| 175 |
+
|
| 176 |
+
def get_session_memory(session_id: str) -> List[Dict]:
|
| 177 |
+
"""Get conversation memory for session"""
|
| 178 |
+
memory_data = redis_client.get(f"memory:{session_id}")
|
| 179 |
+
if memory_data:
|
| 180 |
+
return json.loads(memory_data)
|
| 181 |
+
return []
|
| 182 |
+
|
| 183 |
+
def update_session_memory(session_id: str, messages: List[Dict]):
|
| 184 |
+
"""Update conversation memory for session"""
|
| 185 |
+
redis_client.setex(
|
| 186 |
+
f"memory:{session_id}",
|
| 187 |
+
86400, # 24 hours
|
| 188 |
+
json.dumps(messages)
|
| 189 |
+
)
|
| 190 |
+
|
| 191 |
+
def generate_session_title(session_id: str) -> str:
|
| 192 |
+
"""Generate a title for the session based on chat history"""
|
| 193 |
+
try:
|
| 194 |
+
# Check session
|
| 195 |
+
session_data = redis_client.get(f"session:{session_id}")
|
| 196 |
+
if session_data:
|
| 197 |
+
session = json.loads(session_data)
|
| 198 |
+
if "user_title" in session:
|
| 199 |
+
# Don't override user-defined titles
|
| 200 |
+
return session["user_title"]
|
| 201 |
+
|
| 202 |
+
# Get chat history
|
| 203 |
+
chat_data = redis_client.get(f"chat:{session_id}")
|
| 204 |
+
if not chat_data:
|
| 205 |
+
return "New Chat"
|
| 206 |
+
|
| 207 |
+
messages = json.loads(chat_data)
|
| 208 |
+
if not messages:
|
| 209 |
+
return "New Chat"
|
| 210 |
+
|
| 211 |
+
# Get first user message
|
| 212 |
+
first_user_message = next(
|
| 213 |
+
(msg["message"] for msg in messages if msg["role"] == "user"), None
|
| 214 |
+
)
|
| 215 |
+
|
| 216 |
+
if not first_user_message:
|
| 217 |
+
return "New Chat"
|
| 218 |
+
|
| 219 |
+
# Create a simple title from first message
|
| 220 |
+
words = first_user_message.split()[:6]
|
| 221 |
+
title = " ".join(words) + ("..." if len(first_user_message.split()) > 6 else "")
|
| 222 |
+
|
| 223 |
+
# Save to session
|
| 224 |
+
if session_data:
|
| 225 |
+
session["generated_title"] = title
|
| 226 |
+
if not session.get("user_title"):
|
| 227 |
+
session["title"] = title
|
| 228 |
+
redis_client.setex(f"session:{session_id}", 86400, json.dumps(session))
|
| 229 |
+
|
| 230 |
+
return title
|
| 231 |
+
|
| 232 |
+
except Exception as e:
|
| 233 |
+
print(f"Error in generate_session_title: {e}")
|
| 234 |
+
return "New Chat"
|
| 235 |
+
|
| 236 |
+
def update_session_title(session_id: str):
|
| 237 |
+
"""Update session title after first message"""
|
| 238 |
+
try:
|
| 239 |
+
# Get session data
|
| 240 |
+
session_data = redis_client.get(f"session:{session_id}")
|
| 241 |
+
if not session_data:
|
| 242 |
+
return
|
| 243 |
+
|
| 244 |
+
session = json.loads(session_data)
|
| 245 |
+
|
| 246 |
+
# Only update if current title is "New Chat"
|
| 247 |
+
if session.get("title", "New Chat") == "New Chat":
|
| 248 |
+
new_title = generate_session_title(session_id)
|
| 249 |
+
session["title"] = new_title
|
| 250 |
+
|
| 251 |
+
# Update session in Redis
|
| 252 |
+
redis_client.setex(
|
| 253 |
+
f"session:{session_id}",
|
| 254 |
+
86400, # 24 hours
|
| 255 |
+
json.dumps(session)
|
| 256 |
+
)
|
| 257 |
+
|
| 258 |
+
except Exception as e:
|
| 259 |
+
print(f"Error updating session title: {e}")
|
| 260 |
+
pass
|
| 261 |
+
|
| 262 |
+
def get_user_sessions(userLoginId: int) -> List[dict]:
|
| 263 |
+
"""Get all sessions for a user with generated titles"""
|
| 264 |
+
sessions = []
|
| 265 |
+
# Scan for all session keys
|
| 266 |
+
for key in redis_client.scan_iter(match="session:*"):
|
| 267 |
+
session_data = redis_client.get(key)
|
| 268 |
+
if session_data:
|
| 269 |
+
session = json.loads(session_data)
|
| 270 |
+
if session["userLoginId"] == userLoginId:
|
| 271 |
+
# Generate title based on chat history
|
| 272 |
+
session["title"] = generate_session_title(session["session_id"])
|
| 273 |
+
sessions.Redis_session_routerend(session)
|
| 274 |
+
|
| 275 |
+
|
| 276 |
+
# Sort sessions by created_at (most recent first)
|
| 277 |
+
sessions.sort(key=lambda x: x["created_at"], reverse=True)
|
| 278 |
+
|
| 279 |
+
return sessions
|
| 280 |
+
def get_user_sessions(userLoginId: int) -> List[dict]:
|
| 281 |
+
"""Get all sessions for a user with generated titles"""
|
| 282 |
+
sessions = []
|
| 283 |
+
# Scan for all session keys
|
| 284 |
+
for key in redis_client.scan_iter(match="session:*"):
|
| 285 |
+
session_data = redis_client.get(key)
|
| 286 |
+
if session_data:
|
| 287 |
+
session = json.loads(session_data)
|
| 288 |
+
if session["userLoginId"] == userLoginId:
|
| 289 |
+
# Generate title based on chat history
|
| 290 |
+
session["title"] = generate_session_title(session["session_id"])
|
| 291 |
+
sessions.append(session)
|
| 292 |
+
|
| 293 |
+
# Sort sessions by created_at (most recent first)
|
| 294 |
+
sessions.sort(key=lambda x: x["created_at"], reverse=True)
|
| 295 |
+
return sessions
|
| 296 |
+
|
| 297 |
+
def delete_session(session_id: str):
|
| 298 |
+
"""Delete session and associated data"""
|
| 299 |
+
# Delete session data
|
| 300 |
+
redis_client.delete(f"session:{session_id}")
|
| 301 |
+
# Delete chat history
|
| 302 |
+
redis_client.delete(f"chat:{session_id}")
|
| 303 |
+
# Delete memory
|
| 304 |
+
redis_client.delete(f"memory:{session_id}")
|
| 305 |
+
|
| 306 |
+
# ==================== MIDDLEWARE ====================
|
| 307 |
+
|
| 308 |
+
|
| 309 |
+
|
| 310 |
+
# ==================== API ENDPOINTS ====================
|
| 311 |
+
|
| 312 |
+
@Redis_session_router_old.post("/sessions", response_model=SessionResponse)
|
| 313 |
+
def create_new_session(request: CreateSessionRequest):
|
| 314 |
+
"""Create a new chat session"""
|
| 315 |
+
try:
|
| 316 |
+
session_data = create_session(request.userLoginId, request.orgId, request.auth_token)
|
| 317 |
+
return SessionResponse(**session_data)
|
| 318 |
+
except Exception as e:
|
| 319 |
+
raise HTTPException(status_code=500, detail=f"Error creating session: {str(e)}")
|
| 320 |
+
|
| 321 |
+
# @Redis_session_router_old.get("/sessions")
|
| 322 |
+
# def list_user_sessions(userLoginId: int):
|
| 323 |
+
# """List all sessions for a user"""
|
| 324 |
+
# try:
|
| 325 |
+
# sessions = get_user_sessions(userLoginId)
|
| 326 |
+
# print(sessions)
|
| 327 |
+
# return {
|
| 328 |
+
# "userLoginId": userLoginId,
|
| 329 |
+
# "total_sessions": len(sessions),
|
| 330 |
+
# "sessions": sessions
|
| 331 |
+
# }
|
| 332 |
+
# except Exception as e:
|
| 333 |
+
# raise HTTPException(status_code=500, detail=f"Error fetching sessions: {str(e)}")
|
| 334 |
+
@Redis_session_router_old.get("/sessions")
|
| 335 |
+
def list_user_sessions(userLoginId: int):
|
| 336 |
+
"""List all sessions for a user"""
|
| 337 |
+
try:
|
| 338 |
+
sessions = get_user_sessions(userLoginId)
|
| 339 |
+
return {
|
| 340 |
+
"userLoginId": userLoginId,
|
| 341 |
+
"total_sessions": len(sessions),
|
| 342 |
+
"sessions": sessions
|
| 343 |
+
}
|
| 344 |
+
except Exception as e:
|
| 345 |
+
raise HTTPException(status_code=500, detail=f"Error fetching sessions: {str(e)}")
|
| 346 |
+
|
| 347 |
+
@Redis_session_router_old.get("/sessions/{session_id}")
|
| 348 |
+
def get_session_details(session_id: str):
|
| 349 |
+
"""Get details of a specific session"""
|
| 350 |
+
try:
|
| 351 |
+
session_data = get_session(session_id)
|
| 352 |
+
return session_data
|
| 353 |
+
except Exception as e:
|
| 354 |
+
raise HTTPException(status_code=404, detail=f"Session not found: {str(e)}")
|
| 355 |
+
|
| 356 |
+
@Redis_session_router_old.delete("/sessions/{session_id}")
|
| 357 |
+
def delete_user_session(session_id: str):
|
| 358 |
+
"""Delete/close a session"""
|
| 359 |
+
try:
|
| 360 |
+
# Verify session exists
|
| 361 |
+
get_session(session_id)
|
| 362 |
+
|
| 363 |
+
# Delete session
|
| 364 |
+
delete_session(session_id)
|
| 365 |
+
|
| 366 |
+
return {
|
| 367 |
+
"message": f"Session {session_id} deleted successfully",
|
| 368 |
+
"session_id": session_id
|
| 369 |
+
}
|
| 370 |
+
except Exception as e:
|
| 371 |
+
raise HTTPException(status_code=500, detail=f"Error deleting session: {str(e)}")
|
| 372 |
+
|
| 373 |
+
@Redis_session_router_old.get("/sessions/{session_id}/history", response_model=ChatHistoryResponse)
|
| 374 |
+
def get_session_history(
|
| 375 |
+
session_id: str,
|
| 376 |
+
n: int = QueryParam(50, description="Number of recent messages to return")
|
| 377 |
+
):
|
| 378 |
+
"""Get chat history for a session"""
|
| 379 |
+
try:
|
| 380 |
+
# Verify session exists
|
| 381 |
+
get_session(session_id)
|
| 382 |
+
|
| 383 |
+
# Get chat history
|
| 384 |
+
chat_data = redis_client.get(f"chat:{session_id}")
|
| 385 |
+
if not chat_data:
|
| 386 |
+
return ChatHistoryResponse(
|
| 387 |
+
session_id=session_id,
|
| 388 |
+
messages=[],
|
| 389 |
+
total_messages=0
|
| 390 |
+
)
|
| 391 |
+
|
| 392 |
+
messages = json.loads(chat_data)
|
| 393 |
+
|
| 394 |
+
# Get the last n messages (or all if less than n)
|
| 395 |
+
recent_messages = messages[-n:] if len(messages) > n else messages
|
| 396 |
+
|
| 397 |
+
# Convert to MessageResponse objects
|
| 398 |
+
message_responses = [MessageResponse(**msg) for msg in recent_messages]
|
| 399 |
+
|
| 400 |
+
return ChatHistoryResponse(
|
| 401 |
+
session_id=session_id,
|
| 402 |
+
messages=message_responses,
|
| 403 |
+
total_messages=len(messages)
|
| 404 |
+
)
|
| 405 |
+
|
| 406 |
+
except Exception as e:
|
| 407 |
+
raise HTTPException(status_code=500, detail=f"Error fetching chat history: {str(e)}")
|
| 408 |
+
|
| 409 |
+
@Redis_session_router_old.post("/sessions/{session_id}/messages")
|
| 410 |
+
def add_message(session_id: str, request: AddMessageRequest):
|
| 411 |
+
"""Add a message to a session"""
|
| 412 |
+
try:
|
| 413 |
+
# Verify session exists
|
| 414 |
+
get_session(session_id)
|
| 415 |
+
|
| 416 |
+
# Validate role
|
| 417 |
+
if request.role not in ["user", "assistant"]:
|
| 418 |
+
raise HTTPException(status_code=400, detail="Role must be 'user' or 'assistant'")
|
| 419 |
+
|
| 420 |
+
# Add message
|
| 421 |
+
message_id = add_message_to_session(session_id, request.role, request.message)
|
| 422 |
+
|
| 423 |
+
# Update title if it's the first user message
|
| 424 |
+
if request.role == "user":
|
| 425 |
+
update_session_title(session_id)
|
| 426 |
+
|
| 427 |
+
return {
|
| 428 |
+
"message_id": message_id,
|
| 429 |
+
"session_id": session_id,
|
| 430 |
+
"role": request.role,
|
| 431 |
+
"message": request.message,
|
| 432 |
+
"timestamp": datetime.now().isoformat()
|
| 433 |
+
}
|
| 434 |
+
|
| 435 |
+
except HTTPException:
|
| 436 |
+
raise
|
| 437 |
+
except Exception as e:
|
| 438 |
+
raise HTTPException(status_code=500, detail=f"Error adding message: {str(e)}")
|
| 439 |
+
|
| 440 |
+
@Redis_session_router_old.put("/sessions/{session_id}/title")
|
| 441 |
+
def update_session_title_endpoint(session_id: str, request: UpdateSessionTitleRequest):
|
| 442 |
+
"""Update the user-defined title of an existing session"""
|
| 443 |
+
try:
|
| 444 |
+
session_data = redis_client.get(f"session:{session_id}")
|
| 445 |
+
if not session_data:
|
| 446 |
+
raise HTTPException(status_code=404, detail="Session not found or expired")
|
| 447 |
+
|
| 448 |
+
session = json.loads(session_data)
|
| 449 |
+
|
| 450 |
+
new_title = request.new_title.strip()
|
| 451 |
+
if not new_title:
|
| 452 |
+
raise HTTPException(status_code=400, detail="New title cannot be empty")
|
| 453 |
+
if len(new_title) > 100:
|
| 454 |
+
raise HTTPException(status_code=400, detail="Title cannot exceed 100 characters")
|
| 455 |
+
|
| 456 |
+
old_title = session.get("title", "New Chat")
|
| 457 |
+
session["user_title"] = new_title
|
| 458 |
+
session["title"] = new_title
|
| 459 |
+
session["last_updated"] = datetime.now().isoformat()
|
| 460 |
+
|
| 461 |
+
redis_client.setex(f"session:{session_id}", 86400, json.dumps(session))
|
| 462 |
+
|
| 463 |
+
return {
|
| 464 |
+
"message": "Session title updated successfully",
|
| 465 |
+
"session_id": session_id,
|
| 466 |
+
"old_title": old_title,
|
| 467 |
+
"new_title": new_title
|
| 468 |
+
}
|
| 469 |
+
|
| 470 |
+
except HTTPException:
|
| 471 |
+
raise
|
| 472 |
+
except Exception as e:
|
| 473 |
+
raise HTTPException(status_code=500, detail=f"Error updating session title: {str(e)}")
|
| 474 |
+
|
| 475 |
+
@Redis_session_router_old.put("/sessions/{session_id}/refresh-title")
|
| 476 |
+
def refresh_session_title(session_id: str):
|
| 477 |
+
"""Manually refresh/regenerate session title"""
|
| 478 |
+
try:
|
| 479 |
+
# Verify session exists
|
| 480 |
+
session_data = get_session(session_id)
|
| 481 |
+
|
| 482 |
+
# Generate new title
|
| 483 |
+
new_title = generate_session_title(session_id)
|
| 484 |
+
|
| 485 |
+
# Update session
|
| 486 |
+
session_data["title"] = new_title
|
| 487 |
+
redis_client.setex(
|
| 488 |
+
f"session:{session_id}",
|
| 489 |
+
86400, # 24 hours
|
| 490 |
+
json.dumps(session_data)
|
| 491 |
+
)
|
| 492 |
+
|
| 493 |
+
return {
|
| 494 |
+
"session_id": session_id,
|
| 495 |
+
"new_title": new_title,
|
| 496 |
+
"message": "Session title updated successfully"
|
| 497 |
+
}
|
| 498 |
+
|
| 499 |
+
except Exception as e:
|
| 500 |
+
raise HTTPException(status_code=500, detail=f"Error updating session title: {str(e)}")
|
| 501 |
+
|
| 502 |
+
@Redis_session_router_old.get("/sessions/{session_id}/memory")
|
| 503 |
+
def get_session_memory_endpoint(session_id: str):
|
| 504 |
+
"""Get conversation memory for a session"""
|
| 505 |
+
try:
|
| 506 |
+
# Verify session exists
|
| 507 |
+
get_session(session_id)
|
| 508 |
+
|
| 509 |
+
memory = get_session_memory(session_id)
|
| 510 |
+
|
| 511 |
+
return {
|
| 512 |
+
"session_id": session_id,
|
| 513 |
+
"memory": memory,
|
| 514 |
+
"total_items": len(memory)
|
| 515 |
+
}
|
| 516 |
+
|
| 517 |
+
except Exception as e:
|
| 518 |
+
raise HTTPException(status_code=500, detail=f"Error fetching memory: {str(e)}")
|
| 519 |
+
|
| 520 |
+
# ==================== SYSTEM ENDPOINTS ====================
|
| 521 |
+
|
| 522 |
+
@Redis_session_router_old.get("/redis-info")
|
| 523 |
+
def redis_info():
|
| 524 |
+
"""Get Redis connection information"""
|
| 525 |
+
try:
|
| 526 |
+
info = redis_client.info()
|
| 527 |
+
return {
|
| 528 |
+
"redis_connected": True,
|
| 529 |
+
"redis_version": info.get("redis_version"),
|
| 530 |
+
"used_memory": info.get("used_memory_human"),
|
| 531 |
+
"connected_clients": info.get("connected_clients"),
|
| 532 |
+
"total_keys": redis_client.dbsize()
|
| 533 |
+
}
|
| 534 |
+
except Exception as e:
|
| 535 |
+
return {
|
| 536 |
+
"redis_connected": False,
|
| 537 |
+
"error": str(e)
|
| 538 |
+
}
|
| 539 |
+
|
| 540 |
+
@Redis_session_router_old.get("/health")
|
| 541 |
+
def health():
|
| 542 |
+
"""System health check"""
|
| 543 |
+
try:
|
| 544 |
+
redis_client.ping()
|
| 545 |
+
redis_status = "connected"
|
| 546 |
+
except:
|
| 547 |
+
redis_status = "disconnected"
|
| 548 |
+
|
| 549 |
+
total_sessions = 0
|
| 550 |
+
if redis_status == "connected":
|
| 551 |
+
try:
|
| 552 |
+
total_sessions = len(list(redis_client.scan_iter(match="session:*")))
|
| 553 |
+
except:
|
| 554 |
+
pass
|
| 555 |
+
|
| 556 |
+
return {
|
| 557 |
+
"status": "ok",
|
| 558 |
+
"redis_status": redis_status,
|
| 559 |
+
"session_management": "enabled",
|
| 560 |
+
"total_sessions": total_sessions,
|
| 561 |
+
"ttl": "24 hours"
|
| 562 |
+
}
|
| 563 |
+
|
| 564 |
+
@Redis_session_router_old.get("/")
|
| 565 |
+
def root():
|
| 566 |
+
"""Root endpoint with API information"""
|
| 567 |
+
return {
|
| 568 |
+
"service": "Redis Session Management API",
|
| 569 |
+
"version": "1.0.0",
|
| 570 |
+
"endpoints": {
|
| 571 |
+
"sessions": {
|
| 572 |
+
"POST /sessions": "Create new session",
|
| 573 |
+
"GET /sessions?userLoginId={id}": "List user sessions",
|
| 574 |
+
"GET /sessions/{id}": "Get session details",
|
| 575 |
+
"DELETE /sessions/{id}": "Delete session",
|
| 576 |
+
"GET /sessions/{id}/history": "Get chat history",
|
| 577 |
+
"POST /sessions/{id}/messages": "Add message to session",
|
| 578 |
+
"PUT /sessions/{id}/title": "Update session title",
|
| 579 |
+
"PUT /sessions/{id}/refresh-title": "Refresh session title",
|
| 580 |
+
"GET /sessions/{id}/memory": "Get session memory"
|
| 581 |
+
},
|
| 582 |
+
"system": {
|
| 583 |
+
"GET /health": "Health check",
|
| 584 |
+
"GET /redis-info": "Redis connection info"
|
| 585 |
+
}
|
| 586 |
+
}
|
| 587 |
+
}
|
| 588 |
+
|
| 589 |
+
# ==================== RUN SERVER ====================
|
| 590 |
+
|
| 591 |
+
# if __name__ == "__main__":
|
| 592 |
+
# import uvicorn
|
| 593 |
+
# try:
|
| 594 |
+
# uvicorn.run(Redis_session_router_old, host="0.0.0.0", port=8000)
|
| 595 |
+
# except KeyboardInterrupt:
|
| 596 |
+
# print("\n🛑 Server stopped gracefully")
|
| 597 |
+
# except Exception as e:
|
| 598 |
+
# print(f"❌ Server error: {e}")
|
Routes/generate_report.py
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi import APIRouter, HTTPException
|
| 2 |
+
from pydantic import BaseModel
|
| 3 |
+
from typing import List, Optional
|
| 4 |
+
from Routes.helpers.report_generation_helpers import return_html_report, html_to_pdf_via_api
|
| 5 |
+
|
| 6 |
+
# Router initialization
|
| 7 |
+
Report_Generation_Router = APIRouter(prefix="/report_generation", tags=["Report_Generation"])
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
# ---------- MODELS ----------
|
| 11 |
+
class PageOptions(BaseModel):
|
| 12 |
+
include_cover_page: bool
|
| 13 |
+
include_header_footer: bool
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class ReportRequest(BaseModel):
|
| 17 |
+
format_type: str # "pdf" or "html"
|
| 18 |
+
reportname: str
|
| 19 |
+
include_citations: bool
|
| 20 |
+
success: bool
|
| 21 |
+
list_of_queries: List[str]
|
| 22 |
+
theme: str
|
| 23 |
+
pageoptions: PageOptions
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class ReportResponse(BaseModel):
|
| 27 |
+
report_generation_status: str
|
| 28 |
+
report_url: Optional[str] = None # for PDF storage or link
|
| 29 |
+
html_content: Optional[str] = None # for direct HTML return
|
| 30 |
+
GeGenerated_report_details: Optional[str] = None
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
# ---------- ROUTES ----------
|
| 34 |
+
|
| 35 |
+
@Report_Generation_Router.post("/generate_report", response_model=ReportResponse)
|
| 36 |
+
async def generate_html_report(report_request: ReportRequest):
|
| 37 |
+
"""
|
| 38 |
+
Generates a static HTML report using Gemini AI.
|
| 39 |
+
"""
|
| 40 |
+
try:
|
| 41 |
+
html_content = return_html_report()
|
| 42 |
+
return ReportResponse(
|
| 43 |
+
report_generation_status="success",
|
| 44 |
+
html_content=html_content,
|
| 45 |
+
GeGenerated_report_details="HTML report generated successfully."
|
| 46 |
+
)
|
| 47 |
+
except Exception as e:
|
| 48 |
+
raise HTTPException(status_code=500, detail=f"HTML report generation failed: {str(e)}")
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
@Report_Generation_Router.post("/save_generated_report", response_model=ReportResponse)
|
| 53 |
+
async def save_generated_report(report_request: ReportRequest):
|
| 54 |
+
"""
|
| 55 |
+
Generates HTML and converts it to a PDF file via external API.
|
| 56 |
+
"""
|
| 57 |
+
try:
|
| 58 |
+
html_content = return_html_report()
|
| 59 |
+
generated_report = html_to_pdf_via_api(html_content)
|
| 60 |
+
|
| 61 |
+
# Extract PDF base64 or URL from the response
|
| 62 |
+
pdf_result_info = generated_report.get("pdf_base64", "PDF generated successfully")
|
| 63 |
+
|
| 64 |
+
return ReportResponse(
|
| 65 |
+
report_generation_status="success",
|
| 66 |
+
report_url=None, # you can replace with cloud URL if uploaded
|
| 67 |
+
GeGenerated_report_details=str(pdf_result_info)
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
except Exception as e:
|
| 71 |
+
raise HTTPException(status_code=500, detail=f"PDF report generation failed: {str(e)}")
|
Routes/helpers/autovis_tool.py
ADDED
|
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
# from autoviz import AutoViz_Class
|
| 3 |
+
# import os
|
| 4 |
+
# import pandas as pd
|
| 5 |
+
# import uuid
|
| 6 |
+
|
| 7 |
+
# def run_autoviz(
|
| 8 |
+
# filename,
|
| 9 |
+
# sep=",",
|
| 10 |
+
# depVar="",
|
| 11 |
+
# dfte=None,
|
| 12 |
+
# header=0,
|
| 13 |
+
# verbose=2,
|
| 14 |
+
# lowess=False,
|
| 15 |
+
# chart_format="html",
|
| 16 |
+
# max_rows_analyzed=150000,
|
| 17 |
+
# max_cols_analyzed=30,
|
| 18 |
+
# save_plot_dir=None
|
| 19 |
+
# ):
|
| 20 |
+
# # Generate unique directory using uuid
|
| 21 |
+
# vis_in = str(uuid.uuid4())
|
| 22 |
+
# save_plot_dir = f"./{chart_format}_{vis_in}"
|
| 23 |
+
# os.makedirs(save_plot_dir, exist_ok=True)
|
| 24 |
+
|
| 25 |
+
# # Read the CSV file to check if the path is correct
|
| 26 |
+
# df = pd.read_csv(filename)
|
| 27 |
+
# # print(df.head())
|
| 28 |
+
|
| 29 |
+
# # Run AutoViz
|
| 30 |
+
# AV = AutoViz_Class()
|
| 31 |
+
# print("Running AutoViz...")
|
| 32 |
+
# dft = AV.AutoViz(
|
| 33 |
+
# filename,
|
| 34 |
+
# sep=sep,
|
| 35 |
+
# depVar=depVar,
|
| 36 |
+
# dfte=dfte,
|
| 37 |
+
# header=header,
|
| 38 |
+
# verbose=verbose,
|
| 39 |
+
# lowess=lowess,
|
| 40 |
+
# chart_format=chart_format,
|
| 41 |
+
# max_rows_analyzed=max_rows_analyzed,
|
| 42 |
+
# max_cols_analyzed=max_cols_analyzed,
|
| 43 |
+
# save_plot_dir=save_plot_dir
|
| 44 |
+
# )
|
| 45 |
+
# print(dft)
|
| 46 |
+
# print(f"Plots saved in: {save_plot_dir}")
|
| 47 |
+
# return dft
|
| 48 |
+
|
| 49 |
+
# # Example usage:
|
| 50 |
+
# run_autoviz(
|
| 51 |
+
# filename=r"C:\Users\Dell\Documents\MR-AI\openai_agents\healthcare-data-30.csv"
|
| 52 |
+
# )
|
| 53 |
+
|
| 54 |
+
from fastapi import FastAPI, Request
|
| 55 |
+
from pydantic import BaseModel
|
| 56 |
+
from autoviz import AutoViz_Class
|
| 57 |
+
import os
|
| 58 |
+
import pandas as pd
|
| 59 |
+
import uuid
|
| 60 |
+
from fastapi import APIRouter
|
| 61 |
+
|
| 62 |
+
Autoviz_router = APIRouter(prefix="/autoviz", tags=["autoviz"])
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
class AutoVizParams(BaseModel):
|
| 71 |
+
filename: str
|
| 72 |
+
sep: str = ","
|
| 73 |
+
depVar: str = ""
|
| 74 |
+
header: int = 0
|
| 75 |
+
verbose: int = 2
|
| 76 |
+
lowess: bool = False
|
| 77 |
+
chart_format: str = "html"
|
| 78 |
+
max_rows_analyzed: int = 150000
|
| 79 |
+
max_cols_analyzed: int = 30
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
global save_plot_dir
|
| 83 |
+
def run_autoviz(
|
| 84 |
+
filename,
|
| 85 |
+
sep=",",
|
| 86 |
+
depVar="",
|
| 87 |
+
dfte=None,
|
| 88 |
+
header=0,
|
| 89 |
+
verbose=2,
|
| 90 |
+
lowess=False,
|
| 91 |
+
chart_format="html",
|
| 92 |
+
max_rows_analyzed=150000,
|
| 93 |
+
max_cols_analyzed=30,
|
| 94 |
+
save_plot_dir=None
|
| 95 |
+
):
|
| 96 |
+
vis_in = str(uuid.uuid4())
|
| 97 |
+
save_plot_dir = f"./{chart_format}_{vis_in}"
|
| 98 |
+
os.makedirs(save_plot_dir, exist_ok=True)
|
| 99 |
+
df = pd.read_csv(filename)
|
| 100 |
+
AV = AutoViz_Class()
|
| 101 |
+
dft = AV.AutoViz(
|
| 102 |
+
filename,
|
| 103 |
+
sep=sep,
|
| 104 |
+
depVar=depVar,
|
| 105 |
+
dfte=dfte,
|
| 106 |
+
header=header,
|
| 107 |
+
verbose=verbose,
|
| 108 |
+
lowess=lowess,
|
| 109 |
+
chart_format=chart_format,
|
| 110 |
+
max_rows_analyzed=max_rows_analyzed,
|
| 111 |
+
max_cols_analyzed=max_cols_analyzed,
|
| 112 |
+
save_plot_dir=save_plot_dir
|
| 113 |
+
)
|
| 114 |
+
print("save_plot_dir", save_plot_dir)
|
| 115 |
+
return {"message": "AutoViz run complete", "plots_dir": save_plot_dir}
|
| 116 |
+
|
| 117 |
+
@Autoviz_router.post("/run_autoviz")
|
| 118 |
+
async def autoviz_api(params: AutoVizParams):
|
| 119 |
+
result = run_autoviz(
|
| 120 |
+
filename=params.filename,
|
| 121 |
+
sep=params.sep,
|
| 122 |
+
depVar=params.depVar,
|
| 123 |
+
dfte=None,
|
| 124 |
+
header=params.header,
|
| 125 |
+
verbose=params.verbose,
|
| 126 |
+
lowess=params.lowess,
|
| 127 |
+
chart_format=params.chart_format,
|
| 128 |
+
max_rows_analyzed=params.max_rows_analyzed,
|
| 129 |
+
max_cols_analyzed=params.max_cols_analyzed,
|
| 130 |
+
save_plot_dir=None
|
| 131 |
+
)
|
| 132 |
+
return result
|
Routes/helpers/df_to_vis.py
ADDED
|
@@ -0,0 +1,255 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
import uuid
|
| 3 |
+
import pandas as pd
|
| 4 |
+
import requests
|
| 5 |
+
import tempfile
|
| 6 |
+
import os
|
| 7 |
+
from datetime import datetime
|
| 8 |
+
from langchain_openai import ChatOpenAI
|
| 9 |
+
from langchain.schema import SystemMessage, HumanMessage
|
| 10 |
+
from nbformat import v4, write
|
| 11 |
+
from nbclient import NotebookClient
|
| 12 |
+
import sys
|
| 13 |
+
from pathlib import Path
|
| 14 |
+
import duckdb
|
| 15 |
+
project_root = Path(__file__).resolve().parents[2]
|
| 16 |
+
if str(project_root) not in sys.path:
|
| 17 |
+
sys.path.insert(0, str(project_root))
|
| 18 |
+
from nbclient.exceptions import CellExecutionError
|
| 19 |
+
from s3.read_files import read_csv_from_s3
|
| 20 |
+
|
| 21 |
+
# Global flag
|
| 22 |
+
IS_VIS = False
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def generate_plotly_visualization(
|
| 26 |
+
result_df: pd.DataFrame,
|
| 27 |
+
user_prompt: str,
|
| 28 |
+
is_vis: bool = IS_VIS,
|
| 29 |
+
upload_to_s3: bool = True,
|
| 30 |
+
s3_path: str = "vatsav/artifacts/"
|
| 31 |
+
) -> dict:
|
| 32 |
+
"""
|
| 33 |
+
Takes a pandas DataFrame and generates a Plotly visualization via LLM,
|
| 34 |
+
saves it as a temporary HTML file, and optionally uploads to S3.
|
| 35 |
+
The temporary file is automatically cleaned up after upload.
|
| 36 |
+
|
| 37 |
+
Args:
|
| 38 |
+
result_df (pd.DataFrame): Input DataFrame to visualize.
|
| 39 |
+
user_prompt (str): User prompt for visualization instructions.
|
| 40 |
+
is_vis (bool): Flag to enable/disable visualization generation. Default is global IS_VIS.
|
| 41 |
+
upload_to_s3 (bool): Flag to enable/disable S3 upload. Default is True.
|
| 42 |
+
s3_path (str): S3 path for upload. Default is "vatsav/artifacts/".
|
| 43 |
+
|
| 44 |
+
Returns:
|
| 45 |
+
dict: Metadata dictionary with viz_id, timestamp, s3_url, and status.
|
| 46 |
+
"""
|
| 47 |
+
|
| 48 |
+
viz_id = str(uuid.uuid4())
|
| 49 |
+
timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
|
| 50 |
+
|
| 51 |
+
# If visualization is disabled, return early
|
| 52 |
+
if not is_vis:
|
| 53 |
+
metadata = {
|
| 54 |
+
"viz_id": viz_id,
|
| 55 |
+
"timestamp": timestamp,
|
| 56 |
+
"result": result_df.to_dict(),
|
| 57 |
+
"status": "success",
|
| 58 |
+
"message": "Visualization disabled (is_vis=False)",
|
| 59 |
+
"user_prompt": user_prompt
|
| 60 |
+
}
|
| 61 |
+
print("user_prompt:", user_prompt)
|
| 62 |
+
print(f"ℹ️ Visualization skipped (is_vis=False)")
|
| 63 |
+
return metadata
|
| 64 |
+
|
| 65 |
+
try:
|
| 66 |
+
# 1. Prepare prompt for Plotly visualization assistant
|
| 67 |
+
plotly_agent_prompt = f"""
|
| 68 |
+
You are a cutting-edge Python visualization expert specializing in generating highly aesthetic and functional charts using the Plotly library.
|
| 69 |
+
|
| 70 |
+
Your task is to analyze the provided pandas DataFrame and generate a visualization that adheres to a smart, minimalist, and futuristic aesthetic.
|
| 71 |
+
|
| 72 |
+
User request: "{user_prompt}"
|
| 73 |
+
|
| 74 |
+
Follow the user request closely — especially regarding chart type or any explicit instructions.
|
| 75 |
+
|
| 76 |
+
- Intelligent Data Analysis: Automatically analyze the DataFrame's columns to identify the optimal axis assignments (categorical/time-series for X, numeric/metric for Y). State the reason for the chosen assignment.
|
| 77 |
+
- Smart Chart Selection: Determine and state the single most effective chart type (Bar, Line, Scatter, Histogram, Pie, etc.) that best represents the relationship.
|
| 78 |
+
- Generate Plotly Python Code: Write concise, ready-to-run Python code using plotly.express or plotly.graph_objects.
|
| 79 |
+
|
| 80 |
+
Aesthetic Requirements (Futuristic & Smart)
|
| 81 |
+
- Theme: Apply a dark, high-tech, futuristic template with `template='plotly_dark'`.
|
| 82 |
+
- Color Palette: Use high-contrast, digital-inspired colors [#4477A0, #FF9933, #B266CC, #33B2B2, #66CC66].
|
| 83 |
+
- Labels & Title: Include a professional title and descriptive axes.
|
| 84 |
+
- Annotation: For bar or pie charts, display metric values directly.
|
| 85 |
+
|
| 86 |
+
Required Output Format
|
| 87 |
+
- Suggested chart type: [Chart Type]
|
| 88 |
+
- Analysis and Rationale: [Why chart type and columns were chosen]
|
| 89 |
+
- Python Plotly code:
|
| 90 |
+
"""
|
| 91 |
+
|
| 92 |
+
# 2. Initialize ChatOpenAI
|
| 93 |
+
chat = ChatOpenAI(model="gpt-4o-mini", temperature=0.7)
|
| 94 |
+
|
| 95 |
+
# 3. Prepare messages with system prompt and DataFrame as string
|
| 96 |
+
messages = [
|
| 97 |
+
SystemMessage(content=plotly_agent_prompt),
|
| 98 |
+
HumanMessage(content=result_df.to_string())
|
| 99 |
+
]
|
| 100 |
+
|
| 101 |
+
# 4. Get LLM response with Plotly code
|
| 102 |
+
response = chat.invoke(messages) # Use invoke instead of __call__
|
| 103 |
+
response_content = response.content
|
| 104 |
+
|
| 105 |
+
# 5. Extract Python Plotly code from response (between ```python ... ```)
|
| 106 |
+
code_match = re.search(r"```python\n(.*?)```", response_content, re.DOTALL)
|
| 107 |
+
plotly_code = code_match.group(1).strip() if code_match else response_content.strip()
|
| 108 |
+
|
| 109 |
+
# 6. Execute the Plotly code in a notebook client and save the plot HTML
|
| 110 |
+
|
| 111 |
+
# Create a new notebook object
|
| 112 |
+
nb = v4.new_notebook()
|
| 113 |
+
nb.cells.append(v4.new_code_cell(plotly_code))
|
| 114 |
+
client = NotebookClient(nb, timeout=60, kernel_name="python3")
|
| 115 |
+
|
| 116 |
+
try:
|
| 117 |
+
client.execute()
|
| 118 |
+
except CellExecutionError as e:
|
| 119 |
+
raise RuntimeError(f"Error executing Plotly code: {e}")
|
| 120 |
+
|
| 121 |
+
# Save notebook for debugging (optional)
|
| 122 |
+
with open("temp_notebook.ipynb", "w", encoding="utf-8") as f:
|
| 123 |
+
write(nb, f)
|
| 124 |
+
|
| 125 |
+
# 7. Extract figure object by executing code in local scope
|
| 126 |
+
local_scope = {}
|
| 127 |
+
exec(plotly_code, {}, local_scope)
|
| 128 |
+
fig = local_scope.get("fig")
|
| 129 |
+
|
| 130 |
+
if fig is None:
|
| 131 |
+
raise RuntimeError("No Plotly figure named 'fig' found in the generated code.")
|
| 132 |
+
|
| 133 |
+
# 8. Create a temporary file for the HTML and write to it
|
| 134 |
+
with tempfile.NamedTemporaryFile(mode='w', suffix='.html', prefix='plotly_viz_', delete=False, encoding='utf-8') as temp_file:
|
| 135 |
+
temp_path = temp_file.name
|
| 136 |
+
|
| 137 |
+
# Write HTML to the temp file path (Plotly handles encoding internally)
|
| 138 |
+
fig.write_html(temp_path)
|
| 139 |
+
|
| 140 |
+
print(f"✅ Plotly visualization saved to temporary file")
|
| 141 |
+
|
| 142 |
+
try:
|
| 143 |
+
metadata = {
|
| 144 |
+
"viz_id": viz_id,
|
| 145 |
+
"timestamp": timestamp,
|
| 146 |
+
"status": "success"
|
| 147 |
+
}
|
| 148 |
+
|
| 149 |
+
# 9. Upload to S3 if enabled
|
| 150 |
+
if upload_to_s3:
|
| 151 |
+
try:
|
| 152 |
+
s3_url = upload_file_to_s3(temp_path, s3_path)
|
| 153 |
+
metadata["s3_url"] = s3_url
|
| 154 |
+
metadata["s3_status"] = "uploaded"
|
| 155 |
+
print(f"✅ File uploaded to S3: {s3_url}")
|
| 156 |
+
except Exception as e:
|
| 157 |
+
metadata["s3_status"] = "failed"
|
| 158 |
+
metadata["s3_error"] = str(e)
|
| 159 |
+
print(f"❌ S3 upload failed: {str(e)}")
|
| 160 |
+
|
| 161 |
+
print(f"✅ Process completed (ID: {viz_id})")
|
| 162 |
+
|
| 163 |
+
return metadata
|
| 164 |
+
|
| 165 |
+
finally:
|
| 166 |
+
# 10. Clean up temporary file
|
| 167 |
+
if os.path.exists(temp_path):
|
| 168 |
+
os.remove(temp_path)
|
| 169 |
+
print(f"🗑️ Temporary file cleaned up")
|
| 170 |
+
|
| 171 |
+
except Exception as e:
|
| 172 |
+
return {
|
| 173 |
+
"viz_id": viz_id,
|
| 174 |
+
"timestamp": timestamp,
|
| 175 |
+
"error": str(e),
|
| 176 |
+
"status": "failed"
|
| 177 |
+
}
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
def upload_file_to_s3(file_path: str, s3_path: str = "vatsav/artifacts/") -> str:
|
| 181 |
+
"""
|
| 182 |
+
Uploads a file to S3 using the HuggingFace Space API.
|
| 183 |
+
|
| 184 |
+
Args:
|
| 185 |
+
file_path (str): Local path to the file to upload.
|
| 186 |
+
s3_path (str): S3 path where the file should be uploaded.
|
| 187 |
+
|
| 188 |
+
Returns:
|
| 189 |
+
str: S3 URL of the uploaded file.
|
| 190 |
+
|
| 191 |
+
Raises:
|
| 192 |
+
Exception: If upload fails.
|
| 193 |
+
"""
|
| 194 |
+
api_url = f"https://srivatsavdamaraju-mvp-2-0-deploy-all-apis.hf.space/s3/upload/?path={s3_path}"
|
| 195 |
+
|
| 196 |
+
# Determine content type based on file extension
|
| 197 |
+
if file_path.endswith('.html'):
|
| 198 |
+
content_type = 'text/html'
|
| 199 |
+
elif file_path.endswith('.json'):
|
| 200 |
+
content_type = 'application/json'
|
| 201 |
+
else:
|
| 202 |
+
content_type = 'application/octet-stream'
|
| 203 |
+
|
| 204 |
+
# Open and upload the file
|
| 205 |
+
with open(file_path, 'rb') as f:
|
| 206 |
+
files = {
|
| 207 |
+
'file': (Path(file_path).name, f, content_type)
|
| 208 |
+
}
|
| 209 |
+
headers = {
|
| 210 |
+
'accept': 'application/json'
|
| 211 |
+
}
|
| 212 |
+
|
| 213 |
+
response = requests.post(api_url, headers=headers, files=files)
|
| 214 |
+
|
| 215 |
+
if response.status_code == 200:
|
| 216 |
+
response_data = response.json()
|
| 217 |
+
# Assuming the API returns the S3 URL in the response
|
| 218 |
+
s3_url = response_data.get('url') or response_data.get('file_url') or f"{s3_path}{Path(file_path).name}"
|
| 219 |
+
return s3_url
|
| 220 |
+
else:
|
| 221 |
+
raise Exception(f"Upload failed with status {response.status_code}: {response.text}")
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
# # Example usage:
|
| 225 |
+
# data = {
|
| 226 |
+
# 'smoking_status': ['formerly smoked', 'smokes', 'never smoked', 'Unknown'],
|
| 227 |
+
# 'stroke_percentage': [66.04, 51.22, 50.00, 39.17]
|
| 228 |
+
# }
|
| 229 |
+
|
| 230 |
+
# # Create the DataFrame
|
| 231 |
+
# df = pd.DataFrame(data)
|
| 232 |
+
|
| 233 |
+
# # Without visualization
|
| 234 |
+
# # result = generate_plotly_visualization(df, user_prompt="", is_vis=False)
|
| 235 |
+
# # print(result)
|
| 236 |
+
|
| 237 |
+
# # With visualization and S3 upload (temp file auto-cleaned)
|
| 238 |
+
# user_query = "show me a pie chart of smoking status vs stroke percentage"
|
| 239 |
+
# result = generate_plotly_visualization(
|
| 240 |
+
# df,
|
| 241 |
+
# user_prompt=user_query,
|
| 242 |
+
# is_vis=True,
|
| 243 |
+
# upload_to_s3=True,
|
| 244 |
+
# s3_path="vatsav/artifacts/tests/"
|
| 245 |
+
# )
|
| 246 |
+
# print(result)
|
| 247 |
+
|
| 248 |
+
# With visualization but no S3 upload (temp file auto-cleaned)
|
| 249 |
+
# result = generate_plotly_visualization(
|
| 250 |
+
# df,
|
| 251 |
+
# user_prompt=user_query,
|
| 252 |
+
# is_vis=True,
|
| 253 |
+
# upload_to_s3=False
|
| 254 |
+
# )
|
| 255 |
+
# print(result)
|
Routes/helpers/duck_db_agent.py
ADDED
|
@@ -0,0 +1,321 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import asyncio
|
| 2 |
+
import duckdb
|
| 3 |
+
import sys
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
from datetime import datetime
|
| 6 |
+
from typing import Annotated, Optional, List
|
| 7 |
+
from functools import lru_cache
|
| 8 |
+
|
| 9 |
+
import pandas as pd
|
| 10 |
+
from pydantic import BaseModel, Field, ConfigDict
|
| 11 |
+
from agents import Agent, Runner, function_tool
|
| 12 |
+
import os
|
| 13 |
+
import dotenv
|
| 14 |
+
|
| 15 |
+
# =====================================
|
| 16 |
+
# 🔹 CONFIGURATION
|
| 17 |
+
# =====================================
|
| 18 |
+
dotenv.load_dotenv()
|
| 19 |
+
|
| 20 |
+
# Ensure project root is on sys.path
|
| 21 |
+
PROJECT_ROOT = Path(__file__).resolve().parents[2]
|
| 22 |
+
if str(PROJECT_ROOT) not in sys.path:
|
| 23 |
+
sys.path.insert(0, str(PROJECT_ROOT))
|
| 24 |
+
|
| 25 |
+
# Validate environment
|
| 26 |
+
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
|
| 27 |
+
if not OPENAI_API_KEY:
|
| 28 |
+
raise ValueError("OPENAI_API_KEY not found in environment variables")
|
| 29 |
+
os.environ["OPENAI_API_KEY"] = OPENAI_API_KEY
|
| 30 |
+
|
| 31 |
+
# Local import after path setup
|
| 32 |
+
from s3.read_files import read_csv_from_s3
|
| 33 |
+
|
| 34 |
+
# Constants
|
| 35 |
+
S3_DATASET_PATH = "vatsav/csv/mini_health-care_dataset.csv"
|
| 36 |
+
ALLOWED_SQL_KEYWORDS = ("select", "with")
|
| 37 |
+
MAX_RESULT_ROWS = 10000
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
# =====================================
|
| 41 |
+
# 🔹 PYDANTIC MODELS (Strict Schema Compatible)
|
| 42 |
+
# =====================================
|
| 43 |
+
class SQLResult(BaseModel):
|
| 44 |
+
"""Result from SQL query execution."""
|
| 45 |
+
model_config = ConfigDict(extra='forbid')
|
| 46 |
+
|
| 47 |
+
sql_query: str = Field(..., description="The executed SQL query")
|
| 48 |
+
result_summary: str = Field(..., description="Human-readable summary of results")
|
| 49 |
+
columns: List[str] = Field(..., description="Column names from result")
|
| 50 |
+
row_count: int = Field(..., description="Number of rows returned")
|
| 51 |
+
timestamp: str = Field(..., description="Execution timestamp")
|
| 52 |
+
execution_time_ms: float = Field(..., description="Query execution time in milliseconds")
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
class PlotResult(BaseModel):
|
| 56 |
+
"""Result from plot generation."""
|
| 57 |
+
model_config = ConfigDict(extra='forbid')
|
| 58 |
+
|
| 59 |
+
viz_id: str = Field(..., description="Unique visualization identifier")
|
| 60 |
+
file_path: str = Field(..., description="Path to saved visualization file")
|
| 61 |
+
status: str = Field(..., description="Status of visualization generation")
|
| 62 |
+
chart_type: str = Field(..., description="Type of chart created")
|
| 63 |
+
timestamp: str = Field(..., description="Generation timestamp")
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
# =====================================
|
| 67 |
+
# 🔹 DATA LOADING WITH CACHING
|
| 68 |
+
# =====================================
|
| 69 |
+
@lru_cache(maxsize=1)
|
| 70 |
+
def get_cached_dataset(s3_path: str) -> pd.DataFrame:
|
| 71 |
+
"""Load dataset from S3 with caching."""
|
| 72 |
+
print(f"📥 Loading dataset from S3: {s3_path}")
|
| 73 |
+
df = read_csv_from_s3(s3_path)
|
| 74 |
+
print(f"✅ Dataset loaded: {len(df)} rows, {len(df.columns)} columns")
|
| 75 |
+
return df
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
# Store the last query result globally for use by plotly_tool
|
| 79 |
+
_last_query_result = None
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
# =====================================
|
| 83 |
+
# 🔹 SQL TOOL — performs data analytics
|
| 84 |
+
# =====================================
|
| 85 |
+
@function_tool
|
| 86 |
+
def sql_tool(sql_query: Annotated[str, "SQL query to execute on healthcare dataset"]) -> SQLResult:
|
| 87 |
+
"""
|
| 88 |
+
Execute SQL query on healthcare dataset (table: CSV1) and return results with metadata.
|
| 89 |
+
|
| 90 |
+
Supported operations: SELECT, WITH (CTE)
|
| 91 |
+
Available table: CSV1 (healthcare dataset)
|
| 92 |
+
|
| 93 |
+
Example queries:
|
| 94 |
+
- SELECT diagnosis, COUNT(*) as count FROM CSV1 GROUP BY diagnosis
|
| 95 |
+
- SELECT * FROM CSV1 WHERE age > 50 LIMIT 10
|
| 96 |
+
"""
|
| 97 |
+
global _last_query_result
|
| 98 |
+
|
| 99 |
+
start_time = datetime.now()
|
| 100 |
+
timestamp = start_time.strftime("%Y-%m-%d_%H-%M-%S")
|
| 101 |
+
|
| 102 |
+
try:
|
| 103 |
+
# Validate query
|
| 104 |
+
normalized_query = sql_query.strip().lower()
|
| 105 |
+
if not normalized_query.startswith(ALLOWED_SQL_KEYWORDS):
|
| 106 |
+
raise ValueError(
|
| 107 |
+
f"Only {', '.join(k.upper() for k in ALLOWED_SQL_KEYWORDS)} queries are allowed."
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
# Load dataset (cached)
|
| 111 |
+
df = get_cached_dataset(S3_DATASET_PATH)
|
| 112 |
+
|
| 113 |
+
# Execute query using DuckDB
|
| 114 |
+
conn = duckdb.connect(":memory:")
|
| 115 |
+
conn.register("CSV1", df)
|
| 116 |
+
|
| 117 |
+
result_df = conn.execute(sql_query).df()
|
| 118 |
+
conn.close()
|
| 119 |
+
|
| 120 |
+
# Validate result size
|
| 121 |
+
if len(result_df) > MAX_RESULT_ROWS:
|
| 122 |
+
result_df = result_df.head(MAX_RESULT_ROWS)
|
| 123 |
+
|
| 124 |
+
# Calculate execution time
|
| 125 |
+
execution_time = (datetime.now() - start_time).total_seconds() * 1000
|
| 126 |
+
|
| 127 |
+
# Store result for potential visualization
|
| 128 |
+
_last_query_result = result_df.copy()
|
| 129 |
+
|
| 130 |
+
# Create human-readable summary
|
| 131 |
+
summary_lines = [
|
| 132 |
+
f"Query returned {len(result_df)} rows with {len(result_df.columns)} columns: {', '.join(result_df.columns.tolist())}"
|
| 133 |
+
]
|
| 134 |
+
|
| 135 |
+
# Add preview of results
|
| 136 |
+
if not result_df.empty:
|
| 137 |
+
summary_lines.append("\nResults preview:")
|
| 138 |
+
# Format as table string
|
| 139 |
+
preview = result_df.head(10).to_string(index=False, max_rows=10)
|
| 140 |
+
summary_lines.append(preview)
|
| 141 |
+
|
| 142 |
+
result_summary = "\n".join(summary_lines)
|
| 143 |
+
|
| 144 |
+
print(f"✅ SQL executed: {len(result_df)} rows in {execution_time:.2f}ms")
|
| 145 |
+
|
| 146 |
+
return SQLResult(
|
| 147 |
+
sql_query=sql_query,
|
| 148 |
+
result_summary=result_summary,
|
| 149 |
+
columns=result_df.columns.tolist(),
|
| 150 |
+
row_count=len(result_df),
|
| 151 |
+
timestamp=timestamp,
|
| 152 |
+
execution_time_ms=round(execution_time, 2)
|
| 153 |
+
)
|
| 154 |
+
|
| 155 |
+
except ValueError as ve:
|
| 156 |
+
raise ValueError(f"SQL validation error: {str(ve)}")
|
| 157 |
+
except Exception as e:
|
| 158 |
+
raise RuntimeError(f"SQL execution error: {str(e)}")
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
# =====================================
|
| 162 |
+
# 🔹 PLOTLY TOOL — generates visualization
|
| 163 |
+
# =====================================
|
| 164 |
+
@function_tool
|
| 165 |
+
def plotly_tool(
|
| 166 |
+
chart_type: Annotated[str, "Type of chart: bar, line, scatter, or pie"] = "bar",
|
| 167 |
+
x_column: Annotated[Optional[str], "Column name for x-axis"] = None,
|
| 168 |
+
y_column: Annotated[Optional[str], "Column name for y-axis"] = None,
|
| 169 |
+
title: Annotated[Optional[str], "Custom chart title"] = None
|
| 170 |
+
) -> PlotResult:
|
| 171 |
+
"""
|
| 172 |
+
Generate interactive Plotly visualization from the last SQL query result.
|
| 173 |
+
Must be called after sql_tool.
|
| 174 |
+
"""
|
| 175 |
+
global _last_query_result
|
| 176 |
+
|
| 177 |
+
import plotly.express as px
|
| 178 |
+
import plotly.graph_objects as go
|
| 179 |
+
from uuid import uuid4
|
| 180 |
+
|
| 181 |
+
if _last_query_result is None:
|
| 182 |
+
raise ValueError("No SQL query result available. Please run sql_tool first.")
|
| 183 |
+
|
| 184 |
+
viz_id = str(uuid4())[:8]
|
| 185 |
+
timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
|
| 186 |
+
output_dir = PROJECT_ROOT / "outputs" / "visualizations"
|
| 187 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 188 |
+
|
| 189 |
+
file_path = output_dir / f"viz_{viz_id}_{timestamp}.html"
|
| 190 |
+
|
| 191 |
+
try:
|
| 192 |
+
df = _last_query_result.copy()
|
| 193 |
+
|
| 194 |
+
if df.empty:
|
| 195 |
+
raise ValueError("Cannot create visualization from empty result set")
|
| 196 |
+
|
| 197 |
+
# Auto-select x and y columns if not provided
|
| 198 |
+
if x_column is None:
|
| 199 |
+
x_column = df.columns[0]
|
| 200 |
+
if y_column is None:
|
| 201 |
+
y_column = df.columns[1] if len(df.columns) > 1 else df.columns[0]
|
| 202 |
+
|
| 203 |
+
# Validate columns exist
|
| 204 |
+
if x_column not in df.columns:
|
| 205 |
+
raise ValueError(f"Column '{x_column}' not found. Available: {list(df.columns)}")
|
| 206 |
+
if y_column not in df.columns:
|
| 207 |
+
raise ValueError(f"Column '{y_column}' not found. Available: {list(df.columns)}")
|
| 208 |
+
|
| 209 |
+
# Generate chart title
|
| 210 |
+
if title is None:
|
| 211 |
+
title = f"{y_column} by {x_column}"
|
| 212 |
+
|
| 213 |
+
# Generate appropriate chart
|
| 214 |
+
chart_mapping = {
|
| 215 |
+
"bar": px.bar,
|
| 216 |
+
"line": px.line,
|
| 217 |
+
"scatter": px.scatter,
|
| 218 |
+
"pie": px.pie
|
| 219 |
+
}
|
| 220 |
+
|
| 221 |
+
plot_func = chart_mapping.get(chart_type.lower(), px.bar)
|
| 222 |
+
|
| 223 |
+
if chart_type.lower() == "pie":
|
| 224 |
+
fig = plot_func(df, names=x_column, values=y_column, title=title)
|
| 225 |
+
else:
|
| 226 |
+
fig = plot_func(df, x=x_column, y=y_column, title=title)
|
| 227 |
+
|
| 228 |
+
# Enhanced styling
|
| 229 |
+
fig.update_layout(
|
| 230 |
+
template="plotly_white",
|
| 231 |
+
height=600,
|
| 232 |
+
showlegend=True,
|
| 233 |
+
font=dict(size=12),
|
| 234 |
+
title_font_size=16
|
| 235 |
+
)
|
| 236 |
+
|
| 237 |
+
fig.write_html(str(file_path))
|
| 238 |
+
print(f"📊 Visualization created: {chart_type} chart saved to {file_path.name}")
|
| 239 |
+
|
| 240 |
+
return PlotResult(
|
| 241 |
+
viz_id=viz_id,
|
| 242 |
+
file_path=str(file_path),
|
| 243 |
+
status="success",
|
| 244 |
+
chart_type=chart_type,
|
| 245 |
+
timestamp=timestamp
|
| 246 |
+
)
|
| 247 |
+
|
| 248 |
+
except Exception as e:
|
| 249 |
+
raise RuntimeError(f"Plotly error: {str(e)}")
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
# =====================================
|
| 253 |
+
# 🔹 AGENT DEFINITION
|
| 254 |
+
# =====================================
|
| 255 |
+
agent = Agent(
|
| 256 |
+
name="Healthcare Data Analyst",
|
| 257 |
+
instructions=(
|
| 258 |
+
"You are a healthcare data analyst. Analyze queries and generate SQL + visualizations.\n\n"
|
| 259 |
+
"Available dataset: CSV1 (healthcare patient data)\n\n"
|
| 260 |
+
"Workflow:\n"
|
| 261 |
+
"1. Generate SQL query using sql_tool to analyze CSV1\n"
|
| 262 |
+
"2. Review the result_summary to understand the data\n"
|
| 263 |
+
"3. If visualization is needed, use plotly_tool (it uses the last SQL result)\n"
|
| 264 |
+
"4. Provide clear insights with context\n\n"
|
| 265 |
+
"SQL Best Practices:\n"
|
| 266 |
+
"- Use descriptive aliases\n"
|
| 267 |
+
"- Add ORDER BY for readability\n"
|
| 268 |
+
"- Use LIMIT for large datasets\n\n"
|
| 269 |
+
"Visualization Guidelines:\n"
|
| 270 |
+
"- bar: counts, comparisons\n"
|
| 271 |
+
"- line: trends, time series\n"
|
| 272 |
+
"- pie: proportions (5-7 categories max)\n"
|
| 273 |
+
"- scatter: correlations\n\n"
|
| 274 |
+
"Always interpret results for non-technical users."
|
| 275 |
+
),
|
| 276 |
+
tools=[sql_tool],
|
| 277 |
+
)
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
# =====================================
|
| 281 |
+
# 🔹 MAIN RUNNER
|
| 282 |
+
# =====================================
|
| 283 |
+
async def run_analysis(user_query: str, verbose: bool = True) -> str:
|
| 284 |
+
"""Run the agent with the given query and return results."""
|
| 285 |
+
if verbose:
|
| 286 |
+
print(f"\n{'='*70}")
|
| 287 |
+
print(f"📋 Query: {user_query}")
|
| 288 |
+
print(f"{'='*70}\n")
|
| 289 |
+
|
| 290 |
+
try:
|
| 291 |
+
result = await Runner.run(agent, input=user_query)
|
| 292 |
+
return result.final_output
|
| 293 |
+
except Exception as e:
|
| 294 |
+
print(f"❌ Error: {str(e)}")
|
| 295 |
+
raise
|
| 296 |
+
|
| 297 |
+
|
| 298 |
+
async def main():
|
| 299 |
+
"""Main entry point."""
|
| 300 |
+
user_query = (
|
| 301 |
+
"Show how many patients exist in each diagnosis category, "
|
| 302 |
+
"filtered for married individuals. Create a bar chart to visualize this."
|
| 303 |
+
)
|
| 304 |
+
|
| 305 |
+
try:
|
| 306 |
+
output = await run_analysis(user_query)
|
| 307 |
+
|
| 308 |
+
# Clean output formatting
|
| 309 |
+
print("\n" + "="*70)
|
| 310 |
+
print("📊 ANALYSIS COMPLETE")
|
| 311 |
+
print("="*70)
|
| 312 |
+
print(output)
|
| 313 |
+
print("="*70 + "\n")
|
| 314 |
+
|
| 315 |
+
except Exception as e:
|
| 316 |
+
print(f"\n❌ Analysis failed: {e}\n")
|
| 317 |
+
raise
|
| 318 |
+
|
| 319 |
+
|
| 320 |
+
if __name__ == "__main__":
|
| 321 |
+
asyncio.run(main())
|
Routes/helpers/main_agent_best_as_of_now.py
ADDED
|
@@ -0,0 +1,700 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import duckdb
|
| 2 |
+
import sys
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
from datetime import datetime
|
| 5 |
+
from typing import Annotated, Optional, List
|
| 6 |
+
import re
|
| 7 |
+
import json
|
| 8 |
+
import pandas as pd
|
| 9 |
+
from pydantic import BaseModel, Field, ConfigDict
|
| 10 |
+
from fastapi import APIRouter, HTTPException
|
| 11 |
+
from fastapi.responses import JSONResponse
|
| 12 |
+
import os
|
| 13 |
+
import dotenv
|
| 14 |
+
import tempfile
|
| 15 |
+
from langchain_openai import ChatOpenAI
|
| 16 |
+
from langchain.schema import SystemMessage, HumanMessage
|
| 17 |
+
import requests
|
| 18 |
+
import asyncio
|
| 19 |
+
import sqlparse
|
| 20 |
+
|
| 21 |
+
# Project root setup
|
| 22 |
+
PROJECT_ROOT = Path(__file__).resolve().parents[2]
|
| 23 |
+
if str(PROJECT_ROOT) not in sys.path:
|
| 24 |
+
sys.path.insert(0, str(PROJECT_ROOT))
|
| 25 |
+
|
| 26 |
+
# Assuming these are custom modules in your project
|
| 27 |
+
from agents import Agent, Runner, function_tool
|
| 28 |
+
from s3.read_files import read_csv_from_s3
|
| 29 |
+
|
| 30 |
+
# =====================================
|
| 31 |
+
# 🔹 CONFIGURATION
|
| 32 |
+
# =====================================
|
| 33 |
+
dotenv.load_dotenv()
|
| 34 |
+
|
| 35 |
+
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
|
| 36 |
+
if not OPENAI_API_KEY:
|
| 37 |
+
raise ValueError("OPENAI_API_KEY not found in environment variables")
|
| 38 |
+
os.environ["OPENAI_API_KEY"] = OPENAI_API_KEY
|
| 39 |
+
|
| 40 |
+
# Constants
|
| 41 |
+
ALLOWED_SQL_KEYWORDS = ("select", "with")
|
| 42 |
+
MAX_RESULT_ROWS = 10000
|
| 43 |
+
|
| 44 |
+
# Global storage for datasets and context
|
| 45 |
+
_registered_datasets = {}
|
| 46 |
+
_dataset_context = {}
|
| 47 |
+
last_sql_query: str = ""
|
| 48 |
+
chartjs_path = [] # Global variable to store Chart.js JSON config path or S3 URL
|
| 49 |
+
|
| 50 |
+
def is_safe_query(sql_query: str) -> bool:
|
| 51 |
+
parsed = sqlparse.parse(sql_query)[0]
|
| 52 |
+
forbidden_keywords = {"DROP", "DELETE", "UPDATE", "INSERT"}
|
| 53 |
+
tokens = [token.value.upper() for token in parsed.tokens if token.is_keyword]
|
| 54 |
+
return not any(keyword in forbidden_keywords for keyword in tokens)
|
| 55 |
+
|
| 56 |
+
# =====================================
|
| 57 |
+
# 🔹 S3 UPLOAD FUNCTION
|
| 58 |
+
# =====================================
|
| 59 |
+
|
| 60 |
+
# =====================================
|
| 61 |
+
# 🔹 UTILITY FUNCTIONS
|
| 62 |
+
# =====================================
|
| 63 |
+
def save_convo_id_to_pgdb(convo_id: str):
|
| 64 |
+
"""Save conversation ID to PostgreSQL database (stub function)."""
|
| 65 |
+
print(f"Saving conversation ID to PostgreSQL DB: {convo_id}")
|
| 66 |
+
pass
|
| 67 |
+
|
| 68 |
+
# =====================================
|
| 69 |
+
# 🔹 PYDANTIC MODELS
|
| 70 |
+
# =====================================
|
| 71 |
+
class AnalysisRequest(BaseModel):
|
| 72 |
+
"""Request model for data analysis."""
|
| 73 |
+
file_paths: List[str] = Field(..., description="List of S3 file paths to analyze")
|
| 74 |
+
query: str = Field(..., description="Natural language query for analysis")
|
| 75 |
+
|
| 76 |
+
model_config = ConfigDict(extra='forbid')
|
| 77 |
+
|
| 78 |
+
class SQLResult(BaseModel):
|
| 79 |
+
"""Result from SQL query execution."""
|
| 80 |
+
model_config = ConfigDict(extra='forbid')
|
| 81 |
+
|
| 82 |
+
sql_query: str = Field(..., description="The executed SQL query")
|
| 83 |
+
result_summary: str = Field(..., description="Human-readable summary of results")
|
| 84 |
+
columns: List[str] = Field(..., description="Column names from result")
|
| 85 |
+
row_count: int = Field(..., description="Number of rows returned")
|
| 86 |
+
execution_time_ms: float = Field(..., description="Query execution time in milliseconds")
|
| 87 |
+
|
| 88 |
+
class AnalysisResponse(BaseModel):
|
| 89 |
+
"""Response from analysis endpoint."""
|
| 90 |
+
model_config = ConfigDict(extra='forbid')
|
| 91 |
+
|
| 92 |
+
query: str = Field(..., description="Original user query")
|
| 93 |
+
datasets_loaded: List[str] = Field(..., description="List of loaded datasets")
|
| 94 |
+
analysis_result: str = Field(..., description="Analysis insights")
|
| 95 |
+
sql_executed: str = Field(..., description="SQL query that was executed")
|
| 96 |
+
execution_time_ms: float = Field(..., description="Total execution time")
|
| 97 |
+
|
| 98 |
+
# =====================================
|
| 99 |
+
# 🔹 DATASET MANAGEMENT & CONTEXT ENGINEERING
|
| 100 |
+
# =====================================
|
| 101 |
+
def generate_dataset_context(df: pd.DataFrame, table_name: str) -> str:
|
| 102 |
+
"""Generate comprehensive context for a dataset."""
|
| 103 |
+
context_parts = []
|
| 104 |
+
|
| 105 |
+
context_parts.append(f"Table: {table_name}")
|
| 106 |
+
context_parts.append(f"Rows: {len(df)}, Columns: {len(df.columns)}")
|
| 107 |
+
context_parts.append(f"Column Names: {', '.join(df.columns.tolist())}")
|
| 108 |
+
|
| 109 |
+
context_parts.append("\nData Types:")
|
| 110 |
+
for col, dtype in df.dtypes.items():
|
| 111 |
+
context_parts.append(f" - {col}: {dtype}")
|
| 112 |
+
|
| 113 |
+
context_parts.append("\nSample Data (first 3 rows):")
|
| 114 |
+
sample_str = df.head(3).to_string(index=False, max_cols=10)
|
| 115 |
+
context_parts.append(sample_str)
|
| 116 |
+
|
| 117 |
+
numeric_cols = df.select_dtypes(include=['number']).columns.tolist()
|
| 118 |
+
if numeric_cols:
|
| 119 |
+
context_parts.append("\nNumeric Column Statistics:")
|
| 120 |
+
desc = df[numeric_cols].describe().transpose()
|
| 121 |
+
for col in numeric_cols[:5]:
|
| 122 |
+
if col in desc.index:
|
| 123 |
+
context_parts.append(f" - {col}: mean={desc.loc[col, 'mean']:.2f}, "
|
| 124 |
+
f"min={desc.loc[col, 'min']:.2f}, max={desc.loc[col, 'max']:.2f}")
|
| 125 |
+
|
| 126 |
+
categorical_cols = df.select_dtypes(include=['object', 'category']).columns.tolist()
|
| 127 |
+
if categorical_cols:
|
| 128 |
+
context_parts.append("\nCategorical Columns:")
|
| 129 |
+
for col in categorical_cols[:5]:
|
| 130 |
+
unique_count = df[col].nunique()
|
| 131 |
+
context_parts.append(f" - {col}: {unique_count} unique values")
|
| 132 |
+
if unique_count <= 5:
|
| 133 |
+
values = df[col].value_counts().head(5).to_dict()
|
| 134 |
+
context_parts.append(f" Values: {values}")
|
| 135 |
+
|
| 136 |
+
return "\n".join(context_parts)
|
| 137 |
+
|
| 138 |
+
def load_and_register_datasets(file_paths: List[str]) -> str:
|
| 139 |
+
"""Load datasets from S3 and register them in DuckDB format."""
|
| 140 |
+
global _registered_datasets, _dataset_context
|
| 141 |
+
|
| 142 |
+
_registered_datasets.clear()
|
| 143 |
+
_dataset_context.clear()
|
| 144 |
+
|
| 145 |
+
context_parts = ["Available datasets for SQL queries:\n"]
|
| 146 |
+
|
| 147 |
+
for idx, file_path in enumerate(file_paths, start=1):
|
| 148 |
+
table_name = f"CSV{idx}"
|
| 149 |
+
|
| 150 |
+
try:
|
| 151 |
+
print(f"📥 Loading {file_path} as {table_name}")
|
| 152 |
+
df = read_csv_from_s3(file_path)
|
| 153 |
+
|
| 154 |
+
_registered_datasets[table_name] = df
|
| 155 |
+
|
| 156 |
+
dataset_context = generate_dataset_context(df, table_name)
|
| 157 |
+
_dataset_context[table_name] = dataset_context
|
| 158 |
+
context_parts.append(f"\n{dataset_context}\n")
|
| 159 |
+
context_parts.append("="*70)
|
| 160 |
+
|
| 161 |
+
print(f"✅ Loaded {table_name}: {len(df)} rows, {len(df.columns)} columns")
|
| 162 |
+
|
| 163 |
+
except Exception as e:
|
| 164 |
+
print(f"❌ Error loading {file_path}: {str(e)}")
|
| 165 |
+
raise HTTPException(status_code=400, detail=f"Failed to load {file_path}: {str(e)}")
|
| 166 |
+
|
| 167 |
+
full_context = "\n".join(context_parts)
|
| 168 |
+
return full_context
|
| 169 |
+
|
| 170 |
+
# =====================================
|
| 171 |
+
# 🔹 CHART.JS VISUALIZATION FUNCTION
|
| 172 |
+
# =====================================
|
| 173 |
+
def generate_chartjs_visualization(result_df: pd.DataFrame, user_prompt: str) -> dict:
|
| 174 |
+
"""
|
| 175 |
+
Generate Chart.js JSON config from DataFrame using LLM.
|
| 176 |
+
|
| 177 |
+
Args:
|
| 178 |
+
result_df: pandas DataFrame to visualize
|
| 179 |
+
user_prompt: User's visualization request
|
| 180 |
+
|
| 181 |
+
Returns:
|
| 182 |
+
dict: Chart.js configuration JSON
|
| 183 |
+
"""
|
| 184 |
+
llm = ChatOpenAI(model="gpt-4o-mini", temperature=0.3)
|
| 185 |
+
|
| 186 |
+
prompt = f"""You are an expert data visualization analyst. Analyze the DataFrame and user query to generate the BEST Chart.js visualization.
|
| 187 |
+
|
| 188 |
+
User Query: "{user_prompt}"
|
| 189 |
+
|
| 190 |
+
Your Task:
|
| 191 |
+
1. ANALYZE the DataFrame structure:
|
| 192 |
+
- Identify column types (categorical, numerical, temporal)
|
| 193 |
+
- Detect relationships between columns
|
| 194 |
+
- Understand data distribution and patterns
|
| 195 |
+
|
| 196 |
+
2. DETERMINE the optimal visualization:
|
| 197 |
+
- For comparisons → bar chart
|
| 198 |
+
- For trends over time → line chart
|
| 199 |
+
- For proportions/parts of whole → pie/doughnut
|
| 200 |
+
- For distributions → scatter/bubble
|
| 201 |
+
- For multivariate comparisons → radar/polarArea
|
| 202 |
+
- Match chart type to data characteristics AND user intent
|
| 203 |
+
|
| 204 |
+
3. GENERATE Chart.js JSON with:
|
| 205 |
+
- Correct axis mappings (categorical → labels, numerical → data)
|
| 206 |
+
- Meaningful chart title derived from the data context
|
| 207 |
+
- Appropriate colors for distinction
|
| 208 |
+
- Professional styling options
|
| 209 |
+
|
| 210 |
+
Available chart types: bar, line, pie, doughnut, polarArea, radar, scatter, bubble
|
| 211 |
+
|
| 212 |
+
Color palette: #FF6384, #36A2EB, #FFCE56, #4BC0C0, #9966FF, #FF9F40
|
| 213 |
+
|
| 214 |
+
**CRITICAL**: Output ONLY valid JSON with all property names enclosed in double quotes. Do NOT include any text, comments, or explanations outside the JSON. Ensure the JSON is complete and properly formatted. All property names MUST be enclosed in double quotes.
|
| 215 |
+
|
| 216 |
+
Example of valid output:
|
| 217 |
+
{{
|
| 218 |
+
"type": "bar",
|
| 219 |
+
"data": {{
|
| 220 |
+
"labels": ["A", "B", "C"],
|
| 221 |
+
"datasets": [{{
|
| 222 |
+
"label": "Sample Data",
|
| 223 |
+
"data": [10, 20, 30],
|
| 224 |
+
"backgroundColor": ["#FF6384", "#36A2EB", "#FFCE56"],
|
| 225 |
+
"borderColor": ["#FF6384", "#36A2EB", "#FFCE56"],
|
| 226 |
+
"borderWidth": 2
|
| 227 |
+
}}]
|
| 228 |
+
}},
|
| 229 |
+
"options": {{
|
| 230 |
+
"responsive": true,
|
| 231 |
+
"maintainAspectRatio": false,
|
| 232 |
+
"plugins": {{
|
| 233 |
+
"title": {{
|
| 234 |
+
"display": true,
|
| 235 |
+
"text": "Sample Bar Chart",
|
| 236 |
+
"font": {{ "size": 16 }}
|
| 237 |
+
}},
|
| 238 |
+
"legend": {{
|
| 239 |
+
"display": true,
|
| 240 |
+
"position": "top"
|
| 241 |
+
}}
|
| 242 |
+
}},
|
| 243 |
+
"scales": {{
|
| 244 |
+
"y": {{
|
| 245 |
+
"beginAtZero": true,
|
| 246 |
+
"title": {{ "display": true, "text": "Values" }}
|
| 247 |
+
}},
|
| 248 |
+
"x": {{
|
| 249 |
+
"title": {{ "display": true, "text": "Categories" }}
|
| 250 |
+
}}
|
| 251 |
+
}}
|
| 252 |
+
}}
|
| 253 |
+
}}
|
| 254 |
+
|
| 255 |
+
Output ONLY the JSON content:
|
| 256 |
+
{{
|
| 257 |
+
"type": "best_chart_type",
|
| 258 |
+
"data": {{
|
| 259 |
+
"labels": ["derived", "from", "dataframe"],
|
| 260 |
+
"datasets": [{{
|
| 261 |
+
"label": "Meaningful Label",
|
| 262 |
+
"data": [actual, values, from, df],
|
| 263 |
+
"backgroundColor": ["#FF6384", "#36A2EB", "#FFCE56"],
|
| 264 |
+
"borderColor": ["#FF6384", "#36A2EB", "#FFCE56"],
|
| 265 |
+
"borderWidth": 2
|
| 266 |
+
}}]
|
| 267 |
+
}},
|
| 268 |
+
"options": {{
|
| 269 |
+
"responsive": true,
|
| 270 |
+
"maintainAspectRatio": false,
|
| 271 |
+
"plugins": {{
|
| 272 |
+
"title": {{
|
| 273 |
+
"display": true,
|
| 274 |
+
"text": "Descriptive Title Based on Data",
|
| 275 |
+
"font": {{ "size": 16 }}
|
| 276 |
+
}},
|
| 277 |
+
"legend": {{
|
| 278 |
+
"display": true,
|
| 279 |
+
"position": "top"
|
| 280 |
+
}}
|
| 281 |
+
}},
|
| 282 |
+
"scales": {{
|
| 283 |
+
"y": {{
|
| 284 |
+
"beginAtZero": true,
|
| 285 |
+
"title": {{ "display": true, "text": "Y Axis Label" }}
|
| 286 |
+
}},
|
| 287 |
+
"x": {{
|
| 288 |
+
"title": {{ "display": true, "text": "X Axis Label" }}
|
| 289 |
+
}}
|
| 290 |
+
}}
|
| 291 |
+
}}
|
| 292 |
+
}}
|
| 293 |
+
"""
|
| 294 |
+
|
| 295 |
+
messages = [
|
| 296 |
+
SystemMessage(content=prompt),
|
| 297 |
+
HumanMessage(content=result_df.to_string())
|
| 298 |
+
]
|
| 299 |
+
|
| 300 |
+
response = llm.invoke(messages)
|
| 301 |
+
print(f"Raw LLM Response:\n{response.content}\n")
|
| 302 |
+
|
| 303 |
+
json_match = re.search(r"\{.*\}", response.content, re.DOTALL)
|
| 304 |
+
if not json_match:
|
| 305 |
+
raise ValueError(f"No valid JSON block found in LLM response:\n{response.content}")
|
| 306 |
+
|
| 307 |
+
json_string = json_match.group(0).strip()
|
| 308 |
+
print(f"Extracted JSON String:\n{json_string}\n")
|
| 309 |
+
|
| 310 |
+
try:
|
| 311 |
+
config = json.loads(json_string)
|
| 312 |
+
except json.JSONDecodeError as e:
|
| 313 |
+
raise ValueError(f"Invalid JSON format: {str(e)}\nJSON String:\n{json_string}")
|
| 314 |
+
|
| 315 |
+
print(f"Generated Chart.js Config: {json.dumps(config, indent=2)}")
|
| 316 |
+
|
| 317 |
+
if "type" not in config or "data" not in config:
|
| 318 |
+
raise ValueError("Invalid Chart.js config")
|
| 319 |
+
|
| 320 |
+
return config
|
| 321 |
+
|
| 322 |
+
# =====================================
|
| 323 |
+
# 🔹 VECTOR AGENT TOOL
|
| 324 |
+
# =====================================
|
| 325 |
+
def query_vector_agent_calling(user_query: str, collection_name: str) -> str:
|
| 326 |
+
"""
|
| 327 |
+
Call the vector agent API to get relevant context for the given user query and collection name.
|
| 328 |
+
"""
|
| 329 |
+
base_url = "https://srivatsavdamaraju-mvp-2-0-deploy-all-apis.hf.space/qdrant/search"
|
| 330 |
+
url = f"{base_url}?collection_name={collection_name}&mode=dense"
|
| 331 |
+
|
| 332 |
+
headers = {
|
| 333 |
+
"accept": "application/json",
|
| 334 |
+
"Content-Type": "application/json",
|
| 335 |
+
}
|
| 336 |
+
|
| 337 |
+
payload = {
|
| 338 |
+
"query": user_query,
|
| 339 |
+
"top_k": 2
|
| 340 |
+
}
|
| 341 |
+
|
| 342 |
+
try:
|
| 343 |
+
response = requests.post(url, headers=headers, json=payload)
|
| 344 |
+
response.raise_for_status()
|
| 345 |
+
|
| 346 |
+
data = response.json()
|
| 347 |
+
results = data.get("results") or data.get("result") or data
|
| 348 |
+
if not results:
|
| 349 |
+
return "No relevant context found."
|
| 350 |
+
|
| 351 |
+
contexts = []
|
| 352 |
+
for item in results:
|
| 353 |
+
text = item.get("text") or item.get("payload", {}).get("text")
|
| 354 |
+
if text:
|
| 355 |
+
contexts.append(text)
|
| 356 |
+
|
| 357 |
+
return "\n\n".join(contexts) if contexts else str(data)
|
| 358 |
+
|
| 359 |
+
except requests.RequestException as e:
|
| 360 |
+
print(f"Error calling vector agent API: {e}")
|
| 361 |
+
return "Error retrieving context."
|
| 362 |
+
|
| 363 |
+
@function_tool
|
| 364 |
+
async def vector_agent_tool(
|
| 365 |
+
user_query: Annotated[str, "The user's natural language question"],
|
| 366 |
+
collection_name: Annotated[str, "The name of the collection to search in"]
|
| 367 |
+
) -> str:
|
| 368 |
+
"""
|
| 369 |
+
Tool to call the vector agent API and retrieve relevant context.
|
| 370 |
+
|
| 371 |
+
Args:
|
| 372 |
+
user_query: The user's natural language question.
|
| 373 |
+
collection_name: The name of the collection to search in.
|
| 374 |
+
|
| 375 |
+
Returns:
|
| 376 |
+
A string containing the retrieved context.
|
| 377 |
+
"""
|
| 378 |
+
return query_vector_agent_calling(user_query, collection_name)
|
| 379 |
+
|
| 380 |
+
# =====================================
|
| 381 |
+
# 🔹 SQL TOOL
|
| 382 |
+
# =====================================
|
| 383 |
+
@function_tool
|
| 384 |
+
async def sql_tool(
|
| 385 |
+
sql_query: Annotated[str, "SQL query to execute on registered datasets"],
|
| 386 |
+
user_query: Annotated[str, "Original user question for visualization context"] = ""
|
| 387 |
+
) -> SQLResult:
|
| 388 |
+
"""
|
| 389 |
+
Execute SQL query on registered datasets and return results.
|
| 390 |
+
|
| 391 |
+
Available tables: CSV1, CSV2, CSV3, etc. (based on loaded files)
|
| 392 |
+
Supported operations: SELECT, WITH (CTE)
|
| 393 |
+
|
| 394 |
+
Args:
|
| 395 |
+
sql_query: The SQL query to execute
|
| 396 |
+
user_query: The original user question (REQUIRED for visualization context)
|
| 397 |
+
|
| 398 |
+
Example queries:
|
| 399 |
+
- SELECT * FROM CSV1 WHERE age > 50 LIMIT 10
|
| 400 |
+
- SELECT a.name, b.value FROM CSV1 a JOIN CSV2 b ON a.id = b.id
|
| 401 |
+
|
| 402 |
+
Example usage:
|
| 403 |
+
- sql_tool(
|
| 404 |
+
sql_query="SELECT * FROM CSV1",
|
| 405 |
+
user_query="What is the average age?"
|
| 406 |
+
)
|
| 407 |
+
"""
|
| 408 |
+
global _registered_datasets, last_sql_query
|
| 409 |
+
|
| 410 |
+
last_sql_query = sql_query
|
| 411 |
+
|
| 412 |
+
start_time = datetime.now()
|
| 413 |
+
print(f"\n{'='*70}")
|
| 414 |
+
print(f"🔍 Executing SQL Query:\n{sql_query}\n")
|
| 415 |
+
print(f"📝 User Query Context: {user_query}\n")
|
| 416 |
+
print(f"{'='*70}\n")
|
| 417 |
+
|
| 418 |
+
try:
|
| 419 |
+
normalized_query = sql_query.strip().lower()
|
| 420 |
+
if not normalized_query.startswith(ALLOWED_SQL_KEYWORDS):
|
| 421 |
+
raise ValueError(
|
| 422 |
+
f"Only {', '.join(k.upper() for k in ALLOWED_SQL_KEYWORDS)} queries are allowed."
|
| 423 |
+
)
|
| 424 |
+
|
| 425 |
+
if not _registered_datasets:
|
| 426 |
+
raise ValueError("No datasets loaded. Please provide file_paths first.")
|
| 427 |
+
|
| 428 |
+
conn = duckdb.connect(":memory:")
|
| 429 |
+
|
| 430 |
+
for table_name, df in _registered_datasets.items():
|
| 431 |
+
conn.register(table_name, df)
|
| 432 |
+
|
| 433 |
+
result_df = conn.execute(sql_query).df()
|
| 434 |
+
conn.close()
|
| 435 |
+
|
| 436 |
+
if len(result_df) > MAX_RESULT_ROWS:
|
| 437 |
+
result_df = result_df.head(MAX_RESULT_ROWS)
|
| 438 |
+
|
| 439 |
+
execution_time = (datetime.now() - start_time).total_seconds() * 1000
|
| 440 |
+
|
| 441 |
+
summary_lines = [
|
| 442 |
+
f"Query returned {len(result_df)} rows with {len(result_df.columns)} columns: {', '.join(result_df.columns.tolist())}"
|
| 443 |
+
]
|
| 444 |
+
|
| 445 |
+
if not result_df.empty:
|
| 446 |
+
summary_lines.append("\nResults preview (first 10 rows):")
|
| 447 |
+
preview = result_df.head(10).to_string(index=False, max_rows=10)
|
| 448 |
+
summary_lines.append(preview)
|
| 449 |
+
|
| 450 |
+
result_summary = "\n".join(summary_lines)
|
| 451 |
+
|
| 452 |
+
print(f"✅ SQL executed: {len(result_df)} rows in {execution_time:.2f}ms\n")
|
| 453 |
+
|
| 454 |
+
return SQLResult(
|
| 455 |
+
sql_query=sql_query,
|
| 456 |
+
result_summary=result_summary,
|
| 457 |
+
columns=result_df.columns.tolist(),
|
| 458 |
+
row_count=len(result_df),
|
| 459 |
+
execution_time_ms=round(execution_time, 2),
|
| 460 |
+
)
|
| 461 |
+
|
| 462 |
+
except ValueError as ve:
|
| 463 |
+
raise ValueError(f"SQL validation error: {str(ve)}")
|
| 464 |
+
except Exception as e:
|
| 465 |
+
raise RuntimeError(f"SQL execution error: {str(e)}")
|
| 466 |
+
|
| 467 |
+
# =====================================
|
| 468 |
+
# 🔹 VISUALIZATION TASK
|
| 469 |
+
# =====================================
|
| 470 |
+
async def generate_visualization_task(query: str) -> Optional[dict]:
|
| 471 |
+
"""Async task to generate Chart.js visualization using last_sql_query and upload to S3."""
|
| 472 |
+
global _registered_datasets, last_sql_query, chartjs_path
|
| 473 |
+
|
| 474 |
+
if not last_sql_query or not _registered_datasets:
|
| 475 |
+
print("⚠️ No SQL query executed or no datasets loaded, skipping visualization.")
|
| 476 |
+
return None
|
| 477 |
+
|
| 478 |
+
print(f"🎨 Generating Chart.js visualization...")
|
| 479 |
+
try:
|
| 480 |
+
conn = duckdb.connect(":memory:")
|
| 481 |
+
for table_name, df in _registered_datasets.items():
|
| 482 |
+
conn.register(table_name, df)
|
| 483 |
+
result_df = conn.execute(last_sql_query).df()
|
| 484 |
+
conn.close()
|
| 485 |
+
|
| 486 |
+
visualization_prompt = query.strip() if query and query.strip() else "Visualize the query results"
|
| 487 |
+
print(f"🎯 User prompt for visualization: '{visualization_prompt}'")
|
| 488 |
+
|
| 489 |
+
chartjs_config = generate_chartjs_visualization(
|
| 490 |
+
result_df=result_df,
|
| 491 |
+
user_prompt=visualization_prompt
|
| 492 |
+
)
|
| 493 |
+
print(f"✅ Chart.js visualization generated.")
|
| 494 |
+
print(f"Chart.js Config: {json.dumps(chartjs_config, indent=2)}")
|
| 495 |
+
|
| 496 |
+
# Save Chart.js config to temporary file and upload to S3
|
| 497 |
+
chartjs_path.clear()
|
| 498 |
+
chartjs_path.append(json.dumps(chartjs_config, indent=2)) # Store locally in the global list
|
| 499 |
+
|
| 500 |
+
return chartjs_config
|
| 501 |
+
|
| 502 |
+
except Exception as viz_error:
|
| 503 |
+
print(f"⚠️ Visualization generation failed: {str(viz_error)}")
|
| 504 |
+
print(f"📊 Continuing with results without visualization...")
|
| 505 |
+
return None
|
| 506 |
+
|
| 507 |
+
# =====================================
|
| 508 |
+
# 🔹 MAIN ANALYSIS FUNCTION
|
| 509 |
+
# =====================================
|
| 510 |
+
async def main(query: str, file_paths: List[str] = None) -> dict:
|
| 511 |
+
"""Main function to run data analysis with proper agent initialization."""
|
| 512 |
+
|
| 513 |
+
if file_paths is None:
|
| 514 |
+
file_paths = []
|
| 515 |
+
|
| 516 |
+
print(f"\n{'='*70}")
|
| 517 |
+
print(f"📋 Analysis Request")
|
| 518 |
+
print(f"{'='*70}")
|
| 519 |
+
print(f"Files: {', '.join(file_paths)}")
|
| 520 |
+
print(f"Query: {query}")
|
| 521 |
+
print(f"{'='*70}\n")
|
| 522 |
+
|
| 523 |
+
dataset_context = load_and_register_datasets(file_paths)
|
| 524 |
+
|
| 525 |
+
print(f"\n{'='*70}")
|
| 526 |
+
print("📊 Dataset Context Summary")
|
| 527 |
+
print(f"{'='*70}")
|
| 528 |
+
print(dataset_context[:500] + "..." if len(dataset_context) > 500 else dataset_context)
|
| 529 |
+
print(f"{'='*70}\n")
|
| 530 |
+
|
| 531 |
+
general_purpose_agent = Agent(
|
| 532 |
+
name="general_purpose_agent",
|
| 533 |
+
instructions=(
|
| 534 |
+
"You are a friendly AI assistant with high expertise in data analysis, SQL, and visualization. "
|
| 535 |
+
"Only answer if the query is not related to data analysis."
|
| 536 |
+
),
|
| 537 |
+
model="gpt-4o-mini",
|
| 538 |
+
tools=[]
|
| 539 |
+
)
|
| 540 |
+
|
| 541 |
+
DataAnalyst_agent = Agent(
|
| 542 |
+
name="DataAnalyst_agent",
|
| 543 |
+
instructions=(
|
| 544 |
+
f"""You are a professional data analyst with expertise in SQL and data analysis.
|
| 545 |
+
|
| 546 |
+
{dataset_context}
|
| 547 |
+
|
| 548 |
+
Your workflow:
|
| 549 |
+
1. Analyze the user's question carefully
|
| 550 |
+
2. Review the available datasets and their schemas above
|
| 551 |
+
3. **Generate appropriate SQL query using sql_tool - ALWAYS pass the user_query parameter**
|
| 552 |
+
|
| 553 |
+
CRITICAL - Tool Usage:
|
| 554 |
+
- Call sql_tool with BOTH parameters:
|
| 555 |
+
* sql_query: Your SQL SELECT statement
|
| 556 |
+
* user_query: The EXACT original user question (this creates the visualization)
|
| 557 |
+
|
| 558 |
+
Example correct usage:
|
| 559 |
+
sql_tool(
|
| 560 |
+
sql_query="SELECT age, COUNT(*) as count FROM CSV1 GROUP BY age ORDER BY age",
|
| 561 |
+
user_query="Show me the age distribution of patients"
|
| 562 |
+
)
|
| 563 |
+
|
| 564 |
+
- Use table names: CSV1, CSV2, CSV3, etc.
|
| 565 |
+
- You can JOIN tables if needed
|
| 566 |
+
- Use descriptive column aliases
|
| 567 |
+
- Add ORDER BY for readability
|
| 568 |
+
- Use LIMIT for large result sets
|
| 569 |
+
|
| 570 |
+
4. Interpret the results and provide clear, detailed insights with specific numbers and percentages
|
| 571 |
+
|
| 572 |
+
SQL Best Practices:
|
| 573 |
+
- Always check which columns exist in which tables
|
| 574 |
+
- Use table aliases for clarity (e.g., CSV1 AS c1)
|
| 575 |
+
- Filter data appropriately based on the question
|
| 576 |
+
- Aggregate data when asked for counts, sums, averages
|
| 577 |
+
- Sort results logically
|
| 578 |
+
|
| 579 |
+
Important:
|
| 580 |
+
- Only use columns that exist in the datasets above
|
| 581 |
+
- Reference the correct table names (CSV1, CSV2, etc.)
|
| 582 |
+
- **MANDATORY: Always pass both sql_query AND user_query to sql_tool**
|
| 583 |
+
- The user_query parameter is used to generate intelligent visualizations
|
| 584 |
+
- Always provide complete analysis with specific findings
|
| 585 |
+
- Include tables, numbers, and percentages in your response
|
| 586 |
+
- Provide insights in plain language for non-technical users
|
| 587 |
+
"""
|
| 588 |
+
),
|
| 589 |
+
model="gpt-4o-mini",
|
| 590 |
+
tools=[sql_tool],
|
| 591 |
+
)
|
| 592 |
+
|
| 593 |
+
orchestrator_agent = Agent(
|
| 594 |
+
name="orchestrator_agent",
|
| 595 |
+
instructions=(
|
| 596 |
+
"You receive the user's query and route it to the DataAnalyst agent for execution. "
|
| 597 |
+
"Always use the analyze_dataset tool to process data analysis queries. "
|
| 598 |
+
"Pass the complete user query to the DataAnalyst agent."
|
| 599 |
+
),
|
| 600 |
+
tools=[
|
| 601 |
+
DataAnalyst_agent.as_tool(
|
| 602 |
+
tool_name="analyze_dataset",
|
| 603 |
+
tool_description="Analyze dataset, generate SQL, and visualize results."
|
| 604 |
+
),
|
| 605 |
+
general_purpose_agent.as_tool(
|
| 606 |
+
tool_name="general_purpose_tool",
|
| 607 |
+
tool_description="Act like a human friendly AI assistant for non-data analysis queries."
|
| 608 |
+
)
|
| 609 |
+
],
|
| 610 |
+
model="gpt-4o-mini",
|
| 611 |
+
)
|
| 612 |
+
|
| 613 |
+
print(f"\n{'='*70}")
|
| 614 |
+
print("🚀 Running Orchestrator Agent")
|
| 615 |
+
print(f"{'='*70}\n")
|
| 616 |
+
|
| 617 |
+
orchestrator_result = await Runner.run(orchestrator_agent, query)
|
| 618 |
+
|
| 619 |
+
print(f"\n{'='*70}")
|
| 620 |
+
print("📝 Orchestrator Output")
|
| 621 |
+
print(f"{'='*70}")
|
| 622 |
+
|
| 623 |
+
analysis_text = ""
|
| 624 |
+
for item in orchestrator_result.new_items:
|
| 625 |
+
if hasattr(item, 'output') and item.output:
|
| 626 |
+
analysis_text = item.output
|
| 627 |
+
print(f"Found analysis output: {len(analysis_text)} characters")
|
| 628 |
+
elif hasattr(item, 'raw_item') and hasattr(item.raw_item, 'content'):
|
| 629 |
+
for content_item in item.raw_item.content:
|
| 630 |
+
if hasattr(content_item, 'text'):
|
| 631 |
+
analysis_text = content_item.text
|
| 632 |
+
print(f"Found message output: {len(analysis_text)} characters")
|
| 633 |
+
|
| 634 |
+
print(f"{'='*70}\n")
|
| 635 |
+
|
| 636 |
+
synthesizer_agent = Agent(
|
| 637 |
+
name="synthesizer_agent",
|
| 638 |
+
instructions=(
|
| 639 |
+
"""You are a final report writer. Your job is to present the complete data analysis results.
|
| 640 |
+
|
| 641 |
+
CRITICAL RULES:
|
| 642 |
+
1. Present the COMPLETE analysis from the DataAnalyst agent
|
| 643 |
+
2. Include ALL tables, numbers, statistics, and findings
|
| 644 |
+
3. Maintain all the detailed insights and percentages
|
| 645 |
+
4. Use clear formatting with headers, tables, and bullet points
|
| 646 |
+
5. DO NOT summarize or truncate the analysis
|
| 647 |
+
6. DO NOT add generic closing statements like "feel free to ask"
|
| 648 |
+
7. Make the report professional and easy to read
|
| 649 |
+
|
| 650 |
+
Your output should be the full, detailed analysis report with all findings intact.
|
| 651 |
+
"""
|
| 652 |
+
),
|
| 653 |
+
model="gpt-4o-mini",
|
| 654 |
+
)
|
| 655 |
+
|
| 656 |
+
print(f"\n{'='*70}")
|
| 657 |
+
print("🎯 Running Synthesizer and Visualization in Parallel")
|
| 658 |
+
print(f"{'='*70}\n")
|
| 659 |
+
|
| 660 |
+
synthesizer_task = asyncio.create_task(
|
| 661 |
+
Runner.run(
|
| 662 |
+
synthesizer_agent,
|
| 663 |
+
f"Present this complete data analysis:\n\n{analysis_text}" if analysis_text else orchestrator_result.to_input_list()
|
| 664 |
+
)
|
| 665 |
+
)
|
| 666 |
+
visualization_task = asyncio.create_task(generate_visualization_task(query))
|
| 667 |
+
|
| 668 |
+
synthesizer_result, visualization_result = await asyncio.gather(synthesizer_task, visualization_task, return_exceptions=True)
|
| 669 |
+
|
| 670 |
+
if isinstance(synthesizer_result, Exception):
|
| 671 |
+
print(f"⚠️ Synthesizer task failed: {str(synthesizer_result)}")
|
| 672 |
+
synthesizer_result = type('Result', (), {'final_output': 'Synthesizer task failed'})()
|
| 673 |
+
|
| 674 |
+
if isinstance(visualization_result, Exception):
|
| 675 |
+
print(f"⚠️ Visualization task failed: {str(visualization_result)}")
|
| 676 |
+
visualization_result = None
|
| 677 |
+
|
| 678 |
+
print(f"\n{'='*70}")
|
| 679 |
+
print("✅ FINAL ANSWER")
|
| 680 |
+
print(f"{'='*70}")
|
| 681 |
+
print(synthesizer_result.final_output)
|
| 682 |
+
print(f"{'='*70}\n")
|
| 683 |
+
|
| 684 |
+
print(f"chartjs_config_path: {chartjs_path}")
|
| 685 |
+
print(f"{'='*70}\n")
|
| 686 |
+
|
| 687 |
+
final_api_output = {
|
| 688 |
+
"convoId": "1234",
|
| 689 |
+
"type": "assistant_message",
|
| 690 |
+
"hidden_code": True if last_sql_query else False,
|
| 691 |
+
"hidden_text": last_sql_query,
|
| 692 |
+
"ai_message": synthesizer_result.final_output,
|
| 693 |
+
"artifacts": chartjs_path
|
| 694 |
+
}
|
| 695 |
+
# print(f"Final JSON Output: {json.dumps(final_api_output, indent=2)}")
|
| 696 |
+
final_api_output_json = json.dumps(final_api_output, indent=2)
|
| 697 |
+
return final_api_output_json
|
| 698 |
+
|
| 699 |
+
# =====================================
|
| 700 |
+
# 🔹 STANDALONE TESTING
|
Routes/helpers/main_agent_best_as_of_now.py-update.txt
ADDED
|
@@ -0,0 +1,1110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import duckdb
|
| 2 |
+
import sys
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
from datetime import datetime
|
| 5 |
+
from typing import Annotated, Optional, List
|
| 6 |
+
import re
|
| 7 |
+
import json
|
| 8 |
+
import pandas as pd
|
| 9 |
+
from pydantic import BaseModel, Field, ConfigDict
|
| 10 |
+
from fastapi import APIRouter, HTTPException
|
| 11 |
+
from fastapi.responses import JSONResponse
|
| 12 |
+
import os
|
| 13 |
+
import dotenv
|
| 14 |
+
import tempfile
|
| 15 |
+
from langchain_openai import ChatOpenAI
|
| 16 |
+
from langchain.schema import SystemMessage, HumanMessage
|
| 17 |
+
import requests
|
| 18 |
+
import asyncio
|
| 19 |
+
# Project root setup
|
| 20 |
+
PROJECT_ROOT = Path(__file__).resolve().parents[2]
|
| 21 |
+
if str(PROJECT_ROOT) not in sys.path:
|
| 22 |
+
sys.path.insert(0, str(PROJECT_ROOT))
|
| 23 |
+
|
| 24 |
+
# Assuming these are custom modules in your project
|
| 25 |
+
from agents import Agent, Runner, function_tool
|
| 26 |
+
from s3.read_files import read_csv_from_s3
|
| 27 |
+
|
| 28 |
+
# =====================================
|
| 29 |
+
# 🔹 CONFIGURATION
|
| 30 |
+
# =====================================
|
| 31 |
+
dotenv.load_dotenv()
|
| 32 |
+
|
| 33 |
+
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
|
| 34 |
+
if not OPENAI_API_KEY:
|
| 35 |
+
raise ValueError("OPENAI_API_KEY not found in environment variables")
|
| 36 |
+
os.environ["OPENAI_API_KEY"] = OPENAI_API_KEY
|
| 37 |
+
|
| 38 |
+
# Constants
|
| 39 |
+
ALLOWED_SQL_KEYWORDS = ("select", "with")
|
| 40 |
+
MAX_RESULT_ROWS = 10000
|
| 41 |
+
|
| 42 |
+
# Global storage for datasets and context
|
| 43 |
+
_registered_datasets = {}
|
| 44 |
+
_dataset_context = {}
|
| 45 |
+
last_sql_query: str = ""
|
| 46 |
+
vega_lite_path = [] # Global variable to store Vega-Lite JSON config path
|
| 47 |
+
|
| 48 |
+
import sqlparse
|
| 49 |
+
|
| 50 |
+
def is_safe_query(sql_query: str) -> bool:
|
| 51 |
+
parsed = sqlparse.parse(sql_query)[0]
|
| 52 |
+
forbidden_keywords = {"DROP", "DELETE", "UPDATE", "INSERT"}
|
| 53 |
+
tokens = [token.value.upper() for token in parsed.tokens if token.is_keyword]
|
| 54 |
+
return not any(keyword in forbidden_keywords for keyword in tokens)
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
# =====================================
|
| 58 |
+
# 🔹 S3 UPLOAD FUNCTION
|
| 59 |
+
# =====================================
|
| 60 |
+
def upload_file_to_s3(file_path: str, s3_path: str = "vatsav/artifacts/") -> str:
|
| 61 |
+
"""
|
| 62 |
+
Uploads a file to S3 using the HuggingFace Space API.
|
| 63 |
+
|
| 64 |
+
Args:
|
| 65 |
+
file_path (str): Local path to the file to upload.
|
| 66 |
+
s3_path (str): S3 path where the file should be uploaded.
|
| 67 |
+
|
| 68 |
+
Returns:
|
| 69 |
+
str: S3 URL of the uploaded file.
|
| 70 |
+
|
| 71 |
+
Raises:
|
| 72 |
+
Exception: If upload fails.
|
| 73 |
+
"""
|
| 74 |
+
api_url = f"https://srivatsavdamaraju-mvp-2-0-deploy-all-apis.hf.space/s3/upload/?path={s3_path}"
|
| 75 |
+
|
| 76 |
+
# Determine content type based on file extension
|
| 77 |
+
if file_path.endswith('.html'):
|
| 78 |
+
content_type = 'text/html'
|
| 79 |
+
elif file_path.endswith('.json'):
|
| 80 |
+
content_type = 'application/json'
|
| 81 |
+
else:
|
| 82 |
+
content_type = 'application/octet-stream'
|
| 83 |
+
|
| 84 |
+
# Open and upload the file
|
| 85 |
+
with open(file_path, 'rb') as f:
|
| 86 |
+
files = {
|
| 87 |
+
'file': (Path(file_path).name, f, content_type)
|
| 88 |
+
}
|
| 89 |
+
headers = {
|
| 90 |
+
'accept': 'application/json'
|
| 91 |
+
}
|
| 92 |
+
|
| 93 |
+
response = requests.post(api_url, headers=headers, files=files)
|
| 94 |
+
|
| 95 |
+
if response.status_code == 200:
|
| 96 |
+
response_data = response.json()
|
| 97 |
+
s3_url = response_data.get('url') or response_data.get('file_url') or f"{s3_path}{Path(file_path).name}"
|
| 98 |
+
return s3_url
|
| 99 |
+
else:
|
| 100 |
+
raise Exception(f"Upload failed with status {response.status_code}: {response.text}")
|
| 101 |
+
|
| 102 |
+
# ... [Previous code for dataset management, pydantic models, vector agent tool, SQL tool, and generate_chartjs_visualization remains unchanged]
|
| 103 |
+
|
| 104 |
+
# =====================================
|
| 105 |
+
# 🔹 UTILITY FUNCTIONS
|
| 106 |
+
# =====================================
|
| 107 |
+
def save_convo_id_to_pgdb(convo_id: str):
|
| 108 |
+
"""Save conversation ID to PostgreSQL database (stub function)."""
|
| 109 |
+
print(f"Saving conversation ID to PostgreSQL DB: {convo_id}")
|
| 110 |
+
pass
|
| 111 |
+
|
| 112 |
+
# =====================================
|
| 113 |
+
# 🔹 PYDANTIC MODELS
|
| 114 |
+
# =====================================
|
| 115 |
+
class AnalysisRequest(BaseModel):
|
| 116 |
+
"""Request model for data analysis."""
|
| 117 |
+
file_paths: List[str] = Field(..., description="List of S3 file paths to analyze")
|
| 118 |
+
query: str = Field(..., description="Natural language query for analysis")
|
| 119 |
+
|
| 120 |
+
model_config = ConfigDict(extra='forbid')
|
| 121 |
+
|
| 122 |
+
class SQLResult(BaseModel):
|
| 123 |
+
"""Result from SQL query execution."""
|
| 124 |
+
model_config = ConfigDict(extra='forbid')
|
| 125 |
+
|
| 126 |
+
sql_query: str = Field(..., description="The executed SQL query")
|
| 127 |
+
result_summary: str = Field(..., description="Human-readable summary of results")
|
| 128 |
+
columns: List[str] = Field(..., description="Column names from result")
|
| 129 |
+
row_count: int = Field(..., description="Number of rows returned")
|
| 130 |
+
execution_time_ms: float = Field(..., description="Query execution time in milliseconds")
|
| 131 |
+
|
| 132 |
+
class AnalysisResponse(BaseModel):
|
| 133 |
+
"""Response from analysis endpoint."""
|
| 134 |
+
model_config = ConfigDict(extra='forbid')
|
| 135 |
+
|
| 136 |
+
query: str = Field(..., description="Original user query")
|
| 137 |
+
datasets_loaded: List[str] = Field(..., description="List of loaded datasets")
|
| 138 |
+
analysis_result: str = Field(..., description="Analysis insights")
|
| 139 |
+
sql_executed: str = Field(..., description="SQL query that was executed")
|
| 140 |
+
execution_time_ms: float = Field(..., description="Total execution time")
|
| 141 |
+
|
| 142 |
+
# =====================================
|
| 143 |
+
# 🔹 DATASET MANAGEMENT & CONTEXT ENGINEERING
|
| 144 |
+
# =====================================
|
| 145 |
+
def generate_dataset_context(df: pd.DataFrame, table_name: str) -> str:
|
| 146 |
+
"""Generate comprehensive context for a dataset."""
|
| 147 |
+
context_parts = []
|
| 148 |
+
|
| 149 |
+
context_parts.append(f"Table: {table_name}")
|
| 150 |
+
context_parts.append(f"Rows: {len(df)}, Columns: {len(df.columns)}")
|
| 151 |
+
context_parts.append(f"Column Names: {', '.join(df.columns.tolist())}")
|
| 152 |
+
|
| 153 |
+
context_parts.append("\nData Types:")
|
| 154 |
+
for col, dtype in df.dtypes.items():
|
| 155 |
+
context_parts.append(f" - {col}: {dtype}")
|
| 156 |
+
|
| 157 |
+
context_parts.append("\nSample Data (first 3 rows):")
|
| 158 |
+
sample_str = df.head(3).to_string(index=False, max_cols=10)
|
| 159 |
+
context_parts.append(sample_str)
|
| 160 |
+
|
| 161 |
+
numeric_cols = df.select_dtypes(include=['number']).columns.tolist()
|
| 162 |
+
if numeric_cols:
|
| 163 |
+
context_parts.append("\nNumeric Column Statistics:")
|
| 164 |
+
desc = df[numeric_cols].describe().transpose()
|
| 165 |
+
for col in numeric_cols[:5]:
|
| 166 |
+
if col in desc.index:
|
| 167 |
+
context_parts.append(f" - {col}: mean={desc.loc[col, 'mean']:.2f}, "
|
| 168 |
+
f"min={desc.loc[col, 'min']:.2f}, max={desc.loc[col, 'max']:.2f}")
|
| 169 |
+
|
| 170 |
+
categorical_cols = df.select_dtypes(include=['object', 'category']).columns.tolist()
|
| 171 |
+
if categorical_cols:
|
| 172 |
+
context_parts.append("\nCategorical Columns:")
|
| 173 |
+
for col in categorical_cols[:5]:
|
| 174 |
+
unique_count = df[col].nunique()
|
| 175 |
+
context_parts.append(f" - {col}: {unique_count} unique values")
|
| 176 |
+
if unique_count <= 5:
|
| 177 |
+
values = df[col].value_counts().head(5).to_dict()
|
| 178 |
+
context_parts.append(f" Values: {values}")
|
| 179 |
+
|
| 180 |
+
return "\n".join(context_parts)
|
| 181 |
+
|
| 182 |
+
def load_and_register_datasets(file_paths: List[str]) -> str:
|
| 183 |
+
"""Load datasets from S3 and register them in DuckDB format."""
|
| 184 |
+
global _registered_datasets, _dataset_context
|
| 185 |
+
|
| 186 |
+
_registered_datasets.clear()
|
| 187 |
+
_dataset_context.clear()
|
| 188 |
+
|
| 189 |
+
context_parts = ["Available datasets for SQL queries:\n"]
|
| 190 |
+
|
| 191 |
+
for idx, file_path in enumerate(file_paths, start=1):
|
| 192 |
+
table_name = f"CSV{idx}"
|
| 193 |
+
|
| 194 |
+
try:
|
| 195 |
+
print(f"📥 Loading {file_path} as {table_name}")
|
| 196 |
+
df = read_csv_from_s3(file_path)
|
| 197 |
+
|
| 198 |
+
_registered_datasets[table_name] = df
|
| 199 |
+
|
| 200 |
+
dataset_context = generate_dataset_context(df, table_name)
|
| 201 |
+
_dataset_context[table_name] = dataset_context
|
| 202 |
+
context_parts.append(f"\n{dataset_context}\n")
|
| 203 |
+
context_parts.append("="*70)
|
| 204 |
+
|
| 205 |
+
print(f"✅ Loaded {table_name}: {len(df)} rows, {len(df.columns)} columns")
|
| 206 |
+
|
| 207 |
+
except Exception as e:
|
| 208 |
+
print(f"❌ Error loading {file_path}: {str(e)}")
|
| 209 |
+
raise HTTPException(status_code=400, detail=f"Failed to load {file_path}: {str(e)}")
|
| 210 |
+
|
| 211 |
+
full_context = "\n".join(context_parts)
|
| 212 |
+
return full_context
|
| 213 |
+
|
| 214 |
+
# =====================================
|
| 215 |
+
# 🔹 CHART.JS VISUALIZATION FUNCTION
|
| 216 |
+
# =====================================
|
| 217 |
+
# def generate_chartjs_visualization(result_df: pd.DataFrame, user_prompt: str) -> dict:
|
| 218 |
+
# """
|
| 219 |
+
# Generate Chart.js JSON config from DataFrame using LLM.
|
| 220 |
+
|
| 221 |
+
# Args:
|
| 222 |
+
# result_df: pandas DataFrame to visualize
|
| 223 |
+
# user_prompt: User's visualization request
|
| 224 |
+
|
| 225 |
+
# Returns:
|
| 226 |
+
# dict: Chart.js configuration JSON
|
| 227 |
+
# """
|
| 228 |
+
# # Initialize LLM
|
| 229 |
+
# llm = ChatOpenAI(model="gpt-4o-mini", temperature=0.3)
|
| 230 |
+
|
| 231 |
+
# # LLM Prompt
|
| 232 |
+
# prompt = f"""You are an expert data visualization analyst. Analyze the DataFrame and user query to generate the BEST Chart.js visualization.
|
| 233 |
+
|
| 234 |
+
# User Query: "{user_prompt}"
|
| 235 |
+
|
| 236 |
+
# Your Task:
|
| 237 |
+
# 1. ANALYZE the DataFrame structure:
|
| 238 |
+
# - Identify column types (categorical, numerical, temporal)
|
| 239 |
+
# - Detect relationships between columns
|
| 240 |
+
# - Understand data distribution and patterns
|
| 241 |
+
|
| 242 |
+
# 2. DETERMINE the optimal visualization:
|
| 243 |
+
# - For comparisons → bar chart
|
| 244 |
+
# - For trends over time → line chart
|
| 245 |
+
# - For proportions/parts of whole → pie/doughnut
|
| 246 |
+
# - For distributions → scatter/bubble
|
| 247 |
+
# - For multivariate comparisons → radar/polarArea
|
| 248 |
+
# - Match chart type to data characteristics AND user intent
|
| 249 |
+
|
| 250 |
+
# 3. GENERATE Chart.js JSON with:
|
| 251 |
+
# - Correct axis mappings (categorical → labels, numerical → data)
|
| 252 |
+
# - Meaningful chart title derived from the data context
|
| 253 |
+
# - Appropriate colors for distinction
|
| 254 |
+
# - Professional styling options
|
| 255 |
+
|
| 256 |
+
# Available chart types: bar, line, pie, doughnut, polarArea, radar, scatter, bubble
|
| 257 |
+
|
| 258 |
+
# Color palette: #FF6384, #36A2EB, #FFCE56, #4BC0C0, #9966FF, #FF9F40
|
| 259 |
+
|
| 260 |
+
# Output ONLY valid Chart.js JSON:
|
| 261 |
+
# ```json
|
| 262 |
+
# {{
|
| 263 |
+
# "type": "best_chart_type",
|
| 264 |
+
# "data": {{
|
| 265 |
+
# "labels": ["derived", "from", "dataframe"],
|
| 266 |
+
# "datasets": [{{
|
| 267 |
+
# "label": "Meaningful Label",
|
| 268 |
+
# "data": [actual, values, from, df],
|
| 269 |
+
# "backgroundColor": ["#FF6384", "#36A2EB", ...],
|
| 270 |
+
# "borderColor": ["#FF6384", "#36A2EB", ...],
|
| 271 |
+
# "borderWidth": 2
|
| 272 |
+
# }}]
|
| 273 |
+
# }},
|
| 274 |
+
# "options": {{
|
| 275 |
+
# "responsive": true,
|
| 276 |
+
# "maintainAspectRatio": false,
|
| 277 |
+
# "plugins": {{
|
| 278 |
+
# "title": {{
|
| 279 |
+
# "display": true,
|
| 280 |
+
# "text": "Descriptive Title Based on Data",
|
| 281 |
+
# "font": {{ "size": 16 }}
|
| 282 |
+
# }},
|
| 283 |
+
# "legend": {{
|
| 284 |
+
# "display": true,
|
| 285 |
+
# "position": "top"
|
| 286 |
+
# }}
|
| 287 |
+
# }},
|
| 288 |
+
# "scales": {{
|
| 289 |
+
# "y": {{
|
| 290 |
+
# "beginAtZero": true,
|
| 291 |
+
# "title": {{ "display": true, "text": "Y Axis Label" }}
|
| 292 |
+
# }},
|
| 293 |
+
# "x": {{
|
| 294 |
+
# "title": {{ "display": true, "text": "X Axis Label" }}
|
| 295 |
+
# }}
|
| 296 |
+
# }}
|
| 297 |
+
# }}
|
| 298 |
+
# }}
|
| 299 |
+
# ```"""
|
| 300 |
+
|
| 301 |
+
# # Get response
|
| 302 |
+
# messages = [
|
| 303 |
+
# SystemMessage(content=prompt),
|
| 304 |
+
# HumanMessage(content=result_df.to_string())
|
| 305 |
+
# ]
|
| 306 |
+
|
| 307 |
+
# response = llm.invoke(messages)
|
| 308 |
+
|
| 309 |
+
# # Extract JSON
|
| 310 |
+
# json_match = re.search(r"```json\n(.*?)```", response.content, re.DOTALL)
|
| 311 |
+
# if not json_match:
|
| 312 |
+
# json_match = re.search(r"\{.*\}", response.content, re.DOTALL)
|
| 313 |
+
|
| 314 |
+
# if not json_match:
|
| 315 |
+
# raise ValueError("No JSON found in LLM response")
|
| 316 |
+
|
| 317 |
+
# config = json.loads(json_match.group(1 if "```json" in response.content else 0))
|
| 318 |
+
# vega_lite_path.append(config)
|
| 319 |
+
# print(f"Generated Chart.js Config: {json.dumps(config, indent=2)}")
|
| 320 |
+
|
| 321 |
+
|
| 322 |
+
# # Validate
|
| 323 |
+
# if "type" not in config or "data" not in config:
|
| 324 |
+
# raise ValueError("Invalid Chart.js config")
|
| 325 |
+
|
| 326 |
+
# return config
|
| 327 |
+
def generate_vegalite_visualization(result_df: pd.DataFrame, user_prompt: str) -> dict:
|
| 328 |
+
"""
|
| 329 |
+
Generate Vega-Lite JSON spec from DataFrame using LLM.
|
| 330 |
+
|
| 331 |
+
Args:
|
| 332 |
+
result_df: pandas DataFrame to visualize
|
| 333 |
+
user_prompt: User's visualization request
|
| 334 |
+
|
| 335 |
+
Returns:
|
| 336 |
+
dict: Vega-Lite specification JSON
|
| 337 |
+
"""
|
| 338 |
+
# Initialize the LLM (using OpenAI's GPT-4o-mini as per the original)
|
| 339 |
+
llm = ChatOpenAI(model="gpt-4o-mini", temperature=0.3)
|
| 340 |
+
|
| 341 |
+
# Construct the prompt, incorporating user query and DataFrame structure
|
| 342 |
+
prompt = f"""You are an expert data visualization analyst. Analyze the DataFrame and user query to generate the BEST Vega-Lite visualization.
|
| 343 |
+
|
| 344 |
+
User Query: "{user_prompt}"
|
| 345 |
+
|
| 346 |
+
Your Task:
|
| 347 |
+
1. ANALYZE the DataFrame structure:
|
| 348 |
+
- Identify column types (categorical → nominal, numerical → quantitative, temporal)
|
| 349 |
+
- Detect relationships between columns
|
| 350 |
+
- Understand data distribution and patterns
|
| 351 |
+
|
| 352 |
+
2. DETERMINE the optimal visualization:
|
| 353 |
+
- For comparisons → bar chart
|
| 354 |
+
- For trends over time → line chart
|
| 355 |
+
- For proportions/parts of whole → pie chart
|
| 356 |
+
- For distributions → scatter or point
|
| 357 |
+
- For multivariate comparisons → layered or concatenated views
|
| 358 |
+
- Match chart type to data characteristics AND user intent
|
| 359 |
+
|
| 360 |
+
3. GENERATE Vega-Lite JSON with:
|
| 361 |
+
- Correct encoding mappings (e.g., x for categorical/nominal, y for numerical/quantitative, color for distinction)
|
| 362 |
+
- Data as inline JSON from the DataFrame
|
| 363 |
+
- Meaningful chart title derived from the user query and data context
|
| 364 |
+
- Appropriate colors for distinction using the provided palette
|
| 365 |
+
- Professional styling options, including tooltips for interactivity
|
| 366 |
+
- Reasonable chart dimensions (width: 400, height: 300)
|
| 367 |
+
|
| 368 |
+
Available mark types: bar, line, area, pie, point (for scatter), circle (for bubble)
|
| 369 |
+
|
| 370 |
+
Color palette: ["#FF6384", "#36A2EB", "#FFCE56", "#4BC0C0", "#9966FF", "#FF9F40"]
|
| 371 |
+
|
| 372 |
+
**CRITICAL**: Output ONLY valid JSON with all property names enclosed in double quotes. Do NOT include any text, comments, or explanations outside the JSON. Ensure the JSON is complete and properly formatted for Vega-Lite v5 schema. All property names MUST be enclosed in double quotes.
|
| 373 |
+
|
| 374 |
+
Example of valid output:
|
| 375 |
+
{{
|
| 376 |
+
"$schema": "https://vega.github.io/schema/vega-lite/v5.json",
|
| 377 |
+
"data": {{
|
| 378 |
+
"values": [
|
| 379 |
+
{{"category": "A", "value": 10}},
|
| 380 |
+
{{"category": "B", "value": 20}},
|
| 381 |
+
{{"category": "C", "value": 30}}
|
| 382 |
+
]
|
| 383 |
+
}},
|
| 384 |
+
"mark": "bar",
|
| 385 |
+
"encoding": {{
|
| 386 |
+
"x": {{"field": "category", "type": "nominal", "title": "Categories"}},
|
| 387 |
+
"y": {{"field": "value", "type": "quantitative", "title": "Values"}},
|
| 388 |
+
"color": {{"field": "category", "type": "nominal", "scale": {{"range": ["#FF6384", "#36A2EB", "#FFCE56"]}}}},
|
| 389 |
+
"tooltip": [
|
| 390 |
+
{{"field": "category", "type": "nominal", "title": "Category"}},
|
| 391 |
+
{{"field": "value", "type": "quantitative", "title": "Value"}}
|
| 392 |
+
]
|
| 393 |
+
}},
|
| 394 |
+
"width": 400,
|
| 395 |
+
"height": 300,
|
| 396 |
+
"title": "Sample Bar Chart"
|
| 397 |
+
}}
|
| 398 |
+
|
| 399 |
+
Output ONLY the JSON content:
|
| 400 |
+
{{
|
| 401 |
+
"$schema": "https://vega.github.io/schema/vega-lite/v5.json",
|
| 402 |
+
"data": {{
|
| 403 |
+
"values": []
|
| 404 |
+
}},
|
| 405 |
+
"mark": "best_mark_type",
|
| 406 |
+
"encoding": {{
|
| 407 |
+
"x": {{"field": "derived_field", "type": "nominal", "title": "X Axis Label"}},
|
| 408 |
+
"y": {{"field": "derived_field", "type": "quantitative", "title": "Y Axis Label"}},
|
| 409 |
+
"color": {{"field": "derived_field", "type": "nominal", "scale": {{"range": ["#FF6384", "#36A2EB", "#FFCE56"]}}}},
|
| 410 |
+
"tooltip": [
|
| 411 |
+
{{"field": "derived_field", "type": "nominal", "title": "Label"}},
|
| 412 |
+
{{"field": "derived_field", "type": "quantitative", "title": "Value"}}
|
| 413 |
+
]
|
| 414 |
+
}},
|
| 415 |
+
"width": 400,
|
| 416 |
+
"height": 300,
|
| 417 |
+
"title": "Descriptive Title Based on Data"
|
| 418 |
+
}}
|
| 419 |
+
"""
|
| 420 |
+
|
| 421 |
+
# Convert DataFrame to string for LLM input
|
| 422 |
+
df_string = result_df.to_string()
|
| 423 |
+
|
| 424 |
+
# Create messages for LLM
|
| 425 |
+
messages = [
|
| 426 |
+
SystemMessage(content=prompt),
|
| 427 |
+
HumanMessage(content=df_string)
|
| 428 |
+
]
|
| 429 |
+
|
| 430 |
+
# Invoke LLM to generate the Vega-Lite spec
|
| 431 |
+
try:
|
| 432 |
+
response = llm.invoke(messages)
|
| 433 |
+
print(f"Raw LLM Response:\n{response.content}\n")
|
| 434 |
+
except Exception as e:
|
| 435 |
+
raise ValueError(f"LLM invocation failed: {str(e)}")
|
| 436 |
+
|
| 437 |
+
# Extract JSON block from LLM response
|
| 438 |
+
json_match = re.search(r"\{.*\}", response.content, re.DOTALL)
|
| 439 |
+
if not json_match:
|
| 440 |
+
raise ValueError(f"No valid JSON block found in LLM response:\n{response.content}")
|
| 441 |
+
|
| 442 |
+
json_string = json_match.group(0).strip()
|
| 443 |
+
print(f"Extracted JSON String:\n{json_string}\n")
|
| 444 |
+
|
| 445 |
+
# Parse JSON string into a Python dict
|
| 446 |
+
try:
|
| 447 |
+
config = json.loads(json_string)
|
| 448 |
+
except json.JSONDecodeError as e:
|
| 449 |
+
raise ValueError(f"Invalid JSON format: {str(e)}\nJSON String:\n{json_string}")
|
| 450 |
+
|
| 451 |
+
# Log the generated config for debugging
|
| 452 |
+
print(f"Generated Vega-Lite Spec:\n{json.dumps(config, indent=2)}")
|
| 453 |
+
vega_lite_path.append(config)
|
| 454 |
+
|
| 455 |
+
# Validate essential Vega-Lite properties
|
| 456 |
+
print(f"Generated Vega-Lite Config: {json.dumps(config, indent=2)}")
|
| 457 |
+
|
| 458 |
+
|
| 459 |
+
# Validate
|
| 460 |
+
if "type" not in config or "data" not in config:
|
| 461 |
+
raise ValueError("Invalid Vega-Lite config")
|
| 462 |
+
|
| 463 |
+
return config
|
| 464 |
+
# =====================================
|
| 465 |
+
# 🔹 VECTOR AGENT TOOL
|
| 466 |
+
# =====================================
|
| 467 |
+
def query_vector_agent_calling(user_query: str, collection_name: str) -> str:
|
| 468 |
+
"""
|
| 469 |
+
Call the vector agent API to get relevant context for the given user query and collection name.
|
| 470 |
+
"""
|
| 471 |
+
base_url = "https://srivatsavdamaraju-mvp-2-0-deploy-all-apis.hf.space/qdrant/search"
|
| 472 |
+
url = f"{base_url}?collection_name={collection_name}&mode=dense"
|
| 473 |
+
|
| 474 |
+
headers = {
|
| 475 |
+
"accept": "application/json",
|
| 476 |
+
"Content-Type": "application/json",
|
| 477 |
+
}
|
| 478 |
+
|
| 479 |
+
payload = {
|
| 480 |
+
"query": user_query,
|
| 481 |
+
"top_k": 2
|
| 482 |
+
}
|
| 483 |
+
|
| 484 |
+
try:
|
| 485 |
+
response = requests.post(url, headers=headers, json=payload)
|
| 486 |
+
response.raise_for_status()
|
| 487 |
+
|
| 488 |
+
data = response.json()
|
| 489 |
+
results = data.get("results") or data.get("result") or data
|
| 490 |
+
if not results:
|
| 491 |
+
return "No relevant context found."
|
| 492 |
+
|
| 493 |
+
contexts = []
|
| 494 |
+
for item in results:
|
| 495 |
+
text = item.get("text") or item.get("payload", {}).get("text")
|
| 496 |
+
if text:
|
| 497 |
+
contexts.append(text)
|
| 498 |
+
|
| 499 |
+
return "\n\n".join(contexts) if contexts else str(data)
|
| 500 |
+
|
| 501 |
+
except requests.RequestException as e:
|
| 502 |
+
print(f"Error calling vector agent API: {e}")
|
| 503 |
+
return "Error retrieving context."
|
| 504 |
+
|
| 505 |
+
@function_tool
|
| 506 |
+
async def vector_agent_tool(
|
| 507 |
+
user_query: Annotated[str, "The user's natural language question"],
|
| 508 |
+
collection_name: Annotated[str, "The name of the collection to search in"]
|
| 509 |
+
) -> str:
|
| 510 |
+
"""
|
| 511 |
+
Tool to call the vector agent API and retrieve relevant context.
|
| 512 |
+
|
| 513 |
+
Args:
|
| 514 |
+
user_query: The user's natural language question.
|
| 515 |
+
collection_name: The name of the collection to search in.
|
| 516 |
+
|
| 517 |
+
Returns:
|
| 518 |
+
A string containing the retrieved context.
|
| 519 |
+
"""
|
| 520 |
+
return query_vector_agent_calling(user_query, collection_name)
|
| 521 |
+
|
| 522 |
+
# =====================================
|
| 523 |
+
# 🔹 SQL TOOL
|
| 524 |
+
# =====================================
|
| 525 |
+
@function_tool
|
| 526 |
+
async def sql_tool(
|
| 527 |
+
sql_query: Annotated[str, "SQL query to execute on registered datasets"],
|
| 528 |
+
user_query: Annotated[str, "Original user question for visualization context"] = ""
|
| 529 |
+
) -> SQLResult:
|
| 530 |
+
"""
|
| 531 |
+
Execute SQL query on registered datasets and return results.
|
| 532 |
+
|
| 533 |
+
Available tables: CSV1, CSV2, CSV3, etc. (based on loaded files)
|
| 534 |
+
Supported operations: SELECT, WITH (CTE)
|
| 535 |
+
|
| 536 |
+
Args:
|
| 537 |
+
sql_query: The SQL query to execute
|
| 538 |
+
user_query: The original user question (REQUIRED for visualization context)
|
| 539 |
+
|
| 540 |
+
Example queries:
|
| 541 |
+
- SELECT * FROM CSV1 WHERE age > 50 LIMIT 10
|
| 542 |
+
- SELECT a.name, b.value FROM CSV1 a JOIN CSV2 b ON a.id = b.id
|
| 543 |
+
|
| 544 |
+
Example usage:
|
| 545 |
+
- sql_tool(
|
| 546 |
+
sql_query="SELECT * FROM CSV1",
|
| 547 |
+
user_query="What is the average age?"
|
| 548 |
+
)
|
| 549 |
+
"""
|
| 550 |
+
global _registered_datasets, last_sql_query
|
| 551 |
+
|
| 552 |
+
last_sql_query = sql_query
|
| 553 |
+
|
| 554 |
+
start_time = datetime.now()
|
| 555 |
+
print(f"\n{'='*70}")
|
| 556 |
+
print(f"🔍 Executing SQL Query:\n{sql_query}\n")
|
| 557 |
+
print(f"📝 User Query Context: {user_query}\n")
|
| 558 |
+
print(f"{'='*70}\n")
|
| 559 |
+
|
| 560 |
+
try:
|
| 561 |
+
# Validate query
|
| 562 |
+
normalized_query = sql_query.strip().lower()
|
| 563 |
+
if not normalized_query.startswith(ALLOWED_SQL_KEYWORDS):
|
| 564 |
+
raise ValueError(
|
| 565 |
+
f"Only {', '.join(k.upper() for k in ALLOWED_SQL_KEYWORDS)} queries are allowed."
|
| 566 |
+
)
|
| 567 |
+
|
| 568 |
+
if not _registered_datasets:
|
| 569 |
+
raise ValueError("No datasets loaded. Please provide file_paths first.")
|
| 570 |
+
|
| 571 |
+
# Execute query using DuckDB
|
| 572 |
+
conn = duckdb.connect(":memory:")
|
| 573 |
+
|
| 574 |
+
for table_name, df in _registered_datasets.items():
|
| 575 |
+
conn.register(table_name, df)
|
| 576 |
+
|
| 577 |
+
result_df = conn.execute(sql_query).df()
|
| 578 |
+
conn.close()
|
| 579 |
+
|
| 580 |
+
if len(result_df) > MAX_RESULT_ROWS:
|
| 581 |
+
result_df = result_df.head(MAX_RESULT_ROWS)
|
| 582 |
+
|
| 583 |
+
execution_time = (datetime.now() - start_time).total_seconds() * 1000
|
| 584 |
+
|
| 585 |
+
summary_lines = [
|
| 586 |
+
f"Query returned {len(result_df)} rows with {len(result_df.columns)} columns: {', '.join(result_df.columns.tolist())}"
|
| 587 |
+
]
|
| 588 |
+
|
| 589 |
+
if not result_df.empty:
|
| 590 |
+
summary_lines.append("\nResults preview (first 10 rows):")
|
| 591 |
+
preview = result_df.head(10).to_string(index=False, max_rows=10)
|
| 592 |
+
summary_lines.append(preview)
|
| 593 |
+
|
| 594 |
+
result_summary = "\n".join(summary_lines)
|
| 595 |
+
|
| 596 |
+
print(f"✅ SQL executed: {len(result_df)} rows in {execution_time:.2f}ms\n")
|
| 597 |
+
|
| 598 |
+
return SQLResult(
|
| 599 |
+
sql_query=sql_query,
|
| 600 |
+
result_summary=result_summary,
|
| 601 |
+
columns=result_df.columns.tolist(),
|
| 602 |
+
row_count=len(result_df),
|
| 603 |
+
execution_time_ms=round(execution_time, 2),
|
| 604 |
+
)
|
| 605 |
+
|
| 606 |
+
except ValueError as ve:
|
| 607 |
+
raise ValueError(f"SQL validation error: {str(ve)}")
|
| 608 |
+
except Exception as e:
|
| 609 |
+
raise RuntimeError(f"SQL execution error: {str(e)}")
|
| 610 |
+
|
| 611 |
+
# =====================================
|
| 612 |
+
# 🔹 VISUALIZATION TASK
|
| 613 |
+
# =====================================
|
| 614 |
+
async def generate_visualization_task(query: str) -> Optional[dict]:
|
| 615 |
+
"""Async task to generate Chart.js visualization using last_sql_query and upload to S3."""
|
| 616 |
+
global _registered_datasets, last_sql_query, vega_lite_path
|
| 617 |
+
|
| 618 |
+
if not last_sql_query or not _registered_datasets:
|
| 619 |
+
print("⚠️ No SQL query executed or no datasets loaded, skipping visualization.")
|
| 620 |
+
return None
|
| 621 |
+
|
| 622 |
+
print(f"🎨 Generating Chart.js visualization...")
|
| 623 |
+
try:
|
| 624 |
+
# Re-execute the last SQL query to get the result DataFrame
|
| 625 |
+
conn = duckdb.connect(":memory:")
|
| 626 |
+
for table_name, df in _registered_datasets.items():
|
| 627 |
+
conn.register(table_name, df)
|
| 628 |
+
result_df = conn.execute(last_sql_query).df()
|
| 629 |
+
conn.close()
|
| 630 |
+
|
| 631 |
+
# Ensure user_query is a non-empty string
|
| 632 |
+
visualization_prompt = query.strip() if query and query.strip() else "Visualize the query results"
|
| 633 |
+
|
| 634 |
+
print(f"🎯 User prompt for visualization: '{visualization_prompt}'")
|
| 635 |
+
print(f"🎯 User prompt type: {type(visualization_prompt)}")
|
| 636 |
+
print("+"*60)
|
| 637 |
+
|
| 638 |
+
# Generate Vega-Lite visualization
|
| 639 |
+
vega_lite_config = generate_vegalite_visualization(
|
| 640 |
+
result_df=result_df,
|
| 641 |
+
user_prompt=visualization_prompt
|
| 642 |
+
)
|
| 643 |
+
print(f"✅ Vega-Lite visualization generated.")
|
| 644 |
+
print("+"*60)
|
| 645 |
+
print(f"Vega-Lite Config: {json.dumps(vega_lite_config, indent=2)}")
|
| 646 |
+
|
| 647 |
+
return vega_lite_config
|
| 648 |
+
|
| 649 |
+
except ValueError as ve:
|
| 650 |
+
raise ValueError(f"Visualization error: {str(ve)}")
|
| 651 |
+
except Exception as e:
|
| 652 |
+
raise RuntimeError(f"Visualization error: {str(e)}")
|
| 653 |
+
# =====================================
|
| 654 |
+
# 🔹 MAIN ANALYSIS FUNCTION
|
| 655 |
+
# =====================================
|
| 656 |
+
# async def main(query: str, file_paths: List[str] = None) -> str:
|
| 657 |
+
# """Main function to run data analysis with proper agent initialization."""
|
| 658 |
+
|
| 659 |
+
# if file_paths is None:
|
| 660 |
+
# file_paths = []
|
| 661 |
+
|
| 662 |
+
# print(f"\n{'='*70}")
|
| 663 |
+
# print(f"📋 Analysis Request")
|
| 664 |
+
# print(f"{'='*70}")
|
| 665 |
+
# print(f"Files: {', '.join(file_paths)}")
|
| 666 |
+
# print(f"Query: {query}")
|
| 667 |
+
# print(f"{'='*70}\n")
|
| 668 |
+
|
| 669 |
+
# dataset_context = load_and_register_datasets(file_paths)
|
| 670 |
+
|
| 671 |
+
# print(f"\n{'='*70}")
|
| 672 |
+
# print("📊 Dataset Context Summary")
|
| 673 |
+
# print(f"{'='*70}")
|
| 674 |
+
# print(dataset_context[:500] + "..." if len(dataset_context) > 500 else dataset_context)
|
| 675 |
+
# print(f"{'='*70}\n")
|
| 676 |
+
|
| 677 |
+
# general_purpose_agent = Agent(
|
| 678 |
+
# name="general_purpose_agent",
|
| 679 |
+
# instructions=(
|
| 680 |
+
# "You are a friendly AI assistant with high expertise in data analysis, SQL, and visualization. "
|
| 681 |
+
# "Only answer if the query is not related to data analysis."
|
| 682 |
+
# ),
|
| 683 |
+
# model="gpt-4o-mini",
|
| 684 |
+
# tools=[]
|
| 685 |
+
# )
|
| 686 |
+
|
| 687 |
+
# DataAnalyst_agent = Agent(
|
| 688 |
+
# name="DataAnalyst_agent",
|
| 689 |
+
# instructions=(
|
| 690 |
+
# f"""You are a professional data analyst with expertise in SQL and data analysis.
|
| 691 |
+
|
| 692 |
+
# {dataset_context}
|
| 693 |
+
|
| 694 |
+
# Your workflow:
|
| 695 |
+
# 1. Analyze the user's question carefully
|
| 696 |
+
# 2. Review the available datasets and their schemas above
|
| 697 |
+
# 3. **Generate appropriate SQL query using sql_tool - ALWAYS pass the user_query parameter**
|
| 698 |
+
|
| 699 |
+
# CRITICAL - Tool Usage:
|
| 700 |
+
# - Call sql_tool with BOTH parameters:
|
| 701 |
+
# * sql_query: Your SQL SELECT statement
|
| 702 |
+
# * user_query: The EXACT original user question (this creates the visualization)
|
| 703 |
+
|
| 704 |
+
# Example correct usage:
|
| 705 |
+
# sql_tool(
|
| 706 |
+
# sql_query="SELECT age, COUNT(*) as count FROM CSV1 GROUP BY age ORDER BY age",
|
| 707 |
+
# user_query="Show me the age distribution of patients"
|
| 708 |
+
# )
|
| 709 |
+
|
| 710 |
+
# - Use table names: CSV1, CSV2, CSV3, etc.
|
| 711 |
+
# - You can JOIN tables if needed
|
| 712 |
+
# - Use descriptive column aliases
|
| 713 |
+
# - Add ORDER BY for readability
|
| 714 |
+
# - Use LIMIT for large result sets
|
| 715 |
+
|
| 716 |
+
# 4. Interpret the results and provide clear, detailed insights with specific numbers and percentages
|
| 717 |
+
|
| 718 |
+
# SQL Best Practices:
|
| 719 |
+
# - Always check which columns exist in which tables
|
| 720 |
+
# - Use table aliases for clarity (e.g., CSV1 AS c1)
|
| 721 |
+
# - Filter data appropriately based on the question
|
| 722 |
+
# - Aggregate data when asked for counts, sums, averages
|
| 723 |
+
# - Sort results logically
|
| 724 |
+
|
| 725 |
+
# Important:
|
| 726 |
+
# - Only use columns that exist in the datasets above
|
| 727 |
+
# - Reference the correct table names (CSV1, CSV2, etc.)
|
| 728 |
+
# - **MANDATORY: Always pass both sql_query AND user_query to sql_tool**
|
| 729 |
+
# - The user_query parameter is used to generate intelligent visualizations
|
| 730 |
+
# - Always provide complete analysis with specific findings
|
| 731 |
+
# - Include tables, numbers, and percentages in your response
|
| 732 |
+
# - Provide insights in plain language for non-technical users
|
| 733 |
+
# """
|
| 734 |
+
# ),
|
| 735 |
+
# model="gpt-4o-mini",
|
| 736 |
+
# tools=[sql_tool],
|
| 737 |
+
# )
|
| 738 |
+
|
| 739 |
+
# orchestrator_agent = Agent(
|
| 740 |
+
# name="orchestrator_agent",
|
| 741 |
+
# instructions=(
|
| 742 |
+
# "You receive the user's query and route it to the DataAnalyst agent for execution. "
|
| 743 |
+
# "Always use the analyze_dataset tool to process data analysis queries. "
|
| 744 |
+
# "Pass the complete user query to the DataAnalyst agent."
|
| 745 |
+
# ),
|
| 746 |
+
# tools=[
|
| 747 |
+
# DataAnalyst_agent.as_tool(
|
| 748 |
+
# tool_name="analyze_dataset",
|
| 749 |
+
# tool_description="Analyze dataset, generate SQL, and visualize results."
|
| 750 |
+
# ),
|
| 751 |
+
# general_purpose_agent.as_tool(
|
| 752 |
+
# tool_name="general_purpose_tool",
|
| 753 |
+
# tool_description="Act like a human friendly AI assistant for non-data analysis queries."
|
| 754 |
+
# )
|
| 755 |
+
# ],
|
| 756 |
+
# model="gpt-4o-mini",
|
| 757 |
+
# )
|
| 758 |
+
|
| 759 |
+
# print(f"\n{'='*70}")
|
| 760 |
+
# print("🚀 Running Orchestrator Agent")
|
| 761 |
+
# print(f"{'='*70}\n")
|
| 762 |
+
|
| 763 |
+
# orchestrator_result = await Runner.run(orchestrator_agent, query)
|
| 764 |
+
|
| 765 |
+
# print(f"\n{'='*70}")
|
| 766 |
+
# print("📝 Orchestrator Output")
|
| 767 |
+
# print(f"{'='*70}")
|
| 768 |
+
|
| 769 |
+
# analysis_text = ""
|
| 770 |
+
# for item in orchestrator_result.new_items:
|
| 771 |
+
# if hasattr(item, 'output') and item.output:
|
| 772 |
+
# analysis_text = item.output
|
| 773 |
+
# print(f"Found analysis output: {len(analysis_text)} characters")
|
| 774 |
+
# elif hasattr(item, 'raw_item') and hasattr(item.raw_item, 'content'):
|
| 775 |
+
# for content_item in item.raw_item.content:
|
| 776 |
+
# if hasattr(content_item, 'text'):
|
| 777 |
+
# analysis_text = content_item.text
|
| 778 |
+
# print(f"Found message output: {len(analysis_text)} characters")
|
| 779 |
+
|
| 780 |
+
# print(f"{'='*70}\n")
|
| 781 |
+
|
| 782 |
+
# synthesizer_agent = Agent(
|
| 783 |
+
# name="synthesizer_agent",
|
| 784 |
+
# instructions=(
|
| 785 |
+
# """You are a final report writer. Your job is to present the complete data analysis results.
|
| 786 |
+
|
| 787 |
+
# CRITICAL RULES:
|
| 788 |
+
# 1. Present the COMPLETE analysis from the DataAnalyst agent
|
| 789 |
+
# 2. Include ALL tables, numbers, statistics, and findings
|
| 790 |
+
# 3. Maintain all the detailed insights and percentages
|
| 791 |
+
# 4. Use clear formatting with headers, tables, and bullet points
|
| 792 |
+
# 5. DO NOT summarize or truncate the analysis
|
| 793 |
+
# 6. DO NOT add generic closing statements like "feel free to ask"
|
| 794 |
+
# 7. Make the report professional and easy to read
|
| 795 |
+
|
| 796 |
+
# Your output should be the full, detailed analysis report with all findings intact.
|
| 797 |
+
# """
|
| 798 |
+
# ),
|
| 799 |
+
# model="gpt-4o-mini",
|
| 800 |
+
# )
|
| 801 |
+
|
| 802 |
+
# print(f"\n{'='*70}")
|
| 803 |
+
# print("🎯 Running Synthesizer and Visualization in Parallel")
|
| 804 |
+
# print(f"{'='*70}\n")
|
| 805 |
+
|
| 806 |
+
# synthesizer_task = asyncio.create_task(
|
| 807 |
+
# Runner.run(
|
| 808 |
+
# synthesizer_agent,
|
| 809 |
+
# f"Present this complete data analysis:\n\n{analysis_text}" if analysis_text else orchestrator_result.to_input_list()
|
| 810 |
+
# )
|
| 811 |
+
# )
|
| 812 |
+
# visualization_task = asyncio.create_task(generate_visualization_task(query))
|
| 813 |
+
|
| 814 |
+
# synthesizer_result, visualization_result = await asyncio.gather(synthesizer_task, visualization_task, return_exceptions=True)
|
| 815 |
+
|
| 816 |
+
# if isinstance(synthesizer_result, Exception):
|
| 817 |
+
# print(f"⚠️ Synthesizer task failed: {str(synthesizer_result)}")
|
| 818 |
+
# synthesizer_result = type('Result', (), {'final_output': 'Synthesizer task failed'})()
|
| 819 |
+
|
| 820 |
+
# if isinstance(visualization_result, Exception):
|
| 821 |
+
# print(f"⚠️ Visualization task failed: {str(visualization_result)}")
|
| 822 |
+
# visualization_result = None
|
| 823 |
+
|
| 824 |
+
# print(f"\n{'='*70}")
|
| 825 |
+
# print("✅ FINAL ANSWER")
|
| 826 |
+
# print(f"{'='*70}")
|
| 827 |
+
# print(synthesizer_result.final_output)
|
| 828 |
+
# print(f"{'='*70}\n")
|
| 829 |
+
|
| 830 |
+
# print(f"chartjs_config_path: {vega_lite_path}")
|
| 831 |
+
# print(f"{'='*70}\n")
|
| 832 |
+
|
| 833 |
+
# final_api_output = synthesizer_result.final_output + f"\n\nChart.js Configuration saved at: {vega_lite_path[-1] if vega_lite_path else 'N/A'}"
|
| 834 |
+
# hidden_code = True if last_sql_query else False
|
| 835 |
+
# final_json_api_output = {
|
| 836 |
+
# "convoId": "1234",
|
| 837 |
+
# "type": "assistant_message",
|
| 838 |
+
# "hidden_code": hidden_code,
|
| 839 |
+
# "hidden_text": last_sql_query,
|
| 840 |
+
# "ai_message": synthesizer_result.final_output,
|
| 841 |
+
# "artifacts": vega_lite_path
|
| 842 |
+
# }
|
| 843 |
+
# print(f"Final JSON Output: {final_json_api_output}")
|
| 844 |
+
# return json.dumps(final_json_api_output)
|
| 845 |
+
# return final_json_api_output
|
| 846 |
+
|
| 847 |
+
|
| 848 |
+
|
| 849 |
+
|
| 850 |
+
async def main(query: str, file_paths: List[str] = None) -> str:
|
| 851 |
+
"""Main function to run data analysis with proper agent initialization and structured output."""
|
| 852 |
+
|
| 853 |
+
if file_paths is None:
|
| 854 |
+
file_paths = []
|
| 855 |
+
|
| 856 |
+
print(f"\n{'='*70}")
|
| 857 |
+
print(f"📋 Analysis Request")
|
| 858 |
+
print(f"{'='*70}")
|
| 859 |
+
print(f"Files: {', '.join(file_paths)}")
|
| 860 |
+
print(f"Query: {query}")
|
| 861 |
+
print(f"{'='*70}\n")
|
| 862 |
+
|
| 863 |
+
dataset_context = load_and_register_datasets(file_paths)
|
| 864 |
+
|
| 865 |
+
print(f"\n{'='*70}")
|
| 866 |
+
print("📊 Dataset Context Summary")
|
| 867 |
+
print(f"{'='*70}")
|
| 868 |
+
print(dataset_context[:500] + "..." if len(dataset_context) > 500 else dataset_context)
|
| 869 |
+
print(f"{'='*70}\n")
|
| 870 |
+
|
| 871 |
+
general_purpose_agent = Agent(
|
| 872 |
+
name="general_purpose_agent",
|
| 873 |
+
instructions=(
|
| 874 |
+
"You are a friendly AI assistant with high expertise in data analysis, SQL, and visualization. "
|
| 875 |
+
"Only answer if the query is not related to data analysis."
|
| 876 |
+
),
|
| 877 |
+
model="gpt-4o-mini",
|
| 878 |
+
tools=[]
|
| 879 |
+
)
|
| 880 |
+
|
| 881 |
+
DataAnalyst_agent = Agent(
|
| 882 |
+
name="DataAnalyst_agent",
|
| 883 |
+
instructions=(
|
| 884 |
+
f"""You are a professional data analyst with expertise in SQL and data analysis.
|
| 885 |
+
|
| 886 |
+
{dataset_context}
|
| 887 |
+
|
| 888 |
+
Your workflow:
|
| 889 |
+
1. Analyze the user's question carefully
|
| 890 |
+
2. Review the available datasets and their schemas above
|
| 891 |
+
3. **Generate appropriate SQL query using sql_tool - ALWAYS pass the user_query parameter**
|
| 892 |
+
|
| 893 |
+
CRITICAL - Tool Usage:
|
| 894 |
+
- Call sql_tool with BOTH parameters:
|
| 895 |
+
* sql_query: Your SQL SELECT statement
|
| 896 |
+
* user_query: The EXACT original user question (this creates the visualization)
|
| 897 |
+
|
| 898 |
+
Example correct usage:
|
| 899 |
+
sql_tool(
|
| 900 |
+
sql_query="SELECT age, COUNT(*) as count FROM CSV1 GROUP BY age ORDER BY age",
|
| 901 |
+
user_query="Show me the age distribution of patients"
|
| 902 |
+
)
|
| 903 |
+
|
| 904 |
+
- Use table names: CSV1, CSV2, CSV3, etc.
|
| 905 |
+
- You can JOIN tables if needed
|
| 906 |
+
- Use descriptive column aliases
|
| 907 |
+
- Add ORDER BY for readability
|
| 908 |
+
- Use LIMIT for large result sets
|
| 909 |
+
|
| 910 |
+
4. Interpret the results and provide clear, detailed insights with specific numbers and percentages
|
| 911 |
+
|
| 912 |
+
SQL Best Practices:
|
| 913 |
+
- Always check which columns exist in which tables
|
| 914 |
+
- Use table aliases for clarity (e.g., CSV1 AS c1)
|
| 915 |
+
- Filter data appropriately based on the question
|
| 916 |
+
- Aggregate data when asked for counts, sums, averages
|
| 917 |
+
- Sort results logically
|
| 918 |
+
|
| 919 |
+
Important:
|
| 920 |
+
- Only use columns that exist in the datasets above
|
| 921 |
+
- Reference the correct table names (CSV1, CSV2, etc.)
|
| 922 |
+
- **MANDATORY: Always pass both sql_query AND user_query to sql_tool**
|
| 923 |
+
- The user_query parameter is used to generate intelligent visualizations
|
| 924 |
+
- Always provide complete analysis with specific findings
|
| 925 |
+
- Include tables, numbers, and percentages in your response
|
| 926 |
+
- Provide insights in plain language for non-technical users
|
| 927 |
+
"""
|
| 928 |
+
),
|
| 929 |
+
model="gpt-4o-mini",
|
| 930 |
+
tools=[sql_tool],
|
| 931 |
+
)
|
| 932 |
+
|
| 933 |
+
orchestrator_agent = Agent(
|
| 934 |
+
name="orchestrator_agent",
|
| 935 |
+
instructions=(
|
| 936 |
+
"You receive the user's query and route it to the DataAnalyst agent for execution. "
|
| 937 |
+
"Always use the analyze_dataset tool to process data analysis queries. "
|
| 938 |
+
"Pass the complete user query to the DataAnalyst agent."
|
| 939 |
+
),
|
| 940 |
+
tools=[
|
| 941 |
+
DataAnalyst_agent.as_tool(
|
| 942 |
+
tool_name="analyze_dataset",
|
| 943 |
+
tool_description="Analyze dataset, generate SQL, and visualize results."
|
| 944 |
+
),
|
| 945 |
+
general_purpose_agent.as_tool(
|
| 946 |
+
tool_name="general_purpose_tool",
|
| 947 |
+
tool_description="Act like a human friendly AI assistant for non-data analysis queries."
|
| 948 |
+
)
|
| 949 |
+
],
|
| 950 |
+
model="gpt-4o-mini",
|
| 951 |
+
)
|
| 952 |
+
|
| 953 |
+
print(f"\n{'='*70}")
|
| 954 |
+
print("🚀 Running Orchestrator Agent")
|
| 955 |
+
print(f"{'='*70}\n")
|
| 956 |
+
|
| 957 |
+
orchestrator_result = await Runner.run(orchestrator_agent, query)
|
| 958 |
+
|
| 959 |
+
print(f"\n{'='*70}")
|
| 960 |
+
print("📝 Orchestrator Output")
|
| 961 |
+
print(f"{'='*70}")
|
| 962 |
+
|
| 963 |
+
analysis_text = ""
|
| 964 |
+
for item in orchestrator_result.new_items:
|
| 965 |
+
if hasattr(item, 'output') and item.output:
|
| 966 |
+
analysis_text = item.output
|
| 967 |
+
print(f"Found analysis output: {len(analysis_text)} characters")
|
| 968 |
+
elif hasattr(item, 'raw_item') and hasattr(item.raw_item, 'content'):
|
| 969 |
+
for content_item in item.raw_item.content:
|
| 970 |
+
if hasattr(content_item, 'text'):
|
| 971 |
+
analysis_text = content_item.text
|
| 972 |
+
print(f"Found message output: {len(analysis_text)} characters")
|
| 973 |
+
|
| 974 |
+
print(f"{'='*70}\n")
|
| 975 |
+
|
| 976 |
+
synthesizer_agent = Agent(
|
| 977 |
+
name="synthesizer_agent",
|
| 978 |
+
instructions=(
|
| 979 |
+
f"""You are a final report writer. Your job is to present the complete data analysis results and suggest follow-up queries.
|
| 980 |
+
|
| 981 |
+
CRITICAL RULES:
|
| 982 |
+
1. Present the COMPLETE analysis from the DataAnalyst agent
|
| 983 |
+
2. Include ALL tables, numbers, statistics, and findings
|
| 984 |
+
3. Maintain all the detailed insights and percentages
|
| 985 |
+
4. Use clear formatting with headers, tables, and bullet points
|
| 986 |
+
5. DO NOT summarize or truncate the analysis
|
| 987 |
+
6. DO NOT add generic closing statements like "feel free to ask"
|
| 988 |
+
7. Generate a 'Suggested Next Queries' section with at least 4 follow-up queries
|
| 989 |
+
- Each query should include:
|
| 990 |
+
- query: A natural language query
|
| 991 |
+
- description: A brief explanation of what the query explores
|
| 992 |
+
- recommended_chart_type: A suitable chart type (bar, line, pie, point, area, circle)
|
| 993 |
+
8. Make the report professional and easy to read
|
| 994 |
+
|
| 995 |
+
Dataset Context for Reference:
|
| 996 |
+
{dataset_context}
|
| 997 |
+
|
| 998 |
+
Your output should be a professional, detailed analysis report with all findings intact, followed by suggested next queries.
|
| 999 |
+
"""
|
| 1000 |
+
),
|
| 1001 |
+
model="gpt-4o-mini",
|
| 1002 |
+
)
|
| 1003 |
+
|
| 1004 |
+
print(f"\n{'='*70}")
|
| 1005 |
+
print("🎯 Running Synthesizer and Visualization in Parallel")
|
| 1006 |
+
print(f"{'='*70}\n")
|
| 1007 |
+
|
| 1008 |
+
synthesizer_task = asyncio.create_task(
|
| 1009 |
+
Runner.run(
|
| 1010 |
+
synthesizer_agent,
|
| 1011 |
+
f"Present this complete data analysis:\n\n{analysis_text}" if analysis_text else orchestrator_result.to_input_list()
|
| 1012 |
+
)
|
| 1013 |
+
)
|
| 1014 |
+
visualization_task = asyncio.create_task(generate_visualization_task(query))
|
| 1015 |
+
|
| 1016 |
+
synthesizer_result, visualization_result = await asyncio.gather(synthesizer_task, visualization_task, return_exceptions=True)
|
| 1017 |
+
|
| 1018 |
+
if isinstance(synthesizer_result, Exception):
|
| 1019 |
+
print(f"⚠️ Synthesizer task failed: {str(synthesizer_result)}")
|
| 1020 |
+
synthesizer_result = type('Result', (), {'final_output': 'Synthesizer task failed'})()
|
| 1021 |
+
|
| 1022 |
+
if isinstance(visualization_result, Exception):
|
| 1023 |
+
print(f"⚠️ Visualization task failed: {str(visualization_result)}")
|
| 1024 |
+
visualization_result = None
|
| 1025 |
+
|
| 1026 |
+
# Extract suggested next queries from synthesizer output
|
| 1027 |
+
# next_queries = []
|
| 1028 |
+
# if synthesizer_result.final_output:
|
| 1029 |
+
# # Use regex to extract the Suggested Next Queries section
|
| 1030 |
+
# next_queries_match = re.search(r"## Suggested Next Queries\n\n(.*?)(?:\n\n|$)", synthesizer_result.final_output, re.DOTALL)
|
| 1031 |
+
# if next_queries_match:
|
| 1032 |
+
# queries_text = next_queries_match.group(1).strip()
|
| 1033 |
+
# # Split by list items and parse
|
| 1034 |
+
# query_items = queries_text.split('\n- ')
|
| 1035 |
+
# for item in query_items:
|
| 1036 |
+
# if not item.strip():
|
| 1037 |
+
# continue
|
| 1038 |
+
# # Extract query, description, and chart type using regex
|
| 1039 |
+
# query_match = re.search(r"Query: (.*?); Description: (.*?); Recommended Chart Type: (\w+)", item.strip())
|
| 1040 |
+
# if query_match:
|
| 1041 |
+
# next_queries.append({
|
| 1042 |
+
# "query": query_match.group(1).strip(),
|
| 1043 |
+
# "description": query_match.group(2).strip(),
|
| 1044 |
+
# "recommended_chart_type": query_match.group(3).strip()
|
| 1045 |
+
# })
|
| 1046 |
+
|
| 1047 |
+
# Extract suggested next queries from synthesizer output
|
| 1048 |
+
# Extract suggested next queries from synthesizer output
|
| 1049 |
+
next_queries = []
|
| 1050 |
+
if synthesizer_result.final_output:
|
| 1051 |
+
# Capture everything under "## Suggested Next Queries"
|
| 1052 |
+
match = re.search(
|
| 1053 |
+
r"## Suggested Next Queries\s*(.*?)(?=\n## |\Z)",
|
| 1054 |
+
synthesizer_result.final_output,
|
| 1055 |
+
re.DOTALL
|
| 1056 |
+
)
|
| 1057 |
+
|
| 1058 |
+
if match:
|
| 1059 |
+
section = match.group(1).strip()
|
| 1060 |
+
|
| 1061 |
+
# Split by either "### Query" or numbered "1. **Query:**" formats
|
| 1062 |
+
query_blocks = re.split(r"(?:###\s*Query\s*\d+|^\d+\.\s+\*\*Query:\*\*)", section, flags=re.MULTILINE)
|
| 1063 |
+
|
| 1064 |
+
for block in query_blocks[1:]: # skip header/non-query text
|
| 1065 |
+
block = block.strip()
|
| 1066 |
+
if not block:
|
| 1067 |
+
continue
|
| 1068 |
+
|
| 1069 |
+
# Extract fields (support both markdown and colon formats)
|
| 1070 |
+
query_match = re.search(r"\*\*Query[:\*]*\s*\**\s*(.*)", block)
|
| 1071 |
+
desc_match = re.search(r"-\s*\*\*Description[:\*]*\s*\**\s*(.*)", block)
|
| 1072 |
+
chart_match = re.search(r"-\s*\*\*Recommended Chart Type[:\*]*\s*\**\s*(.*)", block)
|
| 1073 |
+
|
| 1074 |
+
# Also handle colon-only versions (like in your example)
|
| 1075 |
+
if not query_match:
|
| 1076 |
+
query_match = re.search(r"[:\-]\s*\*\*Query[:\*]*\s*\**\s*(.*)", block)
|
| 1077 |
+
if not desc_match:
|
| 1078 |
+
desc_match = re.search(r"[:\-]\s*\*\*description\*\*[:\-]*\s*(.*)", block, re.IGNORECASE)
|
| 1079 |
+
if not chart_match:
|
| 1080 |
+
chart_match = re.search(r"[:\-]\s*\*\*recommended_chart_type\*\*[:\-]*\s*(.*)", block, re.IGNORECASE)
|
| 1081 |
+
|
| 1082 |
+
next_queries.append({
|
| 1083 |
+
"query": query_match.group(1).strip() if query_match else None,
|
| 1084 |
+
"description": desc_match.group(1).strip() if desc_match else None,
|
| 1085 |
+
"recommended_chart_type": chart_match.group(1).strip() if chart_match else None
|
| 1086 |
+
})
|
| 1087 |
+
|
| 1088 |
+
# Step 5: (Optional) Pretty print to confirm
|
| 1089 |
+
|
| 1090 |
+
# Prepare final JSON output
|
| 1091 |
+
final_json_api_output = {
|
| 1092 |
+
"convoId": "1234", # Hardcoded as per example
|
| 1093 |
+
"type": "assistant_message",
|
| 1094 |
+
"hidden_code": bool(last_sql_query),
|
| 1095 |
+
"hidden_text": last_sql_query if last_sql_query else "",
|
| 1096 |
+
"ai_message": synthesizer_result.final_output,
|
| 1097 |
+
"artifacts": vega_lite_path,
|
| 1098 |
+
"next_queries": next_queries
|
| 1099 |
+
}
|
| 1100 |
+
|
| 1101 |
+
print(f"\n{'='*70}")
|
| 1102 |
+
print("✅ FINAL ANSWER")
|
| 1103 |
+
print(f"{'='*70}")
|
| 1104 |
+
print(json.dumps(final_json_api_output, indent=2))
|
| 1105 |
+
print(f"{'='*70}\n")
|
| 1106 |
+
|
| 1107 |
+
return json.dumps(final_json_api_output)
|
| 1108 |
+
# =====================================
|
| 1109 |
+
# 🔹 STANDALONE TESTING
|
| 1110 |
+
|
Routes/helpers/main_agent_helpers.py
ADDED
|
@@ -0,0 +1,255 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pandas as pd
|
| 2 |
+
from langchain_experimental.agents.agent_toolkits import create_pandas_dataframe_agent
|
| 3 |
+
from langchain_openai import ChatOpenAI
|
| 4 |
+
import os
|
| 5 |
+
import dotenv
|
| 6 |
+
from agents import Agent, Runner, function_tool ,trace
|
| 7 |
+
from agents.extensions.memory import RedisSession
|
| 8 |
+
import asyncio
|
| 9 |
+
from .autovis_tool import run_autoviz
|
| 10 |
+
|
| 11 |
+
# Load environment variables
|
| 12 |
+
dotenv.load_dotenv()
|
| 13 |
+
os.environ["OPENAI_API_KEY"] = os.getenv("OPENAI_API_KEY")
|
| 14 |
+
|
| 15 |
+
#redis_url
|
| 16 |
+
Redis_url = os.getenv("REDIS_URL")
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
# # Load CSV
|
| 20 |
+
# csv_path = r"C:\Users\Dell\Documents\MR-AI\openai_agents\healthcare-data-30.csv"
|
| 21 |
+
# df = pd.read_csv(csv_path)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
data = {
|
| 25 |
+
'Employee_ID': [101, 102, 103, 104, 105, 106],
|
| 26 |
+
'Name': ['Alice Johnson', 'Bob Smith', 'Charlie Brown', 'Diana Prince', 'Eve Adams', 'Frank Miller'],
|
| 27 |
+
'Department': ['Sales', 'IT', 'Marketing', 'Sales', 'IT', 'Finance'],
|
| 28 |
+
'Hire_Date': pd.to_datetime(['2021-05-15', '2020-11-01', '2022-01-20', '2021-08-10', '2023-03-01', '2020-07-25']),
|
| 29 |
+
'Salary': [70000, 95000, 60000, 75000, 88000, 110000],
|
| 30 |
+
'Performance_Score': [4.2, 4.8, 3.5, 4.0, 4.5, 3.9],
|
| 31 |
+
'Is_Manager': [False, True, False, False, True, True]
|
| 32 |
+
}
|
| 33 |
+
|
| 34 |
+
df = pd.DataFrame(data)
|
| 35 |
+
# Initialize LLM
|
| 36 |
+
llm = ChatOpenAI(temperature=0)
|
| 37 |
+
|
| 38 |
+
# Create Pandas agent
|
| 39 |
+
pandas_agent = create_pandas_dataframe_agent(llm, df, verbose=True, allow_dangerous_code=True)
|
| 40 |
+
|
| 41 |
+
# Wrap it as a tool
|
| 42 |
+
@function_tool
|
| 43 |
+
def df_agent(query: str):
|
| 44 |
+
"""Query the healthcare dataset using natural language."""
|
| 45 |
+
response = pandas_agent.invoke(query)
|
| 46 |
+
return response
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
@function_tool
|
| 50 |
+
|
| 51 |
+
def autoviz_tool():
|
| 52 |
+
"""Run AutoViz on the given CSV file and return the plot directory."""
|
| 53 |
+
response = run_autoviz(filename=csv_path,chart_format="html")
|
| 54 |
+
print("+"*60)
|
| 55 |
+
return response
|
| 56 |
+
print("+"*60)
|
| 57 |
+
# autoviz_tool()
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
DataAnalyst_agent = Agent(
|
| 61 |
+
name="DataAnalyst_agent",
|
| 62 |
+
instructions=(
|
| 63 |
+
"You are a data analyst given a dataframe tool and autoviz tool. "
|
| 64 |
+
"Always use these tools to answer any question about the data. "
|
| 65 |
+
|
| 66 |
+
),
|
| 67 |
+
model="gpt-4o-mini",
|
| 68 |
+
tools=[df_agent],
|
| 69 |
+
)
|
| 70 |
+
Business_Intelligence_Agent = Agent(
|
| 71 |
+
name="Business Intelligence Agent",
|
| 72 |
+
instructions="You are a business intelligence agent and you are given a dataframe tool and you need to analyse it and give some insights on it",
|
| 73 |
+
model="gpt-4o-mini",
|
| 74 |
+
tools=[df_agent],
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
Data_Scientist_Agent = Agent(
|
| 78 |
+
name="Data Scientist Agent",
|
| 79 |
+
instructions="You are a data scientist and you are given a dataframe tool and you need to analyse it and give some insights on it",
|
| 80 |
+
model="gpt-4o-mini",
|
| 81 |
+
tools=[df_agent],
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
autoviz_agent = Agent(
|
| 85 |
+
name="autoviz_agent",
|
| 86 |
+
instructions="You are a data scientist and you use onlt autoviz tool and you need to vizualise data",
|
| 87 |
+
model="gpt-4o-mini",
|
| 88 |
+
tools=[autoviz_tool],
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
#Rephraser nlp agent
|
| 94 |
+
rephraser_agent = Agent(
|
| 95 |
+
name="rephraser_agent",
|
| 96 |
+
instructions="You rephrase the user's message to make it more natural and fluent",
|
| 97 |
+
handoff_description="An english to rephraser",
|
| 98 |
+
model="gpt-4o-mini"
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
# --- Define Orchestrator Agent ---
|
| 102 |
+
# orchestrator_agent = Agent(
|
| 103 |
+
# name="orchestrator_agent",
|
| 104 |
+
# instructions=(
|
| 105 |
+
# "You are a data analysis orchestrator. You receive a user's query and use the df_agent "
|
| 106 |
+
# "tool to analyze the healthcare dataset and provide an insightful response."
|
| 107 |
+
# "You never analyze the data on your own — always use df_agent."
|
| 108 |
+
# ),
|
| 109 |
+
# tools=[
|
| 110 |
+
# DataAnalyst_agent.as_tool(
|
| 111 |
+
# tool_name="analyze_dataframe",
|
| 112 |
+
# tool_description="Analyze the healthcare dataset based on user queries."
|
| 113 |
+
# ),
|
| 114 |
+
# rephraser_agent.as_tool(
|
| 115 |
+
# tool_name="rephrase",
|
| 116 |
+
# tool_description="Rephrase the user's message to make it more natural and fluent",
|
| 117 |
+
# ),
|
| 118 |
+
# Business_Intelligence_Agent.as_tool(
|
| 119 |
+
# tool_name="Business_Intelligence_Agent",
|
| 120 |
+
# tool_description="Analyse the dataframe and give some insights on it and will use for reports and visualizations",
|
| 121 |
+
# ),
|
| 122 |
+
# Data_Scientist_Agent.as_tool(
|
| 123 |
+
# tool_name="DataScientist_Agent",
|
| 124 |
+
# tool_description="Analyse the dataframe and give some insights on it and will use for reports and visualizations",
|
| 125 |
+
# ),
|
| 126 |
+
# ],
|
| 127 |
+
# )
|
| 128 |
+
|
| 129 |
+
orchestrator_agent = Agent(
|
| 130 |
+
name="orchestrator_agent",
|
| 131 |
+
instructions=(
|
| 132 |
+
"You are a data analysis orchestrator. You receive a user's query and use the df_agent "
|
| 133 |
+
"tool to analyze the healthcare dataset and provide an insightful response."
|
| 134 |
+
"You never analyze the data on your own — always use df_agent."
|
| 135 |
+
"you always use autoviz_agent to only visualize the data.if user asked about plots , or vizualization then use autoviz_agent "
|
| 136 |
+
),
|
| 137 |
+
tools=[
|
| 138 |
+
DataAnalyst_agent.as_tool(
|
| 139 |
+
tool_name="analyze_dataframe",
|
| 140 |
+
tool_description="Analyze the healthcare dataset based on user queries."
|
| 141 |
+
),
|
| 142 |
+
rephraser_agent.as_tool(
|
| 143 |
+
tool_name="rephrase",
|
| 144 |
+
tool_description="Rephrase the user's message to make it more natural and fluent",
|
| 145 |
+
),
|
| 146 |
+
Business_Intelligence_Agent.as_tool(
|
| 147 |
+
tool_name="Business_Intelligence_Agent",
|
| 148 |
+
tool_description="Analyse the dataframe and give some insights on it and will use for reports and visualizations",
|
| 149 |
+
),
|
| 150 |
+
Data_Scientist_Agent.as_tool(
|
| 151 |
+
tool_name="DataScientist_Agent",
|
| 152 |
+
tool_description="Analyse the dataframe and give some insights on it and will use for reports and visualizations",
|
| 153 |
+
),
|
| 154 |
+
autoviz_agent.as_tool(
|
| 155 |
+
tool_name="autoviz_agent",
|
| 156 |
+
tool_description="use this only for plots and vizualization",
|
| 157 |
+
),
|
| 158 |
+
|
| 159 |
+
],
|
| 160 |
+
)
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
# --- Define Synthesizer Agent ---
|
| 164 |
+
synthesizer_agent = Agent(
|
| 165 |
+
name="synthesizer_agent",
|
| 166 |
+
instructions=(
|
| 167 |
+
"You review the dataframe analysis results, clean up the language if needed, "
|
| 168 |
+
"and present a clear final answer to the user."
|
| 169 |
+
),
|
| 170 |
+
)
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
# # --- Main Async Runner ---
|
| 174 |
+
# #without_streamming===================================
|
| 175 |
+
# async def main(query):
|
| 176 |
+
# #
|
| 177 |
+
|
| 178 |
+
# # Run orchestrator
|
| 179 |
+
# orchestrator_result = await Runner.run(orchestrator_agent, query)
|
| 180 |
+
|
| 181 |
+
# for item in orchestrator_result.new_items:
|
| 182 |
+
# print(f"Intermediate result: {item}")
|
| 183 |
+
|
| 184 |
+
# # Run synthesizer
|
| 185 |
+
# synthesizer_result = await Runner.run(
|
| 186 |
+
# synthesizer_agent, orchestrator_result.to_input_list()
|
| 187 |
+
# )
|
| 188 |
+
|
| 189 |
+
# print(f"\n\nFinal Answer:\n{synthesizer_result.final_output}")
|
| 190 |
+
# return synthesizer_result.final_output
|
| 191 |
+
|
| 192 |
+
#=========with_redis_session_memory_stort_term=================
|
| 193 |
+
async def main(query):
|
| 194 |
+
user_id = "8919614347"
|
| 195 |
+
session_id = "123456789"
|
| 196 |
+
|
| 197 |
+
print("Creating session...")
|
| 198 |
+
session = RedisSession.from_url(
|
| 199 |
+
session_id,
|
| 200 |
+
url=Redis_url,
|
| 201 |
+
key_prefix=f"{user_id}:",
|
| 202 |
+
)
|
| 203 |
+
#
|
| 204 |
+
|
| 205 |
+
# Run orchestrator
|
| 206 |
+
orchestrator_result = await Runner.run(orchestrator_agent, query, session=session) # session=session
|
| 207 |
+
|
| 208 |
+
for item in orchestrator_result.new_items:
|
| 209 |
+
print(f"Intermediate result: {item}")
|
| 210 |
+
|
| 211 |
+
# Run synthesizer
|
| 212 |
+
synthesizer_result = await Runner.run(
|
| 213 |
+
synthesizer_agent, orchestrator_result.to_input_list()
|
| 214 |
+
)
|
| 215 |
+
|
| 216 |
+
print(f"\n\nFinal Answer:\n{synthesizer_result.final_output}")
|
| 217 |
+
return synthesizer_result.final_output
|
| 218 |
+
|
| 219 |
+
#with_streamming
|
| 220 |
+
# --- Main Async Function with Streaming + Trace ---
|
| 221 |
+
# async def main(query: str):
|
| 222 |
+
# with trace("Dataframe analysis trace"):
|
| 223 |
+
# print("Analyzing your query...\n")
|
| 224 |
+
|
| 225 |
+
# # Streamed orchestration
|
| 226 |
+
# result = Runner.run_streamed(orchestrator_agent, query)
|
| 227 |
+
|
| 228 |
+
# async for event in result.stream_events():
|
| 229 |
+
# if (
|
| 230 |
+
# event.type == "run_item_stream_event"
|
| 231 |
+
# and event.item.type == "tool_call_item"
|
| 232 |
+
# ):
|
| 233 |
+
# print(f"\n🔧 Tool Call: {event.item.name}")
|
| 234 |
+
# elif (
|
| 235 |
+
# event.type == "run_item_stream_event"
|
| 236 |
+
# and event.item.type == "message_output_item"
|
| 237 |
+
# ):
|
| 238 |
+
# print(f"🧠 Model Output: {event.item.raw_item.output_text}")
|
| 239 |
+
# elif event.type == "run_step_stream_event":
|
| 240 |
+
# print(f"Step Event: {event.step_name}")
|
| 241 |
+
# else:
|
| 242 |
+
# print(f"Other Event: {event.type}")
|
| 243 |
+
|
| 244 |
+
# print("\nSynthesizing final output...\n")
|
| 245 |
+
|
| 246 |
+
# # Run the synthesizer for final cleanup
|
| 247 |
+
# synthesizer_result = await Runner.run(
|
| 248 |
+
# synthesizer_agent, result.to_input_list()
|
| 249 |
+
# )
|
| 250 |
+
|
| 251 |
+
# print(f"\n✅ Final Answer:\n{synthesizer_result.final_output}")
|
| 252 |
+
|
| 253 |
+
# if __name__ == "__main__":
|
| 254 |
+
# user_query = input("Enter the query: ")
|
| 255 |
+
# asyncio.run(main(user_query))
|
Routes/helpers/pandas_ai_agent.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pandas as pd
|
| 2 |
+
from langchain_experimental.agents.agent_toolkits import create_pandas_dataframe_agent
|
| 3 |
+
from langchain_openai import ChatOpenAI
|
| 4 |
+
import os
|
| 5 |
+
import dotenv
|
| 6 |
+
import asyncio
|
| 7 |
+
|
| 8 |
+
from agents import Agent, ItemHelpers, MessageOutputItem, Runner, trace
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
from agents import Agent, Runner, function_tool ,CodeInterpreterTool , trace
|
| 12 |
+
import asyncio
|
| 13 |
+
import os
|
| 14 |
+
import dotenv
|
| 15 |
+
dotenv.load_dotenv()
|
| 16 |
+
|
| 17 |
+
os.environ["OPENAI_API_KEY"] = os.getenv("OPENAI_API_KEY")
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
csv_path = r"C:\Users\Dell\Documents\MR-AI\openai_agents\healthcare-data-30.csv"
|
| 21 |
+
|
| 22 |
+
# Load sample DataFrame
|
| 23 |
+
df = pd.read_csv(csv_path)
|
| 24 |
+
|
| 25 |
+
# Initialize LLM
|
| 26 |
+
llm = ChatOpenAI(temperature=0)
|
| 27 |
+
|
| 28 |
+
# Create the agent with the DataFrame
|
| 29 |
+
agent = create_pandas_dataframe_agent(llm, df, verbose=True ,allow_dangerous_code=True)
|
| 30 |
+
|
| 31 |
+
# Ask a question
|
| 32 |
+
# response = agent.invoke("How many people are older than 28?")
|
| 33 |
+
# print(response)
|
| 34 |
+
def df_agent(query: str):
|
| 35 |
+
response = agent.invoke(query)
|
| 36 |
+
return response
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
|
Routes/helpers/plotly_langchain_agent.py
ADDED
|
@@ -0,0 +1,207 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pandas as pd
|
| 2 |
+
from langchain_openai import ChatOpenAI
|
| 3 |
+
from langchain_experimental.agents import create_pandas_dataframe_agent
|
| 4 |
+
from langchain.tools import Tool
|
| 5 |
+
from typing import Any, Dict, Optional
|
| 6 |
+
import os
|
| 7 |
+
import uuid
|
| 8 |
+
import json
|
| 9 |
+
import plotly.io as pio
|
| 10 |
+
import plotly.graph_objects as go
|
| 11 |
+
from fastapi import FastAPI, HTTPException
|
| 12 |
+
from pydantic import BaseModel
|
| 13 |
+
from starlette.responses import FileResponse, JSONResponse
|
| 14 |
+
|
| 15 |
+
import os
|
| 16 |
+
from dotenv import load_dotenv
|
| 17 |
+
|
| 18 |
+
# Load environment variables from .env file
|
| 19 |
+
load_dotenv()
|
| 20 |
+
|
| 21 |
+
# ----------------- Configuration & Data Loading -----------------
|
| 22 |
+
# Now, retrieve the key from the environment (loaded from .env)
|
| 23 |
+
openai_key = os.getenv("OPENAI_API_KEY")
|
| 24 |
+
os.environ["OPENAI_API_KEY"] = openai_key
|
| 25 |
+
|
| 26 |
+
# # ... rest of your code ...
|
| 27 |
+
|
| 28 |
+
# # --- 1. CONFIGURATION & DATA LOADING ---
|
| 29 |
+
# # NOTE: In a production environment, file paths should be configurable and
|
| 30 |
+
# # API keys should be handled securely (e.g., using FastAPI secrets/dependencies).
|
| 31 |
+
|
| 32 |
+
# # 🚨 IMPORTANT: Replace with your actual key before running!
|
| 33 |
+
# os.environ["OPENAI_API_KEY"] = "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"
|
| 34 |
+
|
| 35 |
+
# Define a persistent directory for chart outputs
|
| 36 |
+
CHART_DIR = "charts"
|
| 37 |
+
os.makedirs(CHART_DIR, exist_ok=True)
|
| 38 |
+
|
| 39 |
+
# Initialize DataFrames
|
| 40 |
+
df1 = pd.DataFrame()
|
| 41 |
+
df2 = pd.DataFrame()
|
| 42 |
+
|
| 43 |
+
# Load DataFrames (adjust paths as necessary for your local environment)
|
| 44 |
+
# IMPORTANT: Since this is a FastAPI app, these paths must be local paths accessible to the server.
|
| 45 |
+
# The original /content/ paths from Colab are generally not valid here.
|
| 46 |
+
csv_file_path = "Visadataset.csv"
|
| 47 |
+
excel_file_path = "HRanalytics.xlsx"
|
| 48 |
+
|
| 49 |
+
try:
|
| 50 |
+
df1 = pd.read_csv(csv_file_path)
|
| 51 |
+
print(f"Dataframe df1 (Visa Dataset) loaded from: {csv_file_path}")
|
| 52 |
+
except FileNotFoundError:
|
| 53 |
+
print(f"Error: CSV file at {csv_file_path} not found. df1 remains empty. Using empty DataFrame.")
|
| 54 |
+
|
| 55 |
+
try:
|
| 56 |
+
df2 = pd.read_excel(excel_file_path)
|
| 57 |
+
print(f"Dataframe df2 (HR Analytics) loaded from: {excel_file_path}")
|
| 58 |
+
except FileNotFoundError:
|
| 59 |
+
print(f"Error: Excel file at {excel_file_path} not found. df2 remains empty. Using empty DataFrame.")
|
| 60 |
+
|
| 61 |
+
# --- 2. VISUALIZATION TOOL DEFINITION (Modified for API Environment) ---
|
| 62 |
+
|
| 63 |
+
def plotly_visualization_tool(code: str) -> str:
|
| 64 |
+
"""
|
| 65 |
+
Executes Python code to generate a Plotly figure, saves it as an HTML file
|
| 66 |
+
in the CHART_DIR, and returns the **full file path**.
|
| 67 |
+
|
| 68 |
+
Args:
|
| 69 |
+
code: A string containing the complete Python code for Plotly visualization.
|
| 70 |
+
|
| 71 |
+
Returns:
|
| 72 |
+
A string containing the full path to the generated HTML file, or an error message.
|
| 73 |
+
"""
|
| 74 |
+
exec_globals = {
|
| 75 |
+
'go': go,
|
| 76 |
+
'pio': pio,
|
| 77 |
+
'pd': pd,
|
| 78 |
+
'df1': df1,
|
| 79 |
+
'df2': df2
|
| 80 |
+
}
|
| 81 |
+
figure_obj = None
|
| 82 |
+
|
| 83 |
+
try:
|
| 84 |
+
# 1. Execute the code generated by the agent
|
| 85 |
+
exec(code, exec_globals)
|
| 86 |
+
figure_obj = exec_globals.get('fig')
|
| 87 |
+
|
| 88 |
+
if not figure_obj or not isinstance(figure_obj, go.Figure):
|
| 89 |
+
return "ERROR: Visualization code executed, but no valid Plotly Figure object (go.Figure) was found in the 'fig' variable."
|
| 90 |
+
|
| 91 |
+
except Exception as e:
|
| 92 |
+
return f"ERROR during code execution for visualization: {e}"
|
| 93 |
+
|
| 94 |
+
# 2. Save the figure to a file and return the path
|
| 95 |
+
try:
|
| 96 |
+
filename = f"plotly_chart_{uuid.uuid4().hex[:8]}.html"
|
| 97 |
+
# Save to the local CHART_DIR
|
| 98 |
+
filepath = os.path.join(CHART_DIR, filename)
|
| 99 |
+
|
| 100 |
+
# Save the HTML file to disk
|
| 101 |
+
pio.write_html(
|
| 102 |
+
figure_obj,
|
| 103 |
+
file=filepath,
|
| 104 |
+
include_plotlyjs='cdn',
|
| 105 |
+
auto_open=False
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
# *** REVISED RETURN ***: Return the relative path for easy API access
|
| 109 |
+
# The user will download the file using a separate endpoint.
|
| 110 |
+
return filename
|
| 111 |
+
|
| 112 |
+
except Exception as e:
|
| 113 |
+
return f"ERROR saving HTML file to path: {e}"
|
| 114 |
+
|
| 115 |
+
# --- 3. AGENT SETUP ---
|
| 116 |
+
|
| 117 |
+
plotly_tool = Tool(
|
| 118 |
+
name="plotly_visualization",
|
| 119 |
+
func=plotly_visualization_tool,
|
| 120 |
+
description=(
|
| 121 |
+
"Use this tool ONLY when the user explicitly asks for a chart, plot, or visualization. "
|
| 122 |
+
"The tool saves the chart as an HTML file and returns the **filename only** (e.g., plotly_chart_xxxx.html). "
|
| 123 |
+
"The input MUST be a COMPLETE and self-contained Python code string using 'plotly.graph_objects as go' "
|
| 124 |
+
"to generate the requested chart using DataFrames 'df1' (Visadataset) or 'df2' (HRanalytics). "
|
| 125 |
+
"The final figure object MUST be assigned to the variable 'fig' and the code MUST NOT contain fig.show(). "
|
| 126 |
+
"Example input: \"import plotly.graph_objects as go\\nfig = go.Figure([go.Bar(x=df1['case_status'].value_counts().index, y=df1['case_status'].value_counts().values)])\\nfig.update_layout(title='Visa Status Count', xaxis_title='Status', yaxis_title='Count')\""
|
| 127 |
+
),
|
| 128 |
+
return_direct=False
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
llm = ChatOpenAI(
|
| 132 |
+
model="gpt-4o-mini",
|
| 133 |
+
api_key=os.environ["OPENAI_API_KEY"],
|
| 134 |
+
temperature=0.3,
|
| 135 |
+
timeout=60, # Increased timeout for potential long agent/tool calls
|
| 136 |
+
max_retries=3
|
| 137 |
+
)
|
| 138 |
+
|
| 139 |
+
# Pass both DataFrames (df1, df2) to the agent
|
| 140 |
+
agent = create_pandas_dataframe_agent(
|
| 141 |
+
llm=llm,
|
| 142 |
+
df=[df1, df2],
|
| 143 |
+
verbose=True,
|
| 144 |
+
allow_dangerous_code=True,
|
| 145 |
+
agent_type="tool-calling",
|
| 146 |
+
extra_tools=[plotly_tool]
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
# --- 4. FASTAPI APPLICATION ---
|
| 150 |
+
|
| 151 |
+
app = FastAPI(title="Pandas Agent API with Plotly Visualization")
|
| 152 |
+
|
| 153 |
+
class QueryRequest(BaseModel):
|
| 154 |
+
question: str
|
| 155 |
+
|
| 156 |
+
class QueryResponse(BaseModel):
|
| 157 |
+
message: str
|
| 158 |
+
chart_filename: Optional[str] = None
|
| 159 |
+
|
| 160 |
+
@app.get("/")
|
| 161 |
+
def read_root():
|
| 162 |
+
return {"message": "Pandas Agent API is running. Use the /query endpoint to ask questions."}
|
| 163 |
+
|
| 164 |
+
@app.post("/query", response_model=QueryResponse)
|
| 165 |
+
async def handle_query(request: QueryRequest):
|
| 166 |
+
"""
|
| 167 |
+
Handles a user query by invoking the LangChain agent.
|
| 168 |
+
If a chart is generated, it returns the filename for download.
|
| 169 |
+
"""
|
| 170 |
+
question = request.question
|
| 171 |
+
|
| 172 |
+
try:
|
| 173 |
+
response = agent.invoke({"input": question})
|
| 174 |
+
output = response.get("output", str(response))
|
| 175 |
+
|
| 176 |
+
# Check if the output is a filename returned by the tool (e.g., plotly_chart_xxxx.html)
|
| 177 |
+
if output.startswith("plotly_chart_") and output.endswith(".html"):
|
| 178 |
+
filename = output
|
| 179 |
+
return QueryResponse(
|
| 180 |
+
message=f"Agent generated a chart. Download the file using the /download/{filename} endpoint.",
|
| 181 |
+
chart_filename=filename
|
| 182 |
+
)
|
| 183 |
+
|
| 184 |
+
# If no chart was generated, return the text output
|
| 185 |
+
return QueryResponse(message=output)
|
| 186 |
+
|
| 187 |
+
except Exception as e:
|
| 188 |
+
# Log the error internally
|
| 189 |
+
print(f"Error during agent invocation: {e}")
|
| 190 |
+
raise HTTPException(status_code=500, detail=f"Internal server error: {e}")
|
| 191 |
+
|
| 192 |
+
@app.get("/download/{filename}")
|
| 193 |
+
async def download_chart(filename: str):
|
| 194 |
+
"""
|
| 195 |
+
Serves the generated Plotly HTML file for download.
|
| 196 |
+
"""
|
| 197 |
+
filepath = os.path.join(CHART_DIR, filename)
|
| 198 |
+
|
| 199 |
+
if not filename.startswith("plotly_chart_") or ".." in filename:
|
| 200 |
+
raise HTTPException(status_code=400, detail="Invalid filename format.")
|
| 201 |
+
|
| 202 |
+
if not os.path.exists(filepath):
|
| 203 |
+
raise HTTPException(status_code=404, detail="Chart file not found.")
|
| 204 |
+
|
| 205 |
+
return FileResponse(filepath, filename=filename, media_type="text/html")
|
| 206 |
+
|
| 207 |
+
# --- END OF FASTAPI CODE ---
|
Routes/helpers/report_generation_helpers.py
ADDED
|
@@ -0,0 +1,242 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# from google import genai
|
| 2 |
+
# from datetime import datetime
|
| 3 |
+
# import requests
|
| 4 |
+
# import base64
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
# gemini_api_key ="AIzaSyAAAUn6xgLwzCkdUkVViaYnLo40pVN0ABc"
|
| 10 |
+
|
| 11 |
+
# # Set your Gemini AI API key
|
| 12 |
+
# client = genai.Client(api_key=gemini_api_key)
|
| 13 |
+
# generation_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
| 14 |
+
|
| 15 |
+
# # Context and executive summary
|
| 16 |
+
# context = """
|
| 17 |
+
# Sales data for Q1 2025 shows product A had a 30% increase in revenue,
|
| 18 |
+
# while product B remained flat. Marketing campaigns improved customer engagement.
|
| 19 |
+
# """
|
| 20 |
+
|
| 21 |
+
# executive_summary = """
|
| 22 |
+
# This report highlights quarterly sales performance.
|
| 23 |
+
# Product A shows significant growth, while Product B remains stable.
|
| 24 |
+
# Includes static CSS-only charts and tables for visual analysis.
|
| 25 |
+
# """
|
| 26 |
+
|
| 27 |
+
# # Prompt to generate HTML + CSS
|
| 28 |
+
# prompt = f"""
|
| 29 |
+
# Generate a best design in css static HTML report (only HTML and CSS, no JS) including:
|
| 30 |
+
# - Executive summary
|
| 31 |
+
# - Static graphs (bar or line charts using CSS)
|
| 32 |
+
# - Tables for data visualization if needed
|
| 33 |
+
# - Professional layout and styling
|
| 34 |
+
|
| 35 |
+
# Executive Summary:
|
| 36 |
+
# {executive_summary}
|
| 37 |
+
|
| 38 |
+
# Context:
|
| 39 |
+
# {context}
|
| 40 |
+
# """
|
| 41 |
+
|
| 42 |
+
# # # Generate content using Gemini AI
|
| 43 |
+
# # response = client.models.generate_content(
|
| 44 |
+
# # model="gemini-2.5-flash",
|
| 45 |
+
# # contents=prompt
|
| 46 |
+
|
| 47 |
+
# # )
|
| 48 |
+
|
| 49 |
+
# # # Get the generated HTML
|
| 50 |
+
# # html_code = response.text
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
# # print(html_code)
|
| 54 |
+
|
| 55 |
+
# # # Save to file
|
| 56 |
+
# # with open("static_report.html", "w", encoding="utf-8") as f:
|
| 57 |
+
# # f.write(html_code)
|
| 58 |
+
|
| 59 |
+
# # print("Static HTML report generated: static_report.html")
|
| 60 |
+
|
| 61 |
+
# # this is for the production grade report generation
|
| 62 |
+
# #def return_html_report(format_type:str,report_name:str,include_citations:bool,success:bool,list_of_queries:list[str],theme:str,pageoptions:dict):
|
| 63 |
+
|
| 64 |
+
# def return_html_report():
|
| 65 |
+
# # Generate content using Gemini AI
|
| 66 |
+
# response = client.models.generate_content(
|
| 67 |
+
# model="gemini-2.5-flash",
|
| 68 |
+
# contents=prompt
|
| 69 |
+
|
| 70 |
+
# )
|
| 71 |
+
|
| 72 |
+
# # Get the generated HTML
|
| 73 |
+
# html_code = response.text
|
| 74 |
+
|
| 75 |
+
# # Find the index where the DOCTYPE starts
|
| 76 |
+
# start_index = html_code.find("<!DOCTYPE html>")
|
| 77 |
+
|
| 78 |
+
# # Extract from that point forward
|
| 79 |
+
# cleaned_html = html_code[start_index:]
|
| 80 |
+
|
| 81 |
+
# print(cleaned_html)
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
# print(cleaned_html)
|
| 85 |
+
|
| 86 |
+
# # Save to file
|
| 87 |
+
# # with open("static_report.html", "w", encoding="utf-8") as f:
|
| 88 |
+
# # f.write(html_code)
|
| 89 |
+
|
| 90 |
+
# # print("Static HTML report generated: static_report.html")
|
| 91 |
+
# return cleaned_html
|
| 92 |
+
|
| 93 |
+
# # html_content = return_html_report()
|
| 94 |
+
# # print(html_content)
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
# import requests
|
| 98 |
+
# import base64
|
| 99 |
+
# #
|
| 100 |
+
# # url = "https://srivatsavdamaraju-htmlnode.hf.space/api/html-to-pdf"
|
| 101 |
+
# # payload = {
|
| 102 |
+
# # "html_content": html_content,
|
| 103 |
+
# # "pdf_options": {
|
| 104 |
+
# # "format": "A4"
|
| 105 |
+
# # }
|
| 106 |
+
# # }
|
| 107 |
+
|
| 108 |
+
# # response = requests.post(url, json=payload)
|
| 109 |
+
|
| 110 |
+
# # result = response.json()
|
| 111 |
+
# # print("="*550)
|
| 112 |
+
# # print(result)
|
| 113 |
+
# # print("="*550)
|
| 114 |
+
|
| 115 |
+
# # # Save PDF to file
|
| 116 |
+
# # with open("output.pdf", "wb") as f:
|
| 117 |
+
# # f.write(base64.b64decode(result["pdf_base64"]))
|
| 118 |
+
# def html_to_pdf_via_api(html_content):
|
| 119 |
+
# url = "https://srivatsavdamaraju-htmlnode.hf.space/api/html-to-pdf"
|
| 120 |
+
# payload = {
|
| 121 |
+
# "html_content": html_content,
|
| 122 |
+
# "pdf_options": {
|
| 123 |
+
# "format": "A4"
|
| 124 |
+
# }
|
| 125 |
+
# }
|
| 126 |
+
|
| 127 |
+
# response = requests.post(url, json=payload)
|
| 128 |
+
# result = response.json()
|
| 129 |
+
# print(result)
|
| 130 |
+
# return result
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
# def html_to_pdf_via_api_with_pdf(html_content, output_pdf_path):
|
| 135 |
+
# url = "https://srivatsavdamaraju-htmlnode.hf.space/api/html-to-pdf"
|
| 136 |
+
# payload = {
|
| 137 |
+
# "html_content": html_content,
|
| 138 |
+
# "pdf_options": {
|
| 139 |
+
# "format": "A4"
|
| 140 |
+
# }
|
| 141 |
+
# }
|
| 142 |
+
|
| 143 |
+
# response = requests.post(url, json=payload)
|
| 144 |
+
# result = response.json()
|
| 145 |
+
# print(result)
|
| 146 |
+
|
| 147 |
+
# # Save PDF to file
|
| 148 |
+
# # with open(output_pdf_path, "wb") as f:
|
| 149 |
+
# # f.write(base64.b64decode(result["pdf_base64"]))
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
# # if __name__ == "__main__":
|
| 153 |
+
|
| 154 |
+
# # html_to_pdf_via_api_with_pdf(html_content, "generated.pdf")
|
| 155 |
+
# # print("PDF report generated: generated.pdf")
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
from google import genai
|
| 160 |
+
from datetime import datetime
|
| 161 |
+
import requests
|
| 162 |
+
import base64
|
| 163 |
+
|
| 164 |
+
# Initialize Gemini client
|
| 165 |
+
gemini_api_key ="AIzaSyAAAUn6xgLwzCkdUkVViaYnLo40pVN0ABc"
|
| 166 |
+
client = genai.Client(api_key=gemini_api_key)
|
| 167 |
+
|
| 168 |
+
def generate_prompt(context, executive_summary):
|
| 169 |
+
"""Generate the text prompt for Gemini."""
|
| 170 |
+
return f"""
|
| 171 |
+
Generate a professional static HTML report (only HTML and CSS, no JS) including:
|
| 172 |
+
- Executive summary
|
| 173 |
+
- Static CSS-based graphs or tables
|
| 174 |
+
- Elegant, professional styling
|
| 175 |
+
|
| 176 |
+
Executive Summary:
|
| 177 |
+
{executive_summary}
|
| 178 |
+
|
| 179 |
+
Context:
|
| 180 |
+
{context}
|
| 181 |
+
"""
|
| 182 |
+
|
| 183 |
+
def return_html_report(context=None, executive_summary=None):
|
| 184 |
+
"""Generate HTML report using Gemini API."""
|
| 185 |
+
if not context:
|
| 186 |
+
context = "Default context: Sales increased by 10%."
|
| 187 |
+
if not executive_summary:
|
| 188 |
+
executive_summary = "Default summary: Q1 growth observed."
|
| 189 |
+
|
| 190 |
+
prompt = generate_prompt(context, executive_summary)
|
| 191 |
+
|
| 192 |
+
response = client.models.generate_content(
|
| 193 |
+
model="gemini-2.5-flash",
|
| 194 |
+
contents=prompt
|
| 195 |
+
)
|
| 196 |
+
|
| 197 |
+
html_code = response.text
|
| 198 |
+
start_index = html_code.find("<!DOCTYPE html>")
|
| 199 |
+
cleaned_html = html_code[start_index:] if start_index != -1 else html_code
|
| 200 |
+
print(cleaned_html)
|
| 201 |
+
return cleaned_html
|
| 202 |
+
|
| 203 |
+
# return_html_report()
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
def html_to_pdf_via_api(html_content):
|
| 207 |
+
"""Send HTML to your Hugging Face PDF API."""
|
| 208 |
+
url = "https://srivatsavdamaraju-htmlnode.hf.space/api/html-to-pdf"
|
| 209 |
+
payload = {
|
| 210 |
+
"html_content": html_content,
|
| 211 |
+
"pdf_options": {"format": "A4"}
|
| 212 |
+
}
|
| 213 |
+
|
| 214 |
+
response = requests.post(url, json=payload)
|
| 215 |
+
result = response.json()
|
| 216 |
+
return result
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
def generate_report_files(format_type="html", regenerate_flags={"html": True, "pdf": True}):
|
| 220 |
+
"""
|
| 221 |
+
Generate report based on flags to avoid redundant work.
|
| 222 |
+
- If regenerate_flags['html'] is False, skip HTML generation.
|
| 223 |
+
- If regenerate_flags['pdf'] is False, skip PDF conversion.
|
| 224 |
+
"""
|
| 225 |
+
html_content = None
|
| 226 |
+
pdf_result = None
|
| 227 |
+
|
| 228 |
+
# Generate HTML only if needed
|
| 229 |
+
if regenerate_flags.get("html", True):
|
| 230 |
+
html_content = return_html_report()
|
| 231 |
+
else:
|
| 232 |
+
print("Skipping HTML generation (flag set to False)")
|
| 233 |
+
|
| 234 |
+
# Generate PDF only if needed
|
| 235 |
+
if format_type == "pdf" and regenerate_flags.get("pdf", True):
|
| 236 |
+
if not html_content:
|
| 237 |
+
html_content = return_html_report()
|
| 238 |
+
pdf_result = html_to_pdf_via_api(html_content)
|
| 239 |
+
elif format_type == "pdf":
|
| 240 |
+
print("Skipping PDF generation (flag set to False)")
|
| 241 |
+
|
| 242 |
+
return {"html": html_content, "pdf": pdf_result}
|
Routes/helpers/rough.py
ADDED
|
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pandas as pd
|
| 2 |
+
from langchain_experimental.agents.agent_toolkits import create_pandas_dataframe_agent
|
| 3 |
+
from langchain_openai import ChatOpenAI
|
| 4 |
+
import os
|
| 5 |
+
import dotenv
|
| 6 |
+
from agents import Agent, Runner, function_tool ,trace
|
| 7 |
+
import asyncio
|
| 8 |
+
|
| 9 |
+
# Load environment variables
|
| 10 |
+
dotenv.load_dotenv()
|
| 11 |
+
os.environ["OPENAI_API_KEY"] = os.getenv("OPENAI_API_KEY")
|
| 12 |
+
|
| 13 |
+
# Load CSV
|
| 14 |
+
csv_path = r"C:\Users\Dell\Documents\MR-AI\openai_agents\healthcare-data-30.csv"
|
| 15 |
+
df = pd.read_csv(csv_path)
|
| 16 |
+
|
| 17 |
+
# Initialize LLM
|
| 18 |
+
llm = ChatOpenAI(temperature=0)
|
| 19 |
+
|
| 20 |
+
# Create Pandas agent
|
| 21 |
+
pandas_agent = create_pandas_dataframe_agent(llm, df, verbose=True, allow_dangerous_code=True)
|
| 22 |
+
|
| 23 |
+
# Wrap it as a tool
|
| 24 |
+
@function_tool
|
| 25 |
+
def df_agent(query: str):
|
| 26 |
+
"""Query the healthcare dataset using natural language."""
|
| 27 |
+
response = pandas_agent.invoke(query)
|
| 28 |
+
return response
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
DataAnalyst_agent = Agent(
|
| 34 |
+
name="DataAnalyst_agent",
|
| 35 |
+
instructions=(
|
| 36 |
+
"You are a data analyst and you are given a dataframe tool. ""You need to analyse it and give some insights on it. "
|
| 37 |
+
"You will use the dataframe tool to analyse, give the report, and visualizations also."
|
| 38 |
+
),
|
| 39 |
+
model="gpt-4o-mini",
|
| 40 |
+
tools=[df_agent],
|
| 41 |
+
)
|
| 42 |
+
Business_Intelligence_Agent = Agent(
|
| 43 |
+
name="Business Intelligence Agent",
|
| 44 |
+
instructions="You are a business intelligence agent and you are given a dataframe tool and you need to analyse it and give some insights on it",
|
| 45 |
+
model="gpt-4o-mini",
|
| 46 |
+
tools=[df_agent],
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
Data_Scientist_Agent = Agent(
|
| 50 |
+
name="Data Scientist Agent",
|
| 51 |
+
instructions="You are a data scientist and you are given a dataframe tool and you need to analyse it and give some insights on it",
|
| 52 |
+
model="gpt-4o-mini",
|
| 53 |
+
tools=[df_agent],
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
#Rephraser nlp agent
|
| 60 |
+
rephraser_agent = Agent(
|
| 61 |
+
name="rephraser_agent",
|
| 62 |
+
instructions="You rephrase the user's message to make it more natural and fluent",
|
| 63 |
+
handoff_description="An english to rephraser",
|
| 64 |
+
model="gpt-4o-mini"
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
# --- Define Orchestrator Agent ---
|
| 68 |
+
orchestrator_agent = Agent(
|
| 69 |
+
name="orchestrator_agent",
|
| 70 |
+
instructions=(
|
| 71 |
+
"You are a data analysis orchestrator. You receive a user's query and use the df_agent "
|
| 72 |
+
"tool to analyze the healthcare dataset and provide an insightful response."
|
| 73 |
+
"You never analyze the data on your own — always use df_agent."
|
| 74 |
+
),
|
| 75 |
+
tools=[
|
| 76 |
+
DataAnalyst_agent.as_tool(
|
| 77 |
+
tool_name="analyze_dataframe",
|
| 78 |
+
tool_description="Analyze the healthcare dataset based on user queries."
|
| 79 |
+
),
|
| 80 |
+
rephraser_agent.as_tool(
|
| 81 |
+
tool_name="rephrase",
|
| 82 |
+
tool_description="Rephrase the user's message to make it more natural and fluent",
|
| 83 |
+
),
|
| 84 |
+
Business_Intelligence_Agent.as_tool(
|
| 85 |
+
tool_name="Business_Intelligence_Agent",
|
| 86 |
+
tool_description="Analyse the dataframe and give some insights on it and will use for reports and visualizations",
|
| 87 |
+
),
|
| 88 |
+
Data_Scientist_Agent.as_tool(
|
| 89 |
+
tool_name="DataScientist_Agent",
|
| 90 |
+
tool_description="Analyse the dataframe and give some insights on it and will use for reports and visualizations",
|
| 91 |
+
),
|
| 92 |
+
],
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
# --- Define Synthesizer Agent ---
|
| 97 |
+
synthesizer_agent = Agent(
|
| 98 |
+
name="synthesizer_agent",
|
| 99 |
+
instructions=(
|
| 100 |
+
"You review the dataframe analysis results, clean up the language if needed, "
|
| 101 |
+
"and present a clear final answer to the user."
|
| 102 |
+
),
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
# # --- Main Async Runner ---
|
| 107 |
+
# #without_streamming
|
| 108 |
+
async def main(query):
|
| 109 |
+
#
|
| 110 |
+
|
| 111 |
+
# Run orchestrator
|
| 112 |
+
orchestrator_result = await Runner.run(orchestrator_agent, query)
|
| 113 |
+
|
| 114 |
+
for item in orchestrator_result.new_items:
|
| 115 |
+
print(f"Intermediate result: {item}")
|
| 116 |
+
|
| 117 |
+
# Run synthesizer
|
| 118 |
+
synthesizer_result = await Runner.run(
|
| 119 |
+
synthesizer_agent, orchestrator_result.to_input_list()
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
print(f"\n\nFinal Answer:\n{synthesizer_result.final_output}")
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
#with_streamming
|
| 126 |
+
# --- Main Async Function with Streaming + Trace ---
|
| 127 |
+
# async def main(query: str):
|
| 128 |
+
# with trace("Dataframe analysis trace"):
|
| 129 |
+
# print("Analyzing your query...\n")
|
| 130 |
+
|
| 131 |
+
# # Streamed orchestration
|
| 132 |
+
# result = Runner.run_streamed(orchestrator_agent, query)
|
| 133 |
+
|
| 134 |
+
# async for event in result.stream_events():
|
| 135 |
+
# if (
|
| 136 |
+
# event.type == "run_item_stream_event"
|
| 137 |
+
# and event.item.type == "tool_call_item"
|
| 138 |
+
# ):
|
| 139 |
+
# print(f"\n🔧 Tool Call: {event.item.name}")
|
| 140 |
+
# elif (
|
| 141 |
+
# event.type == "run_item_stream_event"
|
| 142 |
+
# and event.item.type == "message_output_item"
|
| 143 |
+
# ):
|
| 144 |
+
# print(f"🧠 Model Output: {event.item.raw_item.output_text}")
|
| 145 |
+
# elif event.type == "run_step_stream_event":
|
| 146 |
+
# print(f"Step Event: {event.step_name}")
|
| 147 |
+
# else:
|
| 148 |
+
# print(f"Other Event: {event.type}")
|
| 149 |
+
|
| 150 |
+
# print("\nSynthesizing final output...\n")
|
| 151 |
+
|
| 152 |
+
# # Run the synthesizer for final cleanup
|
| 153 |
+
# synthesizer_result = await Runner.run(
|
| 154 |
+
# synthesizer_agent, result.to_input_list()
|
| 155 |
+
# )
|
| 156 |
+
|
| 157 |
+
# print(f"\n✅ Final Answer:\n{synthesizer_result.final_output}")
|
| 158 |
+
|
| 159 |
+
if __name__ == "__main__":
|
| 160 |
+
user_query = input("Enter the query: ")
|
| 161 |
+
asyncio.run(main(user_query))
|
Routes/main_agent_chat_bot_v2.py
ADDED
|
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
from fastapi import APIRouter, HTTPException
|
| 5 |
+
from pydantic import BaseModel
|
| 6 |
+
from typing import Optional, List, Dict
|
| 7 |
+
from datetime import datetime
|
| 8 |
+
from Redis.sessions_new import *
|
| 9 |
+
from .helpers.main_agent_best_as_of_now import main
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
main_chatbot_route_v2 = APIRouter(prefix="/main_chatbot_v2", tags=["main_chatbot_v2"])
|
| 13 |
+
|
| 14 |
+
# ==================== UPDATED PYDANTIC MODELS ====================
|
| 15 |
+
|
| 16 |
+
class Author(BaseModel):
|
| 17 |
+
role: str
|
| 18 |
+
|
| 19 |
+
class ChatRequest(BaseModel):
|
| 20 |
+
user_login_id: str
|
| 21 |
+
session_id: Optional[str] = None
|
| 22 |
+
query: str
|
| 23 |
+
org_id: Optional[str] = None
|
| 24 |
+
metadata: Optional[Dict] = None
|
| 25 |
+
list_of_files_path: Optional[List[str]] = None
|
| 26 |
+
author: Optional[Author] = None
|
| 27 |
+
create_time: Optional[str] = None
|
| 28 |
+
token: Optional[str] = None
|
| 29 |
+
|
| 30 |
+
class ChatResponse(BaseModel):
|
| 31 |
+
session_id: str
|
| 32 |
+
user_message: str
|
| 33 |
+
assistant_response: str
|
| 34 |
+
is_new_session: bool
|
| 35 |
+
session_title: str
|
| 36 |
+
timestamp: str
|
| 37 |
+
files_processed: Optional[List[str]] = None
|
| 38 |
+
|
| 39 |
+
# ==================== ENDPOINTS ====================
|
| 40 |
+
|
| 41 |
+
@main_chatbot_route_v2.get("/")
|
| 42 |
+
async def root():
|
| 43 |
+
return {
|
| 44 |
+
"message": "AI Data Analysis Chatbot API is running!",
|
| 45 |
+
"version": "2.0.0",
|
| 46 |
+
"features": ["Session Management", "Chat History", "Context Awareness", "Auto Title", "File Processing"]
|
| 47 |
+
}
|
| 48 |
+
|
| 49 |
+
@main_chatbot_route_v2.post("/chat", response_model=ChatResponse)
|
| 50 |
+
async def chat(request: ChatRequest):
|
| 51 |
+
"""
|
| 52 |
+
Main chat endpoint with automatic session management and file processing
|
| 53 |
+
|
| 54 |
+
Args:
|
| 55 |
+
user_login_id: User identifier
|
| 56 |
+
session_id: Session ID (None for new chat)
|
| 57 |
+
query: User's message
|
| 58 |
+
org_id: Optional organization ID
|
| 59 |
+
metadata: Optional metadata (source, language, etc.)
|
| 60 |
+
list_of_files_path: Optional list of file paths to process
|
| 61 |
+
author: Optional author information with role
|
| 62 |
+
create_time: Optional timestamp of request creation
|
| 63 |
+
token: Optional authentication token
|
| 64 |
+
|
| 65 |
+
Returns:
|
| 66 |
+
ChatResponse with session details and AI response
|
| 67 |
+
"""
|
| 68 |
+
try:
|
| 69 |
+
is_new_session = False
|
| 70 |
+
|
| 71 |
+
# ==================== STEP 1: Session Detection ====================
|
| 72 |
+
if request.session_id is None:
|
| 73 |
+
# NEW CHAT → Create new session
|
| 74 |
+
session_data = create_session(
|
| 75 |
+
user_login_id=request.user_login_id,
|
| 76 |
+
org_id=request.org_id,
|
| 77 |
+
metadata=request.metadata
|
| 78 |
+
)
|
| 79 |
+
session_id = session_data["session_id"]
|
| 80 |
+
is_new_session = True
|
| 81 |
+
conversation_history = []
|
| 82 |
+
else:
|
| 83 |
+
# EXISTING CHAT → Load session and history
|
| 84 |
+
try:
|
| 85 |
+
session_data = get_session(request.user_login_id, request.session_id)
|
| 86 |
+
session_id = request.session_id
|
| 87 |
+
conversation_history = get_message_history(
|
| 88 |
+
request.user_login_id,
|
| 89 |
+
session_id,
|
| 90 |
+
limit=5
|
| 91 |
+
)
|
| 92 |
+
except HTTPException as e:
|
| 93 |
+
if e.status_code == 404:
|
| 94 |
+
# Session expired → Create new one
|
| 95 |
+
session_data = create_session(
|
| 96 |
+
user_login_id=request.user_login_id,
|
| 97 |
+
org_id=request.org_id,
|
| 98 |
+
metadata=request.metadata
|
| 99 |
+
)
|
| 100 |
+
session_id = session_data["session_id"]
|
| 101 |
+
is_new_session = True
|
| 102 |
+
conversation_history = []
|
| 103 |
+
else:
|
| 104 |
+
raise
|
| 105 |
+
|
| 106 |
+
# ==================== STEP 2: Save User Message ====================
|
| 107 |
+
# Include files info in the message if provided
|
| 108 |
+
user_message_content = request.query
|
| 109 |
+
if request.list_of_files_path:
|
| 110 |
+
user_message_content += f"\n[Files: {', '.join(request.list_of_files_path)}]"
|
| 111 |
+
|
| 112 |
+
add_message(
|
| 113 |
+
user_login_id=request.user_login_id,
|
| 114 |
+
session_id=session_id,
|
| 115 |
+
role="user",
|
| 116 |
+
content=user_message_content
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
# ==================== STEP 3: Process with AI Agent ====================
|
| 120 |
+
# Format context from conversation history
|
| 121 |
+
context = format_conversation_context(conversation_history) if conversation_history else ""
|
| 122 |
+
|
| 123 |
+
# Prepare additional context for the agent
|
| 124 |
+
agent_context = {
|
| 125 |
+
"query": request.query,
|
| 126 |
+
"files": request.list_of_files_path,
|
| 127 |
+
"conversation_history": context,
|
| 128 |
+
"org_id": request.org_id,
|
| 129 |
+
"metadata": request.metadata,
|
| 130 |
+
"author_role": request.author.role if request.author else None
|
| 131 |
+
}
|
| 132 |
+
|
| 133 |
+
# Call AI agent - modify based on your main() function signature
|
| 134 |
+
# Option 1: If main() accepts multiple parameters
|
| 135 |
+
# result = await main(request.query, files=request.list_of_files_path, context=context)
|
| 136 |
+
|
| 137 |
+
# Option 2: If main() accepts a single query parameter (current)
|
| 138 |
+
print("File paths received in request:", request.list_of_files_path)
|
| 139 |
+
result = await main(request.query,file_paths=request.list_of_files_path)
|
| 140 |
+
print("Result from main():", result)
|
| 141 |
+
|
| 142 |
+
# Option 3: If you want to pass everything
|
| 143 |
+
# result = await main(agent_context)
|
| 144 |
+
|
| 145 |
+
# ==================== STEP 4: Save Assistant Response ====================
|
| 146 |
+
add_message(
|
| 147 |
+
user_login_id=request.user_login_id,
|
| 148 |
+
session_id=session_id,
|
| 149 |
+
role="assistant",
|
| 150 |
+
content=result
|
| 151 |
+
)
|
| 152 |
+
|
| 153 |
+
# ==================== STEP 5: Auto-Generate Title ====================
|
| 154 |
+
update_session_title_if_needed(
|
| 155 |
+
request.user_login_id,
|
| 156 |
+
session_id,
|
| 157 |
+
is_new_session
|
| 158 |
+
)
|
| 159 |
+
|
| 160 |
+
# Get updated session with title
|
| 161 |
+
session_data = get_session(request.user_login_id, session_id)
|
| 162 |
+
|
| 163 |
+
# ==================== STEP 6: Return Response ====================
|
| 164 |
+
return ChatResponse(
|
| 165 |
+
session_id=session_id,
|
| 166 |
+
user_message=request.query,
|
| 167 |
+
assistant_response=result,
|
| 168 |
+
is_new_session=is_new_session,
|
| 169 |
+
session_title=session_data.get("title", "New Chat"),
|
| 170 |
+
timestamp=datetime.now().isoformat(),
|
| 171 |
+
files_processed=request.list_of_files_path,
|
| 172 |
+
artifacts_paths=[]
|
| 173 |
+
)
|
| 174 |
+
|
| 175 |
+
except HTTPException:
|
| 176 |
+
raise
|
| 177 |
+
except Exception as e:
|
| 178 |
+
raise HTTPException(status_code=500, detail=f"Error in chat: {str(e)}")
|
Routes/main_chat_bot.py
ADDED
|
@@ -0,0 +1,588 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# from fastapi import APIRouter, HTTPException
|
| 2 |
+
# from pydantic import BaseModel
|
| 3 |
+
# from typing import Optional, List, Dict
|
| 4 |
+
# import json
|
| 5 |
+
# import redis
|
| 6 |
+
# import os
|
| 7 |
+
# import uuid
|
| 8 |
+
# from datetime import datetime
|
| 9 |
+
# from .helpers.main_agent_helpers import main
|
| 10 |
+
# from dotenv import load_dotenv
|
| 11 |
+
|
| 12 |
+
# load_dotenv()
|
| 13 |
+
|
| 14 |
+
# # Initialize router
|
| 15 |
+
# main_chatbot_route = APIRouter(prefix="/main_chatbot", tags=["main_chatbot"])
|
| 16 |
+
|
| 17 |
+
# # ==================== REDIS CLIENT ====================
|
| 18 |
+
|
| 19 |
+
# def get_redis_client():
|
| 20 |
+
# """Initialize Redis client"""
|
| 21 |
+
# try:
|
| 22 |
+
# REDIS_URL = os.getenv("REDIS_URL")
|
| 23 |
+
# REDIS_HOST = os.getenv("REDIS_HOST", "127.0.0.1")
|
| 24 |
+
# REDIS_PORT = int(os.getenv("REDIS_PORT", 6379))
|
| 25 |
+
# REDIS_PASSWORD = os.getenv("REDIS_PASSWORD")
|
| 26 |
+
|
| 27 |
+
# if REDIS_URL:
|
| 28 |
+
# redis_client = redis.from_url(
|
| 29 |
+
# REDIS_URL,
|
| 30 |
+
# decode_responses=True,
|
| 31 |
+
# socket_connect_timeout=5,
|
| 32 |
+
# socket_timeout=5
|
| 33 |
+
# )
|
| 34 |
+
# else:
|
| 35 |
+
# redis_client = redis.StrictRedis(
|
| 36 |
+
# host=REDIS_HOST,
|
| 37 |
+
# port=REDIS_PORT,
|
| 38 |
+
# password=REDIS_PASSWORD,
|
| 39 |
+
# decode_responses=True,
|
| 40 |
+
# socket_connect_timeout=5,
|
| 41 |
+
# socket_timeout=5
|
| 42 |
+
# )
|
| 43 |
+
|
| 44 |
+
# redis_client.ping()
|
| 45 |
+
# print(f"✅ Redis connected for chatbot")
|
| 46 |
+
# return redis_client
|
| 47 |
+
# except Exception as e:
|
| 48 |
+
# print(f"❌ Redis connection failed: {e}")
|
| 49 |
+
# raise HTTPException(status_code=500, detail=f"Redis connection failed: {str(e)}")
|
| 50 |
+
|
| 51 |
+
# redis_client = get_redis_client()
|
| 52 |
+
|
| 53 |
+
# # ==================== PYDANTIC MODELS ====================
|
| 54 |
+
|
| 55 |
+
# class ChatRequest(BaseModel):
|
| 56 |
+
# user_login_id: str
|
| 57 |
+
# session_id: Optional[str] = None # None = NEW CHAT
|
| 58 |
+
# query: str
|
| 59 |
+
# org_id: Optional[str] = None
|
| 60 |
+
# metadata: Optional[Dict] = None
|
| 61 |
+
|
| 62 |
+
# class ChatResponse(BaseModel):
|
| 63 |
+
# session_id: str
|
| 64 |
+
# user_message: str
|
| 65 |
+
# assistant_response: str
|
| 66 |
+
# is_new_session: bool
|
| 67 |
+
# session_title: str
|
| 68 |
+
# timestamp: str
|
| 69 |
+
|
| 70 |
+
# class MessageResponse(BaseModel):
|
| 71 |
+
# message_id: str
|
| 72 |
+
# role: str
|
| 73 |
+
# content: str
|
| 74 |
+
# timestamp: str
|
| 75 |
+
|
| 76 |
+
# class ChatHistoryResponse(BaseModel):
|
| 77 |
+
# session_id: str
|
| 78 |
+
# title: str
|
| 79 |
+
# created_at: str
|
| 80 |
+
# message_count: int
|
| 81 |
+
# messages: List[MessageResponse]
|
| 82 |
+
|
| 83 |
+
# # ==================== SESSION MANAGEMENT FUNCTIONS ====================
|
| 84 |
+
|
| 85 |
+
# def create_session(user_login_id: str, org_id: Optional[str] = None, metadata: Optional[Dict] = None) -> dict:
|
| 86 |
+
# """Create a new chat session"""
|
| 87 |
+
# session_id = str(uuid.uuid4())
|
| 88 |
+
# session_data = {
|
| 89 |
+
# "session_id": session_id,
|
| 90 |
+
# "user_login_id": user_login_id,
|
| 91 |
+
# "org_id": org_id,
|
| 92 |
+
# "created_at": datetime.now().isoformat(),
|
| 93 |
+
# "title": "New Chat",
|
| 94 |
+
# "message_count": 0,
|
| 95 |
+
# "metadata": metadata or {}
|
| 96 |
+
# }
|
| 97 |
+
|
| 98 |
+
# # Store session in Redis with 24 hour TTL
|
| 99 |
+
# # Key pattern: session:{user_login_id}:{session_id}
|
| 100 |
+
# redis_client.setex(
|
| 101 |
+
# f"session:{user_login_id}:{session_id}",
|
| 102 |
+
# 86400, # 24 hours
|
| 103 |
+
# json.dumps(session_data)
|
| 104 |
+
# )
|
| 105 |
+
|
| 106 |
+
# # Initialize empty message history
|
| 107 |
+
# # Key pattern: messages:{user_login_id}:{session_id}
|
| 108 |
+
# redis_client.setex(
|
| 109 |
+
# f"messages:{user_login_id}:{session_id}",
|
| 110 |
+
# 86400, # 24 hours
|
| 111 |
+
# json.dumps([])
|
| 112 |
+
# )
|
| 113 |
+
|
| 114 |
+
# return session_data
|
| 115 |
+
|
| 116 |
+
# def get_session(user_login_id: str, session_id: str) -> dict:
|
| 117 |
+
# """Get session data from Redis"""
|
| 118 |
+
# session_key = f"session:{user_login_id}:{session_id}"
|
| 119 |
+
# session_data = redis_client.get(session_key)
|
| 120 |
+
|
| 121 |
+
# if not session_data:
|
| 122 |
+
# raise HTTPException(status_code=404, detail="Session not found or expired")
|
| 123 |
+
|
| 124 |
+
# return json.loads(session_data)
|
| 125 |
+
|
| 126 |
+
# def add_message(user_login_id: str, session_id: str, role: str, content: str) -> str:
|
| 127 |
+
# """Add message to session chat history"""
|
| 128 |
+
# message_id = str(uuid.uuid4())
|
| 129 |
+
# message_data = {
|
| 130 |
+
# "message_id": message_id,
|
| 131 |
+
# "role": role, # "user" or "assistant"
|
| 132 |
+
# "content": content,
|
| 133 |
+
# "timestamp": datetime.now().isoformat()
|
| 134 |
+
# }
|
| 135 |
+
|
| 136 |
+
# # Get current message history
|
| 137 |
+
# messages_key = f"messages:{user_login_id}:{session_id}"
|
| 138 |
+
# messages_data = redis_client.get(messages_key)
|
| 139 |
+
|
| 140 |
+
# if messages_data:
|
| 141 |
+
# messages = json.loads(messages_data)
|
| 142 |
+
# else:
|
| 143 |
+
# messages = []
|
| 144 |
+
|
| 145 |
+
# # Add new message
|
| 146 |
+
# messages.append(message_data)
|
| 147 |
+
|
| 148 |
+
# # Update messages in Redis
|
| 149 |
+
# redis_client.setex(messages_key, 86400, json.dumps(messages))
|
| 150 |
+
|
| 151 |
+
# # Update session message count
|
| 152 |
+
# session_key = f"session:{user_login_id}:{session_id}"
|
| 153 |
+
# session_data = redis_client.get(session_key)
|
| 154 |
+
# if session_data:
|
| 155 |
+
# session = json.loads(session_data)
|
| 156 |
+
# session["message_count"] = len(messages)
|
| 157 |
+
# redis_client.setex(session_key, 86400, json.dumps(session))
|
| 158 |
+
|
| 159 |
+
# return message_id
|
| 160 |
+
|
| 161 |
+
# def get_message_history(user_login_id: str, session_id: str, limit: Optional[int] = None) -> List[Dict]:
|
| 162 |
+
# """Get message history for a session"""
|
| 163 |
+
# messages_key = f"messages:{user_login_id}:{session_id}"
|
| 164 |
+
# messages_data = redis_client.get(messages_key)
|
| 165 |
+
|
| 166 |
+
# if not messages_data:
|
| 167 |
+
# return []
|
| 168 |
+
|
| 169 |
+
# messages = json.loads(messages_data)
|
| 170 |
+
|
| 171 |
+
# # Return last 'limit' messages if specified
|
| 172 |
+
# if limit:
|
| 173 |
+
# return messages[-limit:] if len(messages) > limit else messages
|
| 174 |
+
|
| 175 |
+
# return messages
|
| 176 |
+
|
| 177 |
+
# def generate_session_title(user_login_id: str, session_id: str) -> str:
|
| 178 |
+
# """Generate a title for the session based on first user message"""
|
| 179 |
+
# try:
|
| 180 |
+
# # Get message history
|
| 181 |
+
# messages = get_message_history(user_login_id, session_id)
|
| 182 |
+
|
| 183 |
+
# if not messages:
|
| 184 |
+
# return "New Chat"
|
| 185 |
+
|
| 186 |
+
# # Get first user message
|
| 187 |
+
# first_user_message = next(
|
| 188 |
+
# (msg["content"] for msg in messages if msg["role"] == "user"),
|
| 189 |
+
# None
|
| 190 |
+
# )
|
| 191 |
+
|
| 192 |
+
# if not first_user_message:
|
| 193 |
+
# return "New Chat"
|
| 194 |
+
|
| 195 |
+
# # Create simple title from first message (first 6 words)
|
| 196 |
+
# words = first_user_message.split()[:6]
|
| 197 |
+
# title = " ".join(words) + ("..." if len(first_user_message.split()) > 6 else "")
|
| 198 |
+
|
| 199 |
+
# # Update session with new title
|
| 200 |
+
# session_key = f"session:{user_login_id}:{session_id}"
|
| 201 |
+
# session_data = redis_client.get(session_key)
|
| 202 |
+
|
| 203 |
+
# if session_data:
|
| 204 |
+
# session = json.loads(session_data)
|
| 205 |
+
# session["title"] = title
|
| 206 |
+
# redis_client.setex(session_key, 86400, json.dumps(session))
|
| 207 |
+
|
| 208 |
+
# return title
|
| 209 |
+
|
| 210 |
+
# except Exception as e:
|
| 211 |
+
# print(f"Error generating session title: {e}")
|
| 212 |
+
# return "New Chat"
|
| 213 |
+
|
| 214 |
+
# def update_session_title_if_needed(user_login_id: str, session_id: str, is_new_session: bool):
|
| 215 |
+
# """Update session title after first message"""
|
| 216 |
+
# try:
|
| 217 |
+
# session = get_session(user_login_id, session_id)
|
| 218 |
+
|
| 219 |
+
# # Only update if it's a new session or title is still "New Chat"
|
| 220 |
+
# if is_new_session or session.get("title") == "New Chat":
|
| 221 |
+
# generate_session_title(user_login_id, session_id)
|
| 222 |
+
|
| 223 |
+
# except Exception as e:
|
| 224 |
+
# print(f"Error updating session title: {e}")
|
| 225 |
+
|
| 226 |
+
# def format_conversation_context(messages: List[Dict], max_messages: int = 5) -> str:
|
| 227 |
+
# """Format conversation history as context for the AI agent"""
|
| 228 |
+
# if not messages:
|
| 229 |
+
# return ""
|
| 230 |
+
|
| 231 |
+
# # Get last N messages for context
|
| 232 |
+
# recent_messages = messages[-max_messages:] if len(messages) > max_messages else messages
|
| 233 |
+
|
| 234 |
+
# context_lines = []
|
| 235 |
+
# for msg in recent_messages:
|
| 236 |
+
# context_lines.append(f"{msg['role']}: {msg['content']}")
|
| 237 |
+
|
| 238 |
+
# return "\n".join(context_lines)
|
| 239 |
+
|
| 240 |
+
# # ==================== ENDPOINTS ====================
|
| 241 |
+
|
| 242 |
+
# @main_chatbot_route.get("/")
|
| 243 |
+
# async def root():
|
| 244 |
+
# """Root endpoint"""
|
| 245 |
+
# return {
|
| 246 |
+
# "message": "AI Data Analysis Chatbot API is running!",
|
| 247 |
+
# "version": "2.0.0",
|
| 248 |
+
# "features": [
|
| 249 |
+
# "Session Management",
|
| 250 |
+
# "Auto New Chat Detection",
|
| 251 |
+
# "Chat History",
|
| 252 |
+
# "Context Awareness",
|
| 253 |
+
# "Auto Title Generation"
|
| 254 |
+
# ]
|
| 255 |
+
# }
|
| 256 |
+
|
| 257 |
+
# @main_chatbot_route.post("/chat", response_model=ChatResponse)
|
| 258 |
+
# async def chat(request: ChatRequest):
|
| 259 |
+
# """
|
| 260 |
+
# Main chat endpoint with automatic session management
|
| 261 |
+
|
| 262 |
+
# Flow:
|
| 263 |
+
# 1. If session_id is None → Create new session
|
| 264 |
+
# 2. If session_id exists → Load existing session
|
| 265 |
+
# 3. Save user message
|
| 266 |
+
# 4. Process with AI agent (with conversation context)
|
| 267 |
+
# 5. Save assistant response
|
| 268 |
+
# 6. Auto-generate title if new session
|
| 269 |
+
# 7. Return response
|
| 270 |
+
|
| 271 |
+
# Args:
|
| 272 |
+
# user_login_id: User identifier
|
| 273 |
+
# session_id: Session ID (None for new chat)
|
| 274 |
+
# query: User's message
|
| 275 |
+
# org_id: Optional organization ID
|
| 276 |
+
# metadata: Optional metadata
|
| 277 |
+
|
| 278 |
+
# Returns:
|
| 279 |
+
# ChatResponse with session details and AI response
|
| 280 |
+
# """
|
| 281 |
+
# try:
|
| 282 |
+
# is_new_session = False
|
| 283 |
+
# session_data = None
|
| 284 |
+
|
| 285 |
+
# # ==================== STEP 1: Session Detection ====================
|
| 286 |
+
# if request.session_id is None:
|
| 287 |
+
# # NEW CHAT → Create new session
|
| 288 |
+
# session_data = create_session(
|
| 289 |
+
# user_login_id=request.user_login_id,
|
| 290 |
+
# org_id=request.org_id,
|
| 291 |
+
# metadata=request.metadata
|
| 292 |
+
# )
|
| 293 |
+
# session_id = session_data["session_id"]
|
| 294 |
+
# is_new_session = True
|
| 295 |
+
# conversation_history = []
|
| 296 |
+
|
| 297 |
+
# else:
|
| 298 |
+
# # EXISTING CHAT → Load session and history
|
| 299 |
+
# try:
|
| 300 |
+
# session_data = get_session(request.user_login_id, request.session_id)
|
| 301 |
+
# session_id = request.session_id
|
| 302 |
+
|
| 303 |
+
# # Get conversation history for context
|
| 304 |
+
# conversation_history = get_message_history(
|
| 305 |
+
# request.user_login_id,
|
| 306 |
+
# session_id,
|
| 307 |
+
# limit=5 # Last 5 messages for context
|
| 308 |
+
# )
|
| 309 |
+
|
| 310 |
+
# except HTTPException as e:
|
| 311 |
+
# if e.status_code == 404:
|
| 312 |
+
# # Session expired → Create new one
|
| 313 |
+
# session_data = create_session(
|
| 314 |
+
# user_login_id=request.user_login_id,
|
| 315 |
+
# org_id=request.org_id,
|
| 316 |
+
# metadata=request.metadata
|
| 317 |
+
# )
|
| 318 |
+
# session_id = session_data["session_id"]
|
| 319 |
+
# is_new_session = True
|
| 320 |
+
# conversation_history = []
|
| 321 |
+
# else:
|
| 322 |
+
# raise
|
| 323 |
+
|
| 324 |
+
# # ==================== STEP 2: Save User Message ====================
|
| 325 |
+
# user_message_id = add_message(
|
| 326 |
+
# user_login_id=request.user_login_id,
|
| 327 |
+
# session_id=session_id,
|
| 328 |
+
# role="user",
|
| 329 |
+
# content=request.query
|
| 330 |
+
# )
|
| 331 |
+
|
| 332 |
+
# # ==================== STEP 3: Process with AI Agent ====================
|
| 333 |
+
# # Format context from conversation history
|
| 334 |
+
# context = format_conversation_context(conversation_history) if conversation_history else ""
|
| 335 |
+
|
| 336 |
+
# # Call your AI agent
|
| 337 |
+
# # If your main() function accepts context, modify this:
|
| 338 |
+
# # result = await main(request.query, context=context)
|
| 339 |
+
# result = await main(request.query)
|
| 340 |
+
|
| 341 |
+
# assistant_response = str(result)
|
| 342 |
+
|
| 343 |
+
# # ==================== STEP 4: Save Assistant Response ====================
|
| 344 |
+
# assistant_message_id = add_message(
|
| 345 |
+
# user_login_id=request.user_login_id,
|
| 346 |
+
# session_id=session_id,
|
| 347 |
+
# role="assistant",
|
| 348 |
+
# content=assistant_response
|
| 349 |
+
# )
|
| 350 |
+
|
| 351 |
+
# # ==================== STEP 5: Auto-Generate Title ====================
|
| 352 |
+
# update_session_title_if_needed(
|
| 353 |
+
# request.user_login_id,
|
| 354 |
+
# session_id,
|
| 355 |
+
# is_new_session
|
| 356 |
+
# )
|
| 357 |
+
|
| 358 |
+
# # Get updated session with title
|
| 359 |
+
# session_data = get_session(request.user_login_id, session_id)
|
| 360 |
+
|
| 361 |
+
# # ==================== STEP 6: Return Response ====================
|
| 362 |
+
# return ChatResponse(
|
| 363 |
+
# session_id=session_id,
|
| 364 |
+
# user_message=request.query,
|
| 365 |
+
# assistant_response=assistant_response,
|
| 366 |
+
# is_new_session=is_new_session,
|
| 367 |
+
# session_title=session_data.get("title", "New Chat"),
|
| 368 |
+
# timestamp=datetime.now().isoformat()
|
| 369 |
+
# )
|
| 370 |
+
|
| 371 |
+
# except HTTPException:
|
| 372 |
+
# raise
|
| 373 |
+
# except Exception as e:
|
| 374 |
+
# raise HTTPException(status_code=500, detail=f"Error in chat: {str(e)}")
|
| 375 |
+
|
| 376 |
+
# @main_chatbot_route.get("/history/{user_login_id}/{session_id}", response_model=ChatHistoryResponse)
|
| 377 |
+
# async def get_chat_history(user_login_id: str, session_id: str):
|
| 378 |
+
# """
|
| 379 |
+
# Get complete chat history for a session
|
| 380 |
+
|
| 381 |
+
# Args:
|
| 382 |
+
# user_login_id: User identifier
|
| 383 |
+
# session_id: Session ID
|
| 384 |
+
|
| 385 |
+
# Returns:
|
| 386 |
+
# ChatHistoryResponse with all messages
|
| 387 |
+
# """
|
| 388 |
+
# try:
|
| 389 |
+
# # Get session data
|
| 390 |
+
# session_data = get_session(user_login_id, session_id)
|
| 391 |
+
|
| 392 |
+
# # Get all messages
|
| 393 |
+
# messages = get_message_history(user_login_id, session_id)
|
| 394 |
+
|
| 395 |
+
# # Convert to MessageResponse objects
|
| 396 |
+
# message_responses = [
|
| 397 |
+
# MessageResponse(
|
| 398 |
+
# message_id=msg["message_id"],
|
| 399 |
+
# role=msg["role"],
|
| 400 |
+
# content=msg["content"],
|
| 401 |
+
# timestamp=msg["timestamp"]
|
| 402 |
+
# )
|
| 403 |
+
# for msg in messages
|
| 404 |
+
# ]
|
| 405 |
+
|
| 406 |
+
# return ChatHistoryResponse(
|
| 407 |
+
# session_id=session_id,
|
| 408 |
+
# title=session_data.get("title", "New Chat"),
|
| 409 |
+
# created_at=session_data.get("created_at"),
|
| 410 |
+
# message_count=len(messages),
|
| 411 |
+
# messages=message_responses
|
| 412 |
+
# )
|
| 413 |
+
|
| 414 |
+
# except HTTPException:
|
| 415 |
+
# raise
|
| 416 |
+
# except Exception as e:
|
| 417 |
+
# raise HTTPException(status_code=500, detail=f"Error fetching history: {str(e)}")
|
| 418 |
+
|
| 419 |
+
# @main_chatbot_route.get("/sessions/{user_login_id}")
|
| 420 |
+
# async def get_user_sessions(user_login_id: str):
|
| 421 |
+
# """
|
| 422 |
+
# Get all sessions for a user
|
| 423 |
+
|
| 424 |
+
# Args:
|
| 425 |
+
# user_login_id: User identifier
|
| 426 |
+
|
| 427 |
+
# Returns:
|
| 428 |
+
# List of all sessions for the user
|
| 429 |
+
# """
|
| 430 |
+
# try:
|
| 431 |
+
# sessions = []
|
| 432 |
+
|
| 433 |
+
# # Scan for all session keys for this user
|
| 434 |
+
# pattern = f"session:{user_login_id}:*"
|
| 435 |
+
# for key in redis_client.scan_iter(match=pattern):
|
| 436 |
+
# session_data = redis_client.get(key)
|
| 437 |
+
# if session_data:
|
| 438 |
+
# session = json.loads(session_data)
|
| 439 |
+
# sessions.append(session)
|
| 440 |
+
|
| 441 |
+
# # Sort by created_at (most recent first)
|
| 442 |
+
# sessions.sort(key=lambda x: x.get("created_at", ""), reverse=True)
|
| 443 |
+
|
| 444 |
+
# return {
|
| 445 |
+
# "user_login_id": user_login_id,
|
| 446 |
+
# "total_sessions": len(sessions),
|
| 447 |
+
# "sessions": sessions
|
| 448 |
+
# }
|
| 449 |
+
|
| 450 |
+
# except Exception as e:
|
| 451 |
+
# raise HTTPException(status_code=500, detail=f"Error fetching sessions: {str(e)}")
|
| 452 |
+
|
| 453 |
+
# @main_chatbot_route.delete("/sessions/{user_login_id}/{session_id}")
|
| 454 |
+
# async def delete_session(user_login_id: str, session_id: str):
|
| 455 |
+
# """
|
| 456 |
+
# Delete a session and its messages
|
| 457 |
+
|
| 458 |
+
# Args:
|
| 459 |
+
# user_login_id: User identifier
|
| 460 |
+
# session_id: Session ID to delete
|
| 461 |
+
|
| 462 |
+
# Returns:
|
| 463 |
+
# Deletion confirmation
|
| 464 |
+
# """
|
| 465 |
+
# try:
|
| 466 |
+
# # Verify session exists
|
| 467 |
+
# get_session(user_login_id, session_id)
|
| 468 |
+
|
| 469 |
+
# # Delete session and messages
|
| 470 |
+
# session_key = f"session:{user_login_id}:{session_id}"
|
| 471 |
+
# messages_key = f"messages:{user_login_id}:{session_id}"
|
| 472 |
+
|
| 473 |
+
# redis_client.delete(session_key)
|
| 474 |
+
# redis_client.delete(messages_key)
|
| 475 |
+
|
| 476 |
+
# return {
|
| 477 |
+
# "message": f"Session {session_id} deleted successfully",
|
| 478 |
+
# "session_id": session_id,
|
| 479 |
+
# "user_login_id": user_login_id
|
| 480 |
+
# }
|
| 481 |
+
|
| 482 |
+
# except HTTPException:
|
| 483 |
+
# raise
|
| 484 |
+
# except Exception as e:
|
| 485 |
+
# raise HTTPException(status_code=500, detail=f"Error deleting session: {str(e)}")
|
| 486 |
+
|
| 487 |
+
# @main_chatbot_route.get("/health")
|
| 488 |
+
# async def health():
|
| 489 |
+
# """Health check endpoint"""
|
| 490 |
+
# try:
|
| 491 |
+
# redis_client.ping()
|
| 492 |
+
# redis_status = "connected"
|
| 493 |
+
|
| 494 |
+
# # Count total sessions
|
| 495 |
+
# total_sessions = len(list(redis_client.scan_iter(match="session:*")))
|
| 496 |
+
|
| 497 |
+
# except:
|
| 498 |
+
# redis_status = "disconnected"
|
| 499 |
+
# total_sessions = 0
|
| 500 |
+
|
| 501 |
+
# return {
|
| 502 |
+
# "status": "ok",
|
| 503 |
+
# "service": "Chatbot API with Session Management",
|
| 504 |
+
# "redis_status": redis_status,
|
| 505 |
+
# "total_sessions": total_sessions,
|
| 506 |
+
# "session_ttl": "24 hours"
|
| 507 |
+
# }
|
| 508 |
+
|
| 509 |
+
# #============putsession_title===================
|
| 510 |
+
# @main_chatbot_route.put("/sessions/{user_login_id}/{session_id}/title")
|
| 511 |
+
# async def update_session_title(user_login_id: str, session_id: str, title: str):
|
| 512 |
+
# try:
|
| 513 |
+
# session_key = f"session:{user_login_id}:{session_id}"
|
| 514 |
+
# session_data = redis_client.get(session_key)
|
| 515 |
+
|
| 516 |
+
# if session_data:
|
| 517 |
+
# session = json.loads(session_data)
|
| 518 |
+
# session["title"] = title
|
| 519 |
+
# redis_client.setex(session_key, 86400, json.dumps(session))
|
| 520 |
+
# return {"message": "Session title updated successfully."}
|
| 521 |
+
# else:
|
| 522 |
+
# raise HTTPException(status_code=404, detail="Session not found.")
|
| 523 |
+
|
| 524 |
+
# except HTTPException:
|
| 525 |
+
# raise
|
| 526 |
+
# except Exception as e:
|
| 527 |
+
# raise HTTPException(status_code=500, detail=f"Error updating session title: {str(e)}")
|
| 528 |
+
|
| 529 |
+
#======================main_chatbot=====================
|
| 530 |
+
from fastapi import APIRouter, HTTPException
|
| 531 |
+
from datetime import datetime
|
| 532 |
+
from Redis.sessions_new import *
|
| 533 |
+
from .helpers.main_agent_helpers import main
|
| 534 |
+
|
| 535 |
+
main_chatbot_route = APIRouter(prefix="/main_chatbot", tags=["main_chatbot"])
|
| 536 |
+
|
| 537 |
+
@main_chatbot_route.get("/")
|
| 538 |
+
async def root():
|
| 539 |
+
return {
|
| 540 |
+
"message": "AI Data Analysis Chatbot API is running!",
|
| 541 |
+
"version": "2.0.0",
|
| 542 |
+
"features": ["Session Management", "Chat History", "Context Awareness", "Auto Title"]
|
| 543 |
+
}
|
| 544 |
+
|
| 545 |
+
@main_chatbot_route.post("/chat", response_model=ChatResponse)
|
| 546 |
+
async def chat(request: ChatRequest):
|
| 547 |
+
try:
|
| 548 |
+
is_new_session = False
|
| 549 |
+
|
| 550 |
+
if request.session_id is None:
|
| 551 |
+
session_data = create_session(request.user_login_id, request.org_id, request.metadata)
|
| 552 |
+
session_id = session_data["session_id"]
|
| 553 |
+
is_new_session = True
|
| 554 |
+
conversation_history = []
|
| 555 |
+
else:
|
| 556 |
+
try:
|
| 557 |
+
session_data = get_session(request.user_login_id, request.session_id)
|
| 558 |
+
session_id = request.session_id
|
| 559 |
+
conversation_history = get_message_history(request.user_login_id, session_id, limit=5)
|
| 560 |
+
except HTTPException as e:
|
| 561 |
+
if e.status_code == 404:
|
| 562 |
+
session_data = create_session(request.user_login_id, request.org_id, request.metadata)
|
| 563 |
+
session_id = session_data["session_id"]
|
| 564 |
+
is_new_session = True
|
| 565 |
+
conversation_history = []
|
| 566 |
+
else:
|
| 567 |
+
raise
|
| 568 |
+
|
| 569 |
+
add_message(request.user_login_id, session_id, "user", request.query)
|
| 570 |
+
|
| 571 |
+
context = format_conversation_context(conversation_history) if conversation_history else ""
|
| 572 |
+
result = await main(request.query) # You can modify to pass context if needed
|
| 573 |
+
|
| 574 |
+
add_message(request.user_login_id, session_id, "assistant", result)
|
| 575 |
+
|
| 576 |
+
update_session_title_if_needed(request.user_login_id, session_id, is_new_session)
|
| 577 |
+
session_data = get_session(request.user_login_id, session_id)
|
| 578 |
+
|
| 579 |
+
return ChatResponse(
|
| 580 |
+
session_id=session_id,
|
| 581 |
+
user_message=request.query,
|
| 582 |
+
assistant_response=result,
|
| 583 |
+
is_new_session=is_new_session,
|
| 584 |
+
session_title=session_data.get("title", "New Chat"),
|
| 585 |
+
timestamp=datetime.now().isoformat()
|
| 586 |
+
)
|
| 587 |
+
except Exception as e:
|
| 588 |
+
raise HTTPException(status_code=500, detail=f"Error in chat: {str(e)}")
|
Routes/main_chat_bot.py_main.txt
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi import FastAPI
|
| 2 |
+
from pydantic import BaseModel
|
| 3 |
+
from .helpers.main_agent_helpers import main
|
| 4 |
+
from fastapi import APIRouter
|
| 5 |
+
|
| 6 |
+
main_chatbot_route = APIRouter(prefix="/main_chatbot", tags=["main_chatbot"])
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
# --- Request Model ---
|
| 10 |
+
class ChatRequest(BaseModel):
|
| 11 |
+
query: str
|
| 12 |
+
|
| 13 |
+
# --- Endpoints ---
|
| 14 |
+
@main_chatbot_route.get("/")
|
| 15 |
+
async def root():
|
| 16 |
+
return {"message": "AI Data Analysis API is running!"}
|
| 17 |
+
|
| 18 |
+
@main_chatbot_route.post("/chat")
|
| 19 |
+
async def chat(request: ChatRequest):
|
| 20 |
+
try:
|
| 21 |
+
result = await main(request.query)
|
| 22 |
+
return {"response": str(result)} # <-- Wrap in a dictionary and str() to be safe
|
| 23 |
+
except Exception as e:
|
| 24 |
+
return {"error": str(e)}
|
agent_tools/Hybrid_Rag_agent.py
ADDED
|
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import asyncio
|
| 2 |
+
|
| 3 |
+
from agents import Agent, ItemHelpers, MessageOutputItem, Runner, trace
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
from agents import Agent, Runner, function_tool
|
| 7 |
+
import asyncio
|
| 8 |
+
import os
|
| 9 |
+
import dotenv
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
from agents import Agent, ModelSettings, function_tool , Runner
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
@function_tool
|
| 18 |
+
def vector_db_agent_tool(query: str) -> str:
|
| 19 |
+
print("vector_db_agent(query)")
|
| 20 |
+
|
| 21 |
+
@function_tool
|
| 22 |
+
def neo4j_agent_tool(query: str) -> str:
|
| 23 |
+
print("neo4j_agent(query)")
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
# async def main():
|
| 30 |
+
# result = await Runner.run(
|
| 31 |
+
# Dataframe_agent,
|
| 32 |
+
# "so analayse the df and give some insights on it",
|
| 33 |
+
# )
|
| 34 |
+
# print(result)
|
| 35 |
+
# if __name__ == "__main__":
|
| 36 |
+
# asyncio.run(main())
|
| 37 |
+
|
| 38 |
+
#custom agents for the project and why
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
orchestrator_agent = Agent(
|
| 46 |
+
name="orchestrator_agent",
|
| 47 |
+
instructions=(
|
| 48 |
+
"You are a translation agent. You use the tools given to you to translate."
|
| 49 |
+
"If asked for multiple translations, you call the relevant tools in order."
|
| 50 |
+
"You never translate on your own, you always use the provided tools."
|
| 51 |
+
),
|
| 52 |
+
tools=[
|
| 53 |
+
|
| 54 |
+
vector_db_agent_tool.as_tool(
|
| 55 |
+
tool_name="vector_db_agent",
|
| 56 |
+
tool_description="Analyse the dataframe and give some insights on it and will use for reports and visualizations",
|
| 57 |
+
),
|
| 58 |
+
neo4j_agent_tool.as_tool(
|
| 59 |
+
tool_name="neo4j_agent",
|
| 60 |
+
tool_description="Analyse the dataframe and give some insights on it and will use for reports and visualizations",
|
| 61 |
+
),
|
| 62 |
+
|
| 63 |
+
],
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
synthesizer_agent = Agent(
|
| 67 |
+
name="synthesizer_agent",
|
| 68 |
+
instructions="You inspect translations, correct them if needed, and produce a final concatenated response.",
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
async def main():
|
| 73 |
+
msg = input("Hi! What would you like translated, and to which languages? ")
|
| 74 |
+
|
| 75 |
+
# Run the entire orchestration in a single trace
|
| 76 |
+
with trace("Orchestrator evaluator"):
|
| 77 |
+
orchestrator_result = await Runner.run(orchestrator_agent, msg)
|
| 78 |
+
|
| 79 |
+
for item in orchestrator_result.new_items:
|
| 80 |
+
if isinstance(item, MessageOutputItem):
|
| 81 |
+
text = ItemHelpers.text_message_output(item)
|
| 82 |
+
if text:
|
| 83 |
+
print(f" - Translation step: {text}")
|
| 84 |
+
|
| 85 |
+
synthesizer_result = await Runner.run(
|
| 86 |
+
synthesizer_agent, orchestrator_result.to_input_list()
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
print(f"\n\nFinal response:\n{synthesizer_result.final_output}")
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
if __name__ == "__main__":
|
| 93 |
+
asyncio.run(main())
|
agent_tools/Reflection_agent.py
ADDED
|
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import asyncio
|
| 4 |
+
from dataclasses import dataclass
|
| 5 |
+
from typing import Literal
|
| 6 |
+
import os
|
| 7 |
+
import dotenv
|
| 8 |
+
dotenv.load_dotenv()
|
| 9 |
+
os.environ["OPENAI_API_KEY"] = os.getenv("OPENAI_API_KEY")
|
| 10 |
+
|
| 11 |
+
from agents import Agent, ItemHelpers, Runner, TResponseInputItem, trace
|
| 12 |
+
|
| 13 |
+
"""
|
| 14 |
+
This example shows the LLM as a judge pattern. The first agent generates an outline for a story.
|
| 15 |
+
The second agent judges the outline and provides feedback. We loop until the judge is satisfied
|
| 16 |
+
with the outline.
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
story_outline_generator = Agent(
|
| 20 |
+
name="story_outline_generator",
|
| 21 |
+
instructions=(
|
| 22 |
+
"You generate a very short story outline based on the user's input. "
|
| 23 |
+
"If there is any feedback provided, use it to improve the outline."
|
| 24 |
+
),
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
@dataclass
|
| 29 |
+
class EvaluationFeedback:
|
| 30 |
+
feedback: str
|
| 31 |
+
score: Literal["pass", "needs_improvement", "fail"]
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
evaluator = Agent[None](
|
| 35 |
+
name="evaluator",
|
| 36 |
+
instructions=(
|
| 37 |
+
"You evaluate a story outline and decide if it's good enough. "
|
| 38 |
+
"If it's not good enough, you provide feedback on what needs to be improved. "
|
| 39 |
+
"Never give it a pass on the first try. After 5 attempts, you can give it a pass if the story outline is good enough - do not go for perfection"
|
| 40 |
+
),
|
| 41 |
+
output_type=EvaluationFeedback,
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
async def main() -> None:
|
| 46 |
+
msg = input("What kind of story would you like to hear? ")
|
| 47 |
+
input_items: list[TResponseInputItem] = [{"content": msg, "role": "user"}]
|
| 48 |
+
|
| 49 |
+
latest_outline: str | None = None
|
| 50 |
+
|
| 51 |
+
# We'll run the entire workflow in a single trace
|
| 52 |
+
with trace("LLM as a judge"):
|
| 53 |
+
while True:
|
| 54 |
+
story_outline_result = await Runner.run(
|
| 55 |
+
story_outline_generator,
|
| 56 |
+
input_items,
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
input_items = story_outline_result.to_input_list()
|
| 60 |
+
latest_outline = ItemHelpers.text_message_outputs(story_outline_result.new_items)
|
| 61 |
+
print("Story outline generated")
|
| 62 |
+
|
| 63 |
+
evaluator_result = await Runner.run(evaluator, input_items)
|
| 64 |
+
result: EvaluationFeedback = evaluator_result.final_output
|
| 65 |
+
|
| 66 |
+
print(f"Evaluator score: {result.score}")
|
| 67 |
+
|
| 68 |
+
if result.score == "pass":
|
| 69 |
+
print("Story outline is good enough, exiting.")
|
| 70 |
+
break
|
| 71 |
+
|
| 72 |
+
print("Re-running with feedback")
|
| 73 |
+
|
| 74 |
+
input_items.append({"content": f"Feedback: {result.feedback}", "role": "user"})
|
| 75 |
+
|
| 76 |
+
print(f"Final story outline: {latest_outline}")
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
if __name__ == "__main__":
|
| 80 |
+
asyncio.run(main())
|
agent_tools/autovis_tool.py
ADDED
|
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
# from autoviz import AutoViz_Class
|
| 3 |
+
# import os
|
| 4 |
+
# import pandas as pd
|
| 5 |
+
# import uuid
|
| 6 |
+
|
| 7 |
+
# def run_autoviz(
|
| 8 |
+
# filename,
|
| 9 |
+
# sep=",",
|
| 10 |
+
# depVar="",
|
| 11 |
+
# dfte=None,
|
| 12 |
+
# header=0,
|
| 13 |
+
# verbose=2,
|
| 14 |
+
# lowess=False,
|
| 15 |
+
# chart_format="html",
|
| 16 |
+
# max_rows_analyzed=150000,
|
| 17 |
+
# max_cols_analyzed=30,
|
| 18 |
+
# save_plot_dir=None
|
| 19 |
+
# ):
|
| 20 |
+
# # Generate unique directory using uuid
|
| 21 |
+
# vis_in = str(uuid.uuid4())
|
| 22 |
+
# save_plot_dir = f"./{chart_format}_{vis_in}"
|
| 23 |
+
# os.makedirs(save_plot_dir, exist_ok=True)
|
| 24 |
+
|
| 25 |
+
# # Read the CSV file to check if the path is correct
|
| 26 |
+
# df = pd.read_csv(filename)
|
| 27 |
+
# # print(df.head())
|
| 28 |
+
|
| 29 |
+
# # Run AutoViz
|
| 30 |
+
# AV = AutoViz_Class()
|
| 31 |
+
# print("Running AutoViz...")
|
| 32 |
+
# dft = AV.AutoViz(
|
| 33 |
+
# filename,
|
| 34 |
+
# sep=sep,
|
| 35 |
+
# depVar=depVar,
|
| 36 |
+
# dfte=dfte,
|
| 37 |
+
# header=header,
|
| 38 |
+
# verbose=verbose,
|
| 39 |
+
# lowess=lowess,
|
| 40 |
+
# chart_format=chart_format,
|
| 41 |
+
# max_rows_analyzed=max_rows_analyzed,
|
| 42 |
+
# max_cols_analyzed=max_cols_analyzed,
|
| 43 |
+
# save_plot_dir=save_plot_dir
|
| 44 |
+
# )
|
| 45 |
+
# print(dft)
|
| 46 |
+
# print(f"Plots saved in: {save_plot_dir}")
|
| 47 |
+
# return dft
|
| 48 |
+
|
| 49 |
+
# # Example usage:
|
| 50 |
+
# run_autoviz(
|
| 51 |
+
# filename=r"C:\Users\Dell\Documents\MR-AI\openai_agents\healthcare-data-30.csv"
|
| 52 |
+
# )
|
| 53 |
+
|
| 54 |
+
from fastapi import FastAPI, Request
|
| 55 |
+
from pydantic import BaseModel
|
| 56 |
+
from autoviz import AutoViz_Class
|
| 57 |
+
import os
|
| 58 |
+
import pandas as pd
|
| 59 |
+
import uuid
|
| 60 |
+
from fastapi import APIRouter
|
| 61 |
+
|
| 62 |
+
Autoviz_router = APIRouter(prefix="/autoviz", tags=["autoviz"])
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
class AutoVizParams(BaseModel):
|
| 66 |
+
filename: str
|
| 67 |
+
sep: str = ","
|
| 68 |
+
depVar: str = ""
|
| 69 |
+
header: int = 0
|
| 70 |
+
verbose: int = 2
|
| 71 |
+
lowess: bool = False
|
| 72 |
+
chart_format: str = "html"
|
| 73 |
+
max_rows_analyzed: int = 150000
|
| 74 |
+
max_cols_analyzed: int = 30
|
| 75 |
+
|
| 76 |
+
def run_autoviz(
|
| 77 |
+
filename,
|
| 78 |
+
sep=",",
|
| 79 |
+
depVar="",
|
| 80 |
+
dfte=None,
|
| 81 |
+
header=0,
|
| 82 |
+
verbose=2,
|
| 83 |
+
lowess=False,
|
| 84 |
+
chart_format="html",
|
| 85 |
+
max_rows_analyzed=150000,
|
| 86 |
+
max_cols_analyzed=30,
|
| 87 |
+
save_plot_dir=None
|
| 88 |
+
):
|
| 89 |
+
vis_in = str(uuid.uuid4())
|
| 90 |
+
save_plot_dir = f"./{chart_format}_{vis_in}"
|
| 91 |
+
os.makedirs(save_plot_dir, exist_ok=True)
|
| 92 |
+
df = pd.read_csv(filename)
|
| 93 |
+
AV = AutoViz_Class()
|
| 94 |
+
dft = AV.AutoViz(
|
| 95 |
+
filename,
|
| 96 |
+
sep=sep,
|
| 97 |
+
depVar=depVar,
|
| 98 |
+
dfte=dfte,
|
| 99 |
+
header=header,
|
| 100 |
+
verbose=verbose,
|
| 101 |
+
lowess=lowess,
|
| 102 |
+
chart_format=chart_format,
|
| 103 |
+
max_rows_analyzed=max_rows_analyzed,
|
| 104 |
+
max_cols_analyzed=max_cols_analyzed,
|
| 105 |
+
save_plot_dir=save_plot_dir
|
| 106 |
+
)
|
| 107 |
+
return {"message": "AutoViz run complete", "plots_dir": save_plot_dir}
|
| 108 |
+
|
| 109 |
+
@Autoviz_router.post("/run_autoviz")
|
| 110 |
+
async def autoviz_api(params: AutoVizParams):
|
| 111 |
+
result = run_autoviz(
|
| 112 |
+
filename=params.filename,
|
| 113 |
+
sep=params.sep,
|
| 114 |
+
depVar=params.depVar,
|
| 115 |
+
dfte=None,
|
| 116 |
+
header=params.header,
|
| 117 |
+
verbose=params.verbose,
|
| 118 |
+
lowess=params.lowess,
|
| 119 |
+
chart_format=params.chart_format,
|
| 120 |
+
max_rows_analyzed=params.max_rows_analyzed,
|
| 121 |
+
max_cols_analyzed=params.max_cols_analyzed,
|
| 122 |
+
save_plot_dir=None
|
| 123 |
+
)
|
| 124 |
+
return result
|
agent_tools/code_intepreter.py
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import asyncio
|
| 2 |
+
import os
|
| 3 |
+
import dotenv
|
| 4 |
+
import pandas as pd
|
| 5 |
+
|
| 6 |
+
from agents import Agent, CodeInterpreterTool, Runner, trace
|
| 7 |
+
|
| 8 |
+
dotenv.load_dotenv()
|
| 9 |
+
os.environ["OPENAI_API_KEY"] = os.getenv("OPENAI_API_KEY")
|
| 10 |
+
|
| 11 |
+
CSV_PATH = r"C:\Users\Dell\Documents\MR-AI\openai_agents\healthcare-data-30.csv"
|
| 12 |
+
|
| 13 |
+
async def main():
|
| 14 |
+
if not os.path.exists(CSV_PATH):
|
| 15 |
+
raise FileNotFoundError(f"❌ CSV file not found at {CSV_PATH}")
|
| 16 |
+
|
| 17 |
+
df = pd.read_csv(CSV_PATH)
|
| 18 |
+
print(f"📄 Loaded CSV data (first 5 rows):\n{df.head()}\n")
|
| 19 |
+
|
| 20 |
+
# Prepare textual context for the model
|
| 21 |
+
df_preview = df.head(10).to_markdown()
|
| 22 |
+
data_context = (
|
| 23 |
+
f"The dataset preview is:\n{df_preview}\n\n"
|
| 24 |
+
f"Column headers: {', '.join(df.columns)}"
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
# Define the agent
|
| 28 |
+
agent = Agent(
|
| 29 |
+
name="CSV Code Interpreter",
|
| 30 |
+
model="gpt-4.1",
|
| 31 |
+
instructions=(
|
| 32 |
+
"You are a data analyst who loves working with CSV data. "
|
| 33 |
+
"You can analyze and visualize data from pandas DataFrames using Python code. "
|
| 34 |
+
"Use the provided CSV data to answer questions.\n"
|
| 35 |
+
),
|
| 36 |
+
tools=[
|
| 37 |
+
CodeInterpreterTool(
|
| 38 |
+
tool_config={"type": "code_interpreter", "container": {"type": "auto"}}
|
| 39 |
+
)
|
| 40 |
+
],
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
query = input("💬 Ask a question about the CSV file: ")
|
| 44 |
+
|
| 45 |
+
with trace("CSV code interpreter example"):
|
| 46 |
+
print("\n🔍 Processing query...")
|
| 47 |
+
|
| 48 |
+
# ✅ Combine the CSV context + user query in one go
|
| 49 |
+
result = Runner.run_streamed(
|
| 50 |
+
agent,
|
| 51 |
+
f"Dataset context:\n{data_context}\n\nUser query: {query}"
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
async for event in result.stream_events():
|
| 55 |
+
if (
|
| 56 |
+
event.type == "run_item_stream_event"
|
| 57 |
+
and event.item.type == "tool_call_item"
|
| 58 |
+
and event.item.raw_item.type == "code_interpreter_call"
|
| 59 |
+
):
|
| 60 |
+
print(f"\n💻 Code executed:\n```\n{event.item.raw_item.code}\n```\n")
|
| 61 |
+
elif event.type == "run_item_stream_event":
|
| 62 |
+
print(f"Other event: {event.item.type}")
|
| 63 |
+
|
| 64 |
+
print(f"\n✅ Final output: {result.final_output}")
|
| 65 |
+
|
| 66 |
+
if __name__ == "__main__":
|
| 67 |
+
asyncio.run(main())
|
alembic/README
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
Generic single-database configuration.
|
alembic/__pycache__/env.cpython-312.pyc
ADDED
|
Binary file (3.15 kB). View file
|
|
|
alembic/env.py
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from logging.config import fileConfig
|
| 2 |
+
|
| 3 |
+
from sqlalchemy import engine_from_config
|
| 4 |
+
from sqlalchemy import pool
|
| 5 |
+
|
| 6 |
+
from alembic import context
|
| 7 |
+
|
| 8 |
+
# this is the Alembic Config object, which provides
|
| 9 |
+
# access to the values within the .ini file in use.
|
| 10 |
+
config = context.config
|
| 11 |
+
|
| 12 |
+
# Interpret the config file for Python logging.
|
| 13 |
+
# This line sets up loggers basically.
|
| 14 |
+
if config.config_file_name is not None:
|
| 15 |
+
fileConfig(config.config_file_name)
|
| 16 |
+
|
| 17 |
+
# add your model's MetaData object here
|
| 18 |
+
# for 'autogenerate' support
|
| 19 |
+
import sys
|
| 20 |
+
import os
|
| 21 |
+
from pathlib import Path
|
| 22 |
+
|
| 23 |
+
# Add the project root to Python path
|
| 24 |
+
project_root = Path(__file__).resolve().parents[1]
|
| 25 |
+
sys.path.insert(0, str(project_root))
|
| 26 |
+
|
| 27 |
+
from backend.database import Base, DATABASE_URL
|
| 28 |
+
from backend import models # Import all models to register them
|
| 29 |
+
from dotenv import load_dotenv
|
| 30 |
+
|
| 31 |
+
# Load environment variables
|
| 32 |
+
load_dotenv()
|
| 33 |
+
|
| 34 |
+
target_metadata = Base.metadata
|
| 35 |
+
|
| 36 |
+
# Set the database URL for Alembic
|
| 37 |
+
config.set_main_option('sqlalchemy.url', DATABASE_URL)
|
| 38 |
+
|
| 39 |
+
# other values from the config, defined by the needs of env.py,
|
| 40 |
+
# can be acquired:
|
| 41 |
+
# my_important_option = config.get_main_option("my_important_option")
|
| 42 |
+
# ... etc.
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def run_migrations_offline() -> None:
|
| 46 |
+
"""Run migrations in 'offline' mode.
|
| 47 |
+
|
| 48 |
+
This configures the context with just a URL
|
| 49 |
+
and not an Engine, though an Engine is acceptable
|
| 50 |
+
here as well. By skipping the Engine creation
|
| 51 |
+
we don't even need a DBAPI to be available.
|
| 52 |
+
|
| 53 |
+
Calls to context.execute() here emit the given string to the
|
| 54 |
+
script output.
|
| 55 |
+
|
| 56 |
+
"""
|
| 57 |
+
url = config.get_main_option("sqlalchemy.url")
|
| 58 |
+
context.configure(
|
| 59 |
+
url=url,
|
| 60 |
+
target_metadata=target_metadata,
|
| 61 |
+
literal_binds=True,
|
| 62 |
+
dialect_opts={"paramstyle": "named"},
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
with context.begin_transaction():
|
| 66 |
+
context.run_migrations()
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def run_migrations_online() -> None:
|
| 70 |
+
"""Run migrations in 'online' mode.
|
| 71 |
+
|
| 72 |
+
In this scenario we need to create an Engine
|
| 73 |
+
and associate a connection with the context.
|
| 74 |
+
|
| 75 |
+
"""
|
| 76 |
+
connectable = engine_from_config(
|
| 77 |
+
config.get_section(config.config_ini_section, {}),
|
| 78 |
+
prefix="sqlalchemy.",
|
| 79 |
+
poolclass=pool.NullPool,
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
with connectable.connect() as connection:
|
| 83 |
+
context.configure(
|
| 84 |
+
connection=connection, target_metadata=target_metadata
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
with context.begin_transaction():
|
| 88 |
+
context.run_migrations()
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
if context.is_offline_mode():
|
| 92 |
+
run_migrations_offline()
|
| 93 |
+
else:
|
| 94 |
+
run_migrations_online()
|
alembic/script.py.mako
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""${message}
|
| 2 |
+
|
| 3 |
+
Revision ID: ${up_revision}
|
| 4 |
+
Revises: ${down_revision | comma,n}
|
| 5 |
+
Create Date: ${create_date}
|
| 6 |
+
|
| 7 |
+
"""
|
| 8 |
+
from typing import Sequence, Union
|
| 9 |
+
|
| 10 |
+
from alembic import op
|
| 11 |
+
import sqlalchemy as sa
|
| 12 |
+
${imports if imports else ""}
|
| 13 |
+
|
| 14 |
+
# revision identifiers, used by Alembic.
|
| 15 |
+
revision: str = ${repr(up_revision)}
|
| 16 |
+
down_revision: Union[str, None] = ${repr(down_revision)}
|
| 17 |
+
branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)}
|
| 18 |
+
depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)}
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def upgrade() -> None:
|
| 22 |
+
${upgrades if upgrades else "pass"}
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def downgrade() -> None:
|
| 26 |
+
${downgrades if downgrades else "pass"}
|
alembic/versions/049cf12dc407_removed_data_column_and_added_file_.py
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""removed data column and added file_metadata column in conversation_data
|
| 2 |
+
|
| 3 |
+
Revision ID: 049cf12dc407
|
| 4 |
+
Revises: 519b15d0dca6
|
| 5 |
+
Create Date: 2025-10-27 16:27:22.904582
|
| 6 |
+
|
| 7 |
+
"""
|
| 8 |
+
from typing import Sequence, Union
|
| 9 |
+
|
| 10 |
+
from alembic import op
|
| 11 |
+
import sqlalchemy as sa
|
| 12 |
+
from sqlalchemy.dialects import postgresql
|
| 13 |
+
|
| 14 |
+
# revision identifiers, used by Alembic.
|
| 15 |
+
revision: str = '049cf12dc407'
|
| 16 |
+
down_revision: Union[str, None] = '519b15d0dca6'
|
| 17 |
+
branch_labels: Union[str, Sequence[str], None] = None
|
| 18 |
+
depends_on: Union[str, Sequence[str], None] = None
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def upgrade() -> None:
|
| 22 |
+
# ### commands auto generated by Alembic - please adjust! ###
|
| 23 |
+
op.drop_index('idx_convo_id_is_saved', table_name='stored_convoId_data')
|
| 24 |
+
op.drop_table('stored_convoId_data')
|
| 25 |
+
op.add_column('conversation_data', sa.Column('is_saved', sa.Boolean(), nullable=False))
|
| 26 |
+
op.add_column('conversation_data', sa.Column('file_metadata', postgresql.JSONB(astext_type=sa.Text()), nullable=True))
|
| 27 |
+
op.add_column('conversation_data', sa.Column('response', postgresql.JSONB(astext_type=sa.Text()), nullable=False))
|
| 28 |
+
op.add_column('conversation_data', sa.Column('user_query', postgresql.JSONB(astext_type=sa.Text()), nullable=False))
|
| 29 |
+
op.drop_column('conversation_data', 'data')
|
| 30 |
+
# ### end Alembic commands ###
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def downgrade() -> None:
|
| 34 |
+
# ### commands auto generated by Alembic - please adjust! ###
|
| 35 |
+
op.add_column('conversation_data', sa.Column('data', postgresql.JSONB(astext_type=sa.Text()), autoincrement=False, nullable=False))
|
| 36 |
+
op.drop_column('conversation_data', 'user_query')
|
| 37 |
+
op.drop_column('conversation_data', 'response')
|
| 38 |
+
op.drop_column('conversation_data', 'file_metadata')
|
| 39 |
+
op.drop_column('conversation_data', 'is_saved')
|
| 40 |
+
op.create_table('stored_convoId_data',
|
| 41 |
+
sa.Column('id', sa.INTEGER(), server_default=sa.text('nextval(\'"stored_convoId_data_id_seq"\'::regclass)'), autoincrement=True, nullable=False),
|
| 42 |
+
sa.Column('convo_id', sa.VARCHAR(length=255), autoincrement=False, nullable=False),
|
| 43 |
+
sa.Column('user_query', postgresql.JSONB(astext_type=sa.Text()), autoincrement=False, nullable=False),
|
| 44 |
+
sa.Column('response', postgresql.JSONB(astext_type=sa.Text()), autoincrement=False, nullable=False),
|
| 45 |
+
sa.Column('file_metadata', postgresql.JSONB(astext_type=sa.Text()), autoincrement=False, nullable=True),
|
| 46 |
+
sa.Column('is_saved', sa.BOOLEAN(), server_default=sa.text('false'), autoincrement=False, nullable=True),
|
| 47 |
+
sa.Column('created_at', postgresql.TIMESTAMP(), server_default=sa.text('CURRENT_TIMESTAMP'), autoincrement=False, nullable=True),
|
| 48 |
+
sa.Column('updated_at', postgresql.TIMESTAMP(), server_default=sa.text('CURRENT_TIMESTAMP'), autoincrement=False, nullable=True),
|
| 49 |
+
sa.PrimaryKeyConstraint('id', name='stored_convoId_data_pkey')
|
| 50 |
+
)
|
| 51 |
+
op.create_index('idx_convo_id_is_saved', 'stored_convoId_data', ['convo_id', 'is_saved'], unique=False)
|
| 52 |
+
# ### end Alembic commands ###
|
alembic/versions/0f27f624c0f9_make_created_by_non_nullable.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""make created_by non-nullable
|
| 2 |
+
|
| 3 |
+
Revision ID: 0f27f624c0f9
|
| 4 |
+
Revises: 872b723d49c9
|
| 5 |
+
Create Date: 2025-10-26 05:01:28.791953
|
| 6 |
+
|
| 7 |
+
"""
|
| 8 |
+
from typing import Sequence, Union
|
| 9 |
+
|
| 10 |
+
from alembic import op
|
| 11 |
+
import sqlalchemy as sa
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
# revision identifiers, used by Alembic.
|
| 15 |
+
revision: str = '0f27f624c0f9'
|
| 16 |
+
down_revision: Union[str, None] = '872b723d49c9'
|
| 17 |
+
branch_labels: Union[str, Sequence[str], None] = None
|
| 18 |
+
depends_on: Union[str, Sequence[str], None] = None
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def upgrade() -> None:
|
| 22 |
+
op.alter_column('organizations', 'created_by', existing_type=sa.UUID(), nullable=False)
|
| 23 |
+
op.alter_column('plans', 'created_by', existing_type=sa.UUID(), nullable=False)
|
| 24 |
+
op.alter_column('third_party_providers', 'created_by', existing_type=sa.UUID(), nullable=False)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def downgrade() -> None:
|
| 28 |
+
op.alter_column('organizations', 'created_by', existing_type=sa.UUID(), nullable=True)
|
| 29 |
+
op.alter_column('plans', 'created_by', existing_type=sa.UUID(), nullable=True)
|
| 30 |
+
op.alter_column('third_party_providers', 'created_by', existing_type=sa.UUID(), nullable=True)
|
alembic/versions/2cb6dd9a9f5b_improve_models_add_indexes_cascades_.py
ADDED
|
@@ -0,0 +1,242 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Improve models: add indexes, cascades, timezone-aware timestamps, soft delete, and relationships
|
| 2 |
+
|
| 3 |
+
Revision ID: 2cb6dd9a9f5b
|
| 4 |
+
Revises: c378ad11cd73
|
| 5 |
+
Create Date: 2025-10-24 11:36:37.693119
|
| 6 |
+
|
| 7 |
+
"""
|
| 8 |
+
from typing import Sequence, Union
|
| 9 |
+
|
| 10 |
+
from alembic import op
|
| 11 |
+
import sqlalchemy as sa
|
| 12 |
+
from sqlalchemy.dialects import postgresql
|
| 13 |
+
|
| 14 |
+
# revision identifiers, used by Alembic.
|
| 15 |
+
revision: str = '2cb6dd9a9f5b'
|
| 16 |
+
down_revision: Union[str, None] = 'c378ad11cd73'
|
| 17 |
+
branch_labels: Union[str, Sequence[str], None] = None
|
| 18 |
+
depends_on: Union[str, Sequence[str], None] = None
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def upgrade() -> None:
|
| 22 |
+
# ### commands auto generated by Alembic - please adjust! ###
|
| 23 |
+
op.add_column('conversation_data', sa.Column('deleted_at', sa.DateTime(timezone=True), nullable=True))
|
| 24 |
+
op.alter_column('conversation_data', 'created_at',
|
| 25 |
+
existing_type=postgresql.TIMESTAMP(),
|
| 26 |
+
type_=sa.DateTime(timezone=True),
|
| 27 |
+
nullable=False)
|
| 28 |
+
op.add_column('datasets', sa.Column('deleted_at', sa.DateTime(timezone=True), nullable=True))
|
| 29 |
+
op.alter_column('datasets', 'submitted_time',
|
| 30 |
+
existing_type=postgresql.TIMESTAMP(),
|
| 31 |
+
type_=sa.DateTime(timezone=True),
|
| 32 |
+
nullable=False)
|
| 33 |
+
op.alter_column('datasets', 'status',
|
| 34 |
+
existing_type=postgresql.ENUM('pending', 'approved', 'rejected', 'published', name='datasetstatus'),
|
| 35 |
+
nullable=False)
|
| 36 |
+
op.alter_column('datasets', 'approved_at',
|
| 37 |
+
existing_type=postgresql.TIMESTAMP(),
|
| 38 |
+
type_=sa.DateTime(timezone=True),
|
| 39 |
+
existing_nullable=True)
|
| 40 |
+
op.alter_column('datasets', 'created_at',
|
| 41 |
+
existing_type=postgresql.TIMESTAMP(),
|
| 42 |
+
type_=sa.DateTime(timezone=True),
|
| 43 |
+
nullable=False)
|
| 44 |
+
op.create_index(op.f('ix_datasets_provider_id'), 'datasets', ['provider_id'], unique=False)
|
| 45 |
+
op.create_index(op.f('ix_datasets_reviewed_by'), 'datasets', ['reviewed_by'], unique=False)
|
| 46 |
+
op.create_index(op.f('ix_datasets_status'), 'datasets', ['status'], unique=False)
|
| 47 |
+
op.drop_constraint('datasets_provider_id_fkey', 'datasets', type_='foreignkey')
|
| 48 |
+
op.drop_constraint('datasets_reviewed_by_fkey', 'datasets', type_='foreignkey')
|
| 49 |
+
op.create_foreign_key(None, 'datasets', 'users', ['reviewed_by'], ['id'], ondelete='SET NULL')
|
| 50 |
+
op.create_foreign_key(None, 'datasets', 'third_party_providers', ['provider_id'], ['id'], ondelete='RESTRICT')
|
| 51 |
+
op.drop_column('datasets', 'is_approved')
|
| 52 |
+
op.add_column('leads', sa.Column('deleted_at', sa.DateTime(timezone=True), nullable=True))
|
| 53 |
+
op.alter_column('leads', 'status',
|
| 54 |
+
existing_type=postgresql.ENUM('new', 'contacted', 'closed', name='leadstatus'),
|
| 55 |
+
nullable=False)
|
| 56 |
+
op.alter_column('leads', 'created_at',
|
| 57 |
+
existing_type=postgresql.TIMESTAMP(),
|
| 58 |
+
type_=sa.DateTime(timezone=True),
|
| 59 |
+
nullable=False)
|
| 60 |
+
op.alter_column('leads', 'updated_at',
|
| 61 |
+
existing_type=postgresql.TIMESTAMP(),
|
| 62 |
+
type_=sa.DateTime(timezone=True),
|
| 63 |
+
nullable=False)
|
| 64 |
+
op.create_index(op.f('ix_leads_dataset_id'), 'leads', ['dataset_id'], unique=False)
|
| 65 |
+
op.create_index(op.f('ix_leads_organization_id'), 'leads', ['organization_id'], unique=False)
|
| 66 |
+
op.create_index(op.f('ix_leads_provider_id'), 'leads', ['provider_id'], unique=False)
|
| 67 |
+
op.create_index(op.f('ix_leads_status'), 'leads', ['status'], unique=False)
|
| 68 |
+
op.drop_constraint('leads_organization_id_fkey', 'leads', type_='foreignkey')
|
| 69 |
+
op.drop_constraint('leads_provider_id_fkey', 'leads', type_='foreignkey')
|
| 70 |
+
op.drop_constraint('leads_dataset_id_fkey', 'leads', type_='foreignkey')
|
| 71 |
+
op.create_foreign_key(None, 'leads', 'datasets', ['dataset_id'], ['id'], ondelete='RESTRICT')
|
| 72 |
+
op.create_foreign_key(None, 'leads', 'organizations', ['organization_id'], ['id'], ondelete='CASCADE')
|
| 73 |
+
op.create_foreign_key(None, 'leads', 'third_party_providers', ['provider_id'], ['id'], ondelete='RESTRICT')
|
| 74 |
+
op.add_column('organizations', sa.Column('deleted_at', sa.DateTime(timezone=True), nullable=True))
|
| 75 |
+
op.alter_column('organizations', 'is_active',
|
| 76 |
+
existing_type=sa.BOOLEAN(),
|
| 77 |
+
nullable=False)
|
| 78 |
+
op.alter_column('organizations', 'created_at',
|
| 79 |
+
existing_type=postgresql.TIMESTAMP(),
|
| 80 |
+
type_=sa.DateTime(timezone=True),
|
| 81 |
+
nullable=False)
|
| 82 |
+
op.alter_column('organizations', 'updated_at',
|
| 83 |
+
existing_type=postgresql.TIMESTAMP(),
|
| 84 |
+
type_=sa.DateTime(timezone=True),
|
| 85 |
+
nullable=False)
|
| 86 |
+
op.create_index(op.f('ix_organizations_plan_id'), 'organizations', ['plan_id'], unique=False)
|
| 87 |
+
op.drop_constraint('organizations_plan_id_fkey', 'organizations', type_='foreignkey')
|
| 88 |
+
op.create_foreign_key(None, 'organizations', 'plans', ['plan_id'], ['id'], ondelete='RESTRICT')
|
| 89 |
+
op.add_column('plans', sa.Column('deleted_at', sa.DateTime(timezone=True), nullable=True))
|
| 90 |
+
op.alter_column('plans', 'created_at',
|
| 91 |
+
existing_type=postgresql.TIMESTAMP(),
|
| 92 |
+
type_=sa.DateTime(timezone=True),
|
| 93 |
+
nullable=False)
|
| 94 |
+
op.add_column('third_party_providers', sa.Column('deleted_at', sa.DateTime(timezone=True), nullable=True))
|
| 95 |
+
op.alter_column('third_party_providers', 'is_active',
|
| 96 |
+
existing_type=sa.BOOLEAN(),
|
| 97 |
+
nullable=False)
|
| 98 |
+
op.alter_column('third_party_providers', 'created_at',
|
| 99 |
+
existing_type=postgresql.TIMESTAMP(),
|
| 100 |
+
type_=sa.DateTime(timezone=True),
|
| 101 |
+
nullable=False)
|
| 102 |
+
op.add_column('users', sa.Column('deleted_at', sa.DateTime(timezone=True), nullable=True))
|
| 103 |
+
op.alter_column('users', 'password',
|
| 104 |
+
existing_type=sa.VARCHAR(length=255),
|
| 105 |
+
type_=sa.String(length=500),
|
| 106 |
+
existing_nullable=False)
|
| 107 |
+
op.alter_column('users', 'is_active',
|
| 108 |
+
existing_type=sa.BOOLEAN(),
|
| 109 |
+
nullable=False)
|
| 110 |
+
op.alter_column('users', 'last_login',
|
| 111 |
+
existing_type=postgresql.TIMESTAMP(),
|
| 112 |
+
type_=sa.DateTime(timezone=True),
|
| 113 |
+
existing_nullable=True)
|
| 114 |
+
op.alter_column('users', 'created_at',
|
| 115 |
+
existing_type=postgresql.TIMESTAMP(),
|
| 116 |
+
type_=sa.DateTime(timezone=True),
|
| 117 |
+
nullable=False)
|
| 118 |
+
op.alter_column('users', 'updated_at',
|
| 119 |
+
existing_type=postgresql.TIMESTAMP(),
|
| 120 |
+
type_=sa.DateTime(timezone=True),
|
| 121 |
+
nullable=False)
|
| 122 |
+
op.create_index(op.f('ix_users_is_active'), 'users', ['is_active'], unique=False)
|
| 123 |
+
op.create_index(op.f('ix_users_organization_id'), 'users', ['organization_id'], unique=False)
|
| 124 |
+
op.create_index(op.f('ix_users_provider_id'), 'users', ['provider_id'], unique=False)
|
| 125 |
+
op.create_index(op.f('ix_users_role'), 'users', ['role'], unique=False)
|
| 126 |
+
op.drop_constraint('users_provider_id_fkey', 'users', type_='foreignkey')
|
| 127 |
+
op.drop_constraint('users_organization_id_fkey', 'users', type_='foreignkey')
|
| 128 |
+
op.create_foreign_key(None, 'users', 'third_party_providers', ['provider_id'], ['id'], ondelete='SET NULL')
|
| 129 |
+
op.create_foreign_key(None, 'users', 'organizations', ['organization_id'], ['id'], ondelete='SET NULL')
|
| 130 |
+
# ### end Alembic commands ###
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
def downgrade() -> None:
|
| 134 |
+
# ### commands auto generated by Alembic - please adjust! ###
|
| 135 |
+
op.drop_constraint(None, 'users', type_='foreignkey')
|
| 136 |
+
op.drop_constraint(None, 'users', type_='foreignkey')
|
| 137 |
+
op.create_foreign_key('users_organization_id_fkey', 'users', 'organizations', ['organization_id'], ['id'])
|
| 138 |
+
op.create_foreign_key('users_provider_id_fkey', 'users', 'third_party_providers', ['provider_id'], ['id'])
|
| 139 |
+
op.drop_index(op.f('ix_users_role'), table_name='users')
|
| 140 |
+
op.drop_index(op.f('ix_users_provider_id'), table_name='users')
|
| 141 |
+
op.drop_index(op.f('ix_users_organization_id'), table_name='users')
|
| 142 |
+
op.drop_index(op.f('ix_users_is_active'), table_name='users')
|
| 143 |
+
op.alter_column('users', 'updated_at',
|
| 144 |
+
existing_type=sa.DateTime(timezone=True),
|
| 145 |
+
type_=postgresql.TIMESTAMP(),
|
| 146 |
+
nullable=True)
|
| 147 |
+
op.alter_column('users', 'created_at',
|
| 148 |
+
existing_type=sa.DateTime(timezone=True),
|
| 149 |
+
type_=postgresql.TIMESTAMP(),
|
| 150 |
+
nullable=True)
|
| 151 |
+
op.alter_column('users', 'last_login',
|
| 152 |
+
existing_type=sa.DateTime(timezone=True),
|
| 153 |
+
type_=postgresql.TIMESTAMP(),
|
| 154 |
+
existing_nullable=True)
|
| 155 |
+
op.alter_column('users', 'is_active',
|
| 156 |
+
existing_type=sa.BOOLEAN(),
|
| 157 |
+
nullable=True)
|
| 158 |
+
op.alter_column('users', 'password',
|
| 159 |
+
existing_type=sa.String(length=500),
|
| 160 |
+
type_=sa.VARCHAR(length=255),
|
| 161 |
+
existing_nullable=False)
|
| 162 |
+
op.drop_column('users', 'deleted_at')
|
| 163 |
+
op.alter_column('third_party_providers', 'created_at',
|
| 164 |
+
existing_type=sa.DateTime(timezone=True),
|
| 165 |
+
type_=postgresql.TIMESTAMP(),
|
| 166 |
+
nullable=True)
|
| 167 |
+
op.alter_column('third_party_providers', 'is_active',
|
| 168 |
+
existing_type=sa.BOOLEAN(),
|
| 169 |
+
nullable=True)
|
| 170 |
+
op.drop_column('third_party_providers', 'deleted_at')
|
| 171 |
+
op.alter_column('plans', 'created_at',
|
| 172 |
+
existing_type=sa.DateTime(timezone=True),
|
| 173 |
+
type_=postgresql.TIMESTAMP(),
|
| 174 |
+
nullable=True)
|
| 175 |
+
op.drop_column('plans', 'deleted_at')
|
| 176 |
+
op.drop_constraint(None, 'organizations', type_='foreignkey')
|
| 177 |
+
op.create_foreign_key('organizations_plan_id_fkey', 'organizations', 'plans', ['plan_id'], ['id'])
|
| 178 |
+
op.drop_index(op.f('ix_organizations_plan_id'), table_name='organizations')
|
| 179 |
+
op.alter_column('organizations', 'updated_at',
|
| 180 |
+
existing_type=sa.DateTime(timezone=True),
|
| 181 |
+
type_=postgresql.TIMESTAMP(),
|
| 182 |
+
nullable=True)
|
| 183 |
+
op.alter_column('organizations', 'created_at',
|
| 184 |
+
existing_type=sa.DateTime(timezone=True),
|
| 185 |
+
type_=postgresql.TIMESTAMP(),
|
| 186 |
+
nullable=True)
|
| 187 |
+
op.alter_column('organizations', 'is_active',
|
| 188 |
+
existing_type=sa.BOOLEAN(),
|
| 189 |
+
nullable=True)
|
| 190 |
+
op.drop_column('organizations', 'deleted_at')
|
| 191 |
+
op.drop_constraint(None, 'leads', type_='foreignkey')
|
| 192 |
+
op.drop_constraint(None, 'leads', type_='foreignkey')
|
| 193 |
+
op.drop_constraint(None, 'leads', type_='foreignkey')
|
| 194 |
+
op.create_foreign_key('leads_dataset_id_fkey', 'leads', 'datasets', ['dataset_id'], ['id'])
|
| 195 |
+
op.create_foreign_key('leads_provider_id_fkey', 'leads', 'third_party_providers', ['provider_id'], ['id'])
|
| 196 |
+
op.create_foreign_key('leads_organization_id_fkey', 'leads', 'organizations', ['organization_id'], ['id'])
|
| 197 |
+
op.drop_index(op.f('ix_leads_status'), table_name='leads')
|
| 198 |
+
op.drop_index(op.f('ix_leads_provider_id'), table_name='leads')
|
| 199 |
+
op.drop_index(op.f('ix_leads_organization_id'), table_name='leads')
|
| 200 |
+
op.drop_index(op.f('ix_leads_dataset_id'), table_name='leads')
|
| 201 |
+
op.alter_column('leads', 'updated_at',
|
| 202 |
+
existing_type=sa.DateTime(timezone=True),
|
| 203 |
+
type_=postgresql.TIMESTAMP(),
|
| 204 |
+
nullable=True)
|
| 205 |
+
op.alter_column('leads', 'created_at',
|
| 206 |
+
existing_type=sa.DateTime(timezone=True),
|
| 207 |
+
type_=postgresql.TIMESTAMP(),
|
| 208 |
+
nullable=True)
|
| 209 |
+
op.alter_column('leads', 'status',
|
| 210 |
+
existing_type=postgresql.ENUM('new', 'contacted', 'closed', name='leadstatus'),
|
| 211 |
+
nullable=True)
|
| 212 |
+
op.drop_column('leads', 'deleted_at')
|
| 213 |
+
op.add_column('datasets', sa.Column('is_approved', sa.BOOLEAN(), autoincrement=False, nullable=True))
|
| 214 |
+
op.drop_constraint(None, 'datasets', type_='foreignkey')
|
| 215 |
+
op.drop_constraint(None, 'datasets', type_='foreignkey')
|
| 216 |
+
op.create_foreign_key('datasets_reviewed_by_fkey', 'datasets', 'users', ['reviewed_by'], ['id'])
|
| 217 |
+
op.create_foreign_key('datasets_provider_id_fkey', 'datasets', 'third_party_providers', ['provider_id'], ['id'])
|
| 218 |
+
op.drop_index(op.f('ix_datasets_status'), table_name='datasets')
|
| 219 |
+
op.drop_index(op.f('ix_datasets_reviewed_by'), table_name='datasets')
|
| 220 |
+
op.drop_index(op.f('ix_datasets_provider_id'), table_name='datasets')
|
| 221 |
+
op.alter_column('datasets', 'created_at',
|
| 222 |
+
existing_type=sa.DateTime(timezone=True),
|
| 223 |
+
type_=postgresql.TIMESTAMP(),
|
| 224 |
+
nullable=True)
|
| 225 |
+
op.alter_column('datasets', 'approved_at',
|
| 226 |
+
existing_type=sa.DateTime(timezone=True),
|
| 227 |
+
type_=postgresql.TIMESTAMP(),
|
| 228 |
+
existing_nullable=True)
|
| 229 |
+
op.alter_column('datasets', 'status',
|
| 230 |
+
existing_type=postgresql.ENUM('pending', 'approved', 'rejected', 'published', name='datasetstatus'),
|
| 231 |
+
nullable=True)
|
| 232 |
+
op.alter_column('datasets', 'submitted_time',
|
| 233 |
+
existing_type=sa.DateTime(timezone=True),
|
| 234 |
+
type_=postgresql.TIMESTAMP(),
|
| 235 |
+
nullable=True)
|
| 236 |
+
op.drop_column('datasets', 'deleted_at')
|
| 237 |
+
op.alter_column('conversation_data', 'created_at',
|
| 238 |
+
existing_type=sa.DateTime(timezone=True),
|
| 239 |
+
type_=postgresql.TIMESTAMP(),
|
| 240 |
+
nullable=True)
|
| 241 |
+
op.drop_column('conversation_data', 'deleted_at')
|
| 242 |
+
# ### end Alembic commands ###
|
alembic/versions/3b084b14f4b1_add_created_by_field_to_datasets.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""add created_by field to datasets
|
| 2 |
+
|
| 3 |
+
Revision ID: 3b084b14f4b1
|
| 4 |
+
Revises: 2cb6dd9a9f5b
|
| 5 |
+
Create Date: 2025-10-24 15:41:23.744370
|
| 6 |
+
|
| 7 |
+
"""
|
| 8 |
+
from typing import Sequence, Union
|
| 9 |
+
|
| 10 |
+
from alembic import op
|
| 11 |
+
import sqlalchemy as sa
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
# revision identifiers, used by Alembic.
|
| 15 |
+
revision: str = '3b084b14f4b1'
|
| 16 |
+
down_revision: Union[str, None] = '2cb6dd9a9f5b'
|
| 17 |
+
branch_labels: Union[str, Sequence[str], None] = None
|
| 18 |
+
depends_on: Union[str, Sequence[str], None] = None
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def upgrade() -> None:
|
| 22 |
+
# ### commands auto generated by Alembic - please adjust! ###
|
| 23 |
+
op.add_column('datasets', sa.Column('created_by', sa.UUID(), nullable=False))
|
| 24 |
+
op.create_index(op.f('ix_datasets_created_by'), 'datasets', ['created_by'], unique=False)
|
| 25 |
+
op.create_foreign_key(None, 'datasets', 'users', ['created_by'], ['id'], ondelete='CASCADE')
|
| 26 |
+
# ### end Alembic commands ###
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def downgrade() -> None:
|
| 30 |
+
# ### commands auto generated by Alembic - please adjust! ###
|
| 31 |
+
op.drop_constraint(None, 'datasets', type_='foreignkey')
|
| 32 |
+
op.drop_index(op.f('ix_datasets_created_by'), table_name='datasets')
|
| 33 |
+
op.drop_column('datasets', 'created_by')
|
| 34 |
+
# ### end Alembic commands ###
|
alembic/versions/519b15d0dca6_added_two_columns_in_convo_table.py
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""added two columns in convo table
|
| 2 |
+
|
| 3 |
+
Revision ID: 519b15d0dca6
|
| 4 |
+
Revises: 0f27f624c0f9
|
| 5 |
+
Create Date: 2025-10-27 13:07:07.655070
|
| 6 |
+
|
| 7 |
+
"""
|
| 8 |
+
from typing import Sequence, Union
|
| 9 |
+
|
| 10 |
+
from alembic import op
|
| 11 |
+
import sqlalchemy as sa
|
| 12 |
+
from sqlalchemy.dialects import postgresql
|
| 13 |
+
|
| 14 |
+
# revision identifiers, used by Alembic.
|
| 15 |
+
revision: str = '519b15d0dca6'
|
| 16 |
+
down_revision: Union[str, None] = '0f27f624c0f9'
|
| 17 |
+
branch_labels: Union[str, Sequence[str], None] = None
|
| 18 |
+
depends_on: Union[str, Sequence[str], None] = None
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def upgrade() -> None:
|
| 22 |
+
# ### commands auto generated by Alembic - please adjust! ###
|
| 23 |
+
op.drop_index('idx_convo_id_is_saved', table_name='stored_convoId_data')
|
| 24 |
+
op.drop_table('stored_convoId_data')
|
| 25 |
+
op.add_column('conversation_data', sa.Column('updated_at', sa.DateTime(timezone=True), nullable=False))
|
| 26 |
+
op.add_column('conversation_data', sa.Column('user_id', sa.UUID(), nullable=False))
|
| 27 |
+
op.create_index(op.f('ix_conversation_data_user_id'), 'conversation_data', ['user_id'], unique=False)
|
| 28 |
+
op.create_foreign_key(None, 'conversation_data', 'users', ['user_id'], ['id'], ondelete='CASCADE')
|
| 29 |
+
# ### end Alembic commands ###
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def downgrade() -> None:
|
| 33 |
+
# ### commands auto generated by Alembic - please adjust! ###
|
| 34 |
+
op.drop_constraint(None, 'conversation_data', type_='foreignkey')
|
| 35 |
+
op.drop_index(op.f('ix_conversation_data_user_id'), table_name='conversation_data')
|
| 36 |
+
op.drop_column('conversation_data', 'user_id')
|
| 37 |
+
op.drop_column('conversation_data', 'updated_at')
|
| 38 |
+
op.create_table('stored_convoId_data',
|
| 39 |
+
sa.Column('id', sa.INTEGER(), server_default=sa.text('nextval(\'"stored_convoId_data_id_seq"\'::regclass)'), autoincrement=True, nullable=False),
|
| 40 |
+
sa.Column('convo_id', sa.VARCHAR(length=255), autoincrement=False, nullable=False),
|
| 41 |
+
sa.Column('user_query', postgresql.JSONB(astext_type=sa.Text()), autoincrement=False, nullable=False),
|
| 42 |
+
sa.Column('response', postgresql.JSONB(astext_type=sa.Text()), autoincrement=False, nullable=False),
|
| 43 |
+
sa.Column('file_metadata', postgresql.JSONB(astext_type=sa.Text()), autoincrement=False, nullable=True),
|
| 44 |
+
sa.Column('is_saved', sa.BOOLEAN(), server_default=sa.text('false'), autoincrement=False, nullable=True),
|
| 45 |
+
sa.Column('created_at', postgresql.TIMESTAMP(), server_default=sa.text('CURRENT_TIMESTAMP'), autoincrement=False, nullable=True),
|
| 46 |
+
sa.Column('updated_at', postgresql.TIMESTAMP(), server_default=sa.text('CURRENT_TIMESTAMP'), autoincrement=False, nullable=True),
|
| 47 |
+
sa.PrimaryKeyConstraint('id', name='stored_convoId_data_pkey')
|
| 48 |
+
)
|
| 49 |
+
op.create_index('idx_convo_id_is_saved', 'stored_convoId_data', ['convo_id', 'is_saved'], unique=False)
|
| 50 |
+
# ### end Alembic commands ###
|
alembic/versions/872b723d49c9_added_created_by_fields_in_required_.py
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""added created_by fields in required tables
|
| 2 |
+
|
| 3 |
+
Revision ID: 872b723d49c9
|
| 4 |
+
Revises: 3b084b14f4b1
|
| 5 |
+
Create Date: 2025-10-26 04:45:04.299414
|
| 6 |
+
|
| 7 |
+
"""
|
| 8 |
+
from typing import Sequence, Union
|
| 9 |
+
|
| 10 |
+
from alembic import op
|
| 11 |
+
import sqlalchemy as sa
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
# revision identifiers, used by Alembic.
|
| 15 |
+
revision: str = '872b723d49c9'
|
| 16 |
+
down_revision: Union[str, None] = '3b084b14f4b1'
|
| 17 |
+
branch_labels: Union[str, Sequence[str], None] = None
|
| 18 |
+
depends_on: Union[str, Sequence[str], None] = None
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def upgrade() -> None:
|
| 22 |
+
# ### commands auto generated by Alembic - please adjust! ###
|
| 23 |
+
op.add_column('organizations', sa.Column('created_by', sa.UUID(), nullable=True))
|
| 24 |
+
op.create_index(op.f('ix_organizations_created_by'), 'organizations', ['created_by'], unique=False)
|
| 25 |
+
op.create_foreign_key(None, 'organizations', 'users', ['created_by'], ['id'], ondelete='CASCADE')
|
| 26 |
+
op.add_column('plans', sa.Column('created_by', sa.UUID(), nullable=True))
|
| 27 |
+
op.create_index(op.f('ix_plans_created_by'), 'plans', ['created_by'], unique=False)
|
| 28 |
+
op.create_foreign_key(None, 'plans', 'users', ['created_by'], ['id'], ondelete='CASCADE')
|
| 29 |
+
op.add_column('third_party_providers', sa.Column('created_by', sa.UUID(), nullable=True))
|
| 30 |
+
op.create_index(op.f('ix_third_party_providers_created_by'), 'third_party_providers', ['created_by'], unique=False)
|
| 31 |
+
op.create_foreign_key(None, 'third_party_providers', 'users', ['created_by'], ['id'], ondelete='CASCADE')
|
| 32 |
+
# ### end Alembic commands ###
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def downgrade() -> None:
|
| 36 |
+
# ### commands auto generated by Alembic - please adjust! ###
|
| 37 |
+
op.drop_constraint(None, 'third_party_providers', type_='foreignkey')
|
| 38 |
+
op.drop_index(op.f('ix_third_party_providers_created_by'), table_name='third_party_providers')
|
| 39 |
+
op.drop_column('third_party_providers', 'created_by')
|
| 40 |
+
op.drop_constraint(None, 'plans', type_='foreignkey')
|
| 41 |
+
op.drop_index(op.f('ix_plans_created_by'), table_name='plans')
|
| 42 |
+
op.drop_column('plans', 'created_by')
|
| 43 |
+
op.drop_constraint(None, 'organizations', type_='foreignkey')
|
| 44 |
+
op.drop_index(op.f('ix_organizations_created_by'), table_name='organizations')
|
| 45 |
+
op.drop_column('organizations', 'created_by')
|
| 46 |
+
# ### end Alembic commands ###
|
alembic/versions/__pycache__/049cf12dc407_removed_data_column_and_added_file_.cpython-312.pyc
ADDED
|
Binary file (4.69 kB). View file
|
|
|
alembic/versions/__pycache__/0f27f624c0f9_make_created_by_non_nullable.cpython-312.pyc
ADDED
|
Binary file (1.82 kB). View file
|
|
|
alembic/versions/__pycache__/2cb6dd9a9f5b_improve_models_add_indexes_cascades_.cpython-312.pyc
ADDED
|
Binary file (15.9 kB). View file
|
|
|
alembic/versions/__pycache__/3b084b14f4b1_add_created_by_field_to_datasets.cpython-312.pyc
ADDED
|
Binary file (1.89 kB). View file
|
|
|
alembic/versions/__pycache__/519b15d0dca6_added_two_columns_in_convo_table.cpython-312.pyc
ADDED
|
Binary file (4.29 kB). View file
|
|
|
alembic/versions/__pycache__/872b723d49c9_added_created_by_fields_in_required_.cpython-312.pyc
ADDED
|
Binary file (3.23 kB). View file
|
|
|
alembic/versions/__pycache__/ac5b502d055a_added_userdatasetsmetadata_table.cpython-312.pyc
ADDED
|
Binary file (7.54 kB). View file
|
|
|
alembic/versions/__pycache__/c378ad11cd73_initial_migration_capture_current_schema.cpython-312.pyc
ADDED
|
Binary file (1.02 kB). View file
|
|
|
alembic/versions/ac5b502d055a_added_userdatasetsmetadata_table.py
ADDED
|
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Added UserDatasetsMetadata Table
|
| 2 |
+
|
| 3 |
+
Revision ID: ac5b502d055a
|
| 4 |
+
Revises: 049cf12dc407
|
| 5 |
+
Create Date: 2025-10-29 11:41:42.626176
|
| 6 |
+
|
| 7 |
+
"""
|
| 8 |
+
from typing import Sequence, Union
|
| 9 |
+
|
| 10 |
+
from alembic import op
|
| 11 |
+
import sqlalchemy as sa
|
| 12 |
+
from sqlalchemy.dialects import postgresql
|
| 13 |
+
|
| 14 |
+
# revision identifiers, used by Alembic.
|
| 15 |
+
revision: str = 'ac5b502d055a'
|
| 16 |
+
down_revision: Union[str, None] = '049cf12dc407'
|
| 17 |
+
branch_labels: Union[str, Sequence[str], None] = None
|
| 18 |
+
depends_on: Union[str, Sequence[str], None] = None
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def upgrade() -> None:
|
| 22 |
+
# ### commands auto generated by Alembic - please adjust! ###
|
| 23 |
+
op.create_table('user_datasets_metadata',
|
| 24 |
+
sa.Column('id', sa.UUID(), nullable=False),
|
| 25 |
+
sa.Column('user_id', sa.UUID(), nullable=False),
|
| 26 |
+
sa.Column('user_metadata', postgresql.JSONB(astext_type=sa.Text()), nullable=False),
|
| 27 |
+
sa.Column('created_at', sa.DateTime(timezone=True), nullable=False),
|
| 28 |
+
sa.Column('updated_at', sa.DateTime(timezone=True), nullable=False),
|
| 29 |
+
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ondelete='CASCADE'),
|
| 30 |
+
sa.PrimaryKeyConstraint('id')
|
| 31 |
+
)
|
| 32 |
+
op.create_index(op.f('ix_user_datasets_metadata_user_id'), 'user_datasets_metadata', ['user_id'], unique=False)
|
| 33 |
+
op.drop_index('idx_convo_id_is_saved', table_name='stored_convoId_data')
|
| 34 |
+
op.drop_table('stored_convoId_data')
|
| 35 |
+
op.drop_index('ix_conversation_data_convo_id', table_name='conversation_data')
|
| 36 |
+
op.create_index(op.f('ix_conversation_data_convo_id'), 'conversation_data', ['convo_id'], unique=False)
|
| 37 |
+
op.alter_column('datasets', 'created_by',
|
| 38 |
+
existing_type=sa.UUID(),
|
| 39 |
+
nullable=True)
|
| 40 |
+
op.drop_constraint('datasets_created_by_fkey', 'datasets', type_='foreignkey')
|
| 41 |
+
op.create_foreign_key(None, 'datasets', 'users', ['created_by'], ['id'], ondelete='SET NULL')
|
| 42 |
+
op.alter_column('organizations', 'created_by',
|
| 43 |
+
existing_type=sa.UUID(),
|
| 44 |
+
nullable=True)
|
| 45 |
+
op.drop_constraint('organizations_created_by_fkey', 'organizations', type_='foreignkey')
|
| 46 |
+
op.create_foreign_key(None, 'organizations', 'users', ['created_by'], ['id'], ondelete='SET NULL')
|
| 47 |
+
op.alter_column('plans', 'created_by',
|
| 48 |
+
existing_type=sa.UUID(),
|
| 49 |
+
nullable=True)
|
| 50 |
+
op.drop_constraint('plans_created_by_fkey', 'plans', type_='foreignkey')
|
| 51 |
+
op.create_foreign_key(None, 'plans', 'users', ['created_by'], ['id'], ondelete='SET NULL')
|
| 52 |
+
op.alter_column('third_party_providers', 'created_by',
|
| 53 |
+
existing_type=sa.UUID(),
|
| 54 |
+
nullable=True)
|
| 55 |
+
op.drop_constraint('third_party_providers_created_by_fkey', 'third_party_providers', type_='foreignkey')
|
| 56 |
+
op.create_foreign_key(None, 'third_party_providers', 'users', ['created_by'], ['id'], ondelete='SET NULL')
|
| 57 |
+
# ### end Alembic commands ###
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def downgrade() -> None:
|
| 61 |
+
# ### commands auto generated by Alembic - please adjust! ###
|
| 62 |
+
op.drop_constraint(None, 'third_party_providers', type_='foreignkey')
|
| 63 |
+
op.create_foreign_key('third_party_providers_created_by_fkey', 'third_party_providers', 'users', ['created_by'], ['id'], ondelete='CASCADE')
|
| 64 |
+
op.alter_column('third_party_providers', 'created_by',
|
| 65 |
+
existing_type=sa.UUID(),
|
| 66 |
+
nullable=False)
|
| 67 |
+
op.drop_constraint(None, 'plans', type_='foreignkey')
|
| 68 |
+
op.create_foreign_key('plans_created_by_fkey', 'plans', 'users', ['created_by'], ['id'], ondelete='CASCADE')
|
| 69 |
+
op.alter_column('plans', 'created_by',
|
| 70 |
+
existing_type=sa.UUID(),
|
| 71 |
+
nullable=False)
|
| 72 |
+
op.drop_constraint(None, 'organizations', type_='foreignkey')
|
| 73 |
+
op.create_foreign_key('organizations_created_by_fkey', 'organizations', 'users', ['created_by'], ['id'], ondelete='CASCADE')
|
| 74 |
+
op.alter_column('organizations', 'created_by',
|
| 75 |
+
existing_type=sa.UUID(),
|
| 76 |
+
nullable=False)
|
| 77 |
+
op.drop_constraint(None, 'datasets', type_='foreignkey')
|
| 78 |
+
op.create_foreign_key('datasets_created_by_fkey', 'datasets', 'users', ['created_by'], ['id'], ondelete='CASCADE')
|
| 79 |
+
op.alter_column('datasets', 'created_by',
|
| 80 |
+
existing_type=sa.UUID(),
|
| 81 |
+
nullable=False)
|
| 82 |
+
op.drop_index(op.f('ix_conversation_data_convo_id'), table_name='conversation_data')
|
| 83 |
+
op.create_index('ix_conversation_data_convo_id', 'conversation_data', ['convo_id'], unique=True)
|
| 84 |
+
op.create_table('stored_convoId_data',
|
| 85 |
+
sa.Column('id', sa.INTEGER(), server_default=sa.text('nextval(\'"stored_convoId_data_id_seq"\'::regclass)'), autoincrement=True, nullable=False),
|
| 86 |
+
sa.Column('convo_id', sa.VARCHAR(length=255), autoincrement=False, nullable=False),
|
| 87 |
+
sa.Column('user_query', postgresql.JSONB(astext_type=sa.Text()), autoincrement=False, nullable=False),
|
| 88 |
+
sa.Column('response', postgresql.JSONB(astext_type=sa.Text()), autoincrement=False, nullable=False),
|
| 89 |
+
sa.Column('file_metadata', postgresql.JSONB(astext_type=sa.Text()), autoincrement=False, nullable=True),
|
| 90 |
+
sa.Column('is_saved', sa.BOOLEAN(), server_default=sa.text('false'), autoincrement=False, nullable=True),
|
| 91 |
+
sa.Column('created_at', postgresql.TIMESTAMP(), server_default=sa.text('CURRENT_TIMESTAMP'), autoincrement=False, nullable=True),
|
| 92 |
+
sa.Column('updated_at', postgresql.TIMESTAMP(), server_default=sa.text('CURRENT_TIMESTAMP'), autoincrement=False, nullable=True),
|
| 93 |
+
sa.PrimaryKeyConstraint('id', name='stored_convoId_data_pkey')
|
| 94 |
+
)
|
| 95 |
+
op.create_index('idx_convo_id_is_saved', 'stored_convoId_data', ['convo_id', 'is_saved'], unique=False)
|
| 96 |
+
op.drop_index(op.f('ix_user_datasets_metadata_user_id'), table_name='user_datasets_metadata')
|
| 97 |
+
op.drop_table('user_datasets_metadata')
|
| 98 |
+
# ### end Alembic commands ###
|
alembic/versions/c378ad11cd73_initial_migration_capture_current_schema.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Initial migration - capture current schema
|
| 2 |
+
|
| 3 |
+
Revision ID: c378ad11cd73
|
| 4 |
+
Revises:
|
| 5 |
+
Create Date: 2025-10-24 11:31:53.592361
|
| 6 |
+
|
| 7 |
+
"""
|
| 8 |
+
from typing import Sequence, Union
|
| 9 |
+
|
| 10 |
+
from alembic import op
|
| 11 |
+
import sqlalchemy as sa
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
# revision identifiers, used by Alembic.
|
| 15 |
+
revision: str = 'c378ad11cd73'
|
| 16 |
+
down_revision: Union[str, None] = None
|
| 17 |
+
branch_labels: Union[str, Sequence[str], None] = None
|
| 18 |
+
depends_on: Union[str, Sequence[str], None] = None
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def upgrade() -> None:
|
| 22 |
+
# ### commands auto generated by Alembic - please adjust! ###
|
| 23 |
+
pass
|
| 24 |
+
# ### end Alembic commands ###
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def downgrade() -> None:
|
| 28 |
+
# ### commands auto generated by Alembic - please adjust! ###
|
| 29 |
+
pass
|
| 30 |
+
# ### end Alembic commands ###
|
app.py
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import warnings
|
| 3 |
+
from fastapi import FastAPI, Request
|
| 4 |
+
from fastapi.responses import JSONResponse
|
| 5 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 6 |
+
|
| 7 |
+
# Import routes
|
| 8 |
+
# from routes import collections, ingestion, sessions, chat, system
|
| 9 |
+
from Redis.sessions import Redis_session_router
|
| 10 |
+
from Redis.sessions_new import redis_session_route_new
|
| 11 |
+
from Redis.sessions_old import Redis_session_router_old
|
| 12 |
+
from Routes.main_chat_bot import main_chatbot_route
|
| 13 |
+
from Routes.generate_report import Report_Generation_Router
|
| 14 |
+
from s3.file_insertion_s3 import s3_bucket_router
|
| 15 |
+
from backend.main import login_apis_router
|
| 16 |
+
from vector_db.retrival_qa_agent import RetrievalQA_router
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
from Routes.main_agent_chat_bot_v2 import main_chatbot_route_v2
|
| 21 |
+
from vector_db import qdrant_crud
|
| 22 |
+
|
| 23 |
+
from vector_db.qdrant_crud import Qdrant_router
|
| 24 |
+
from agent_tools.autovis_tool import Autoviz_router
|
| 25 |
+
from main_chat.main_chat import main_chat_router_v3
|
| 26 |
+
from DB_store_backup.agentic_context_convoid_management import Db_store_router
|
| 27 |
+
# Suppress warnings
|
| 28 |
+
warnings.filterwarnings("ignore", message="Qdrant client version.*is incompatible.*")
|
| 29 |
+
|
| 30 |
+
app = FastAPI(title="Combined AI Agent with Qdrant Collections and Redis Session Management")
|
| 31 |
+
|
| 32 |
+
# Add CORS middleware
|
| 33 |
+
app.add_middleware(
|
| 34 |
+
CORSMiddleware,
|
| 35 |
+
allow_origins=["*"],
|
| 36 |
+
allow_credentials=True,
|
| 37 |
+
allow_methods=["*"],
|
| 38 |
+
allow_headers=["*"],
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
print([route.path for route in app.routes])
|
| 42 |
+
# Include routers
|
| 43 |
+
|
| 44 |
+
app.include_router(main_chatbot_route)
|
| 45 |
+
app.include_router(Report_Generation_Router)
|
| 46 |
+
app.include_router(Redis_session_router)
|
| 47 |
+
app.include_router(redis_session_route_new)
|
| 48 |
+
# app.include_router(Redis_session_router_old)
|
| 49 |
+
app.include_router(Db_store_router)
|
| 50 |
+
app.include_router(s3_bucket_router)
|
| 51 |
+
app.include_router(Autoviz_router)
|
| 52 |
+
app.include_router(main_chat_router_v3)
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
app.include_router(Qdrant_router)
|
| 59 |
+
app.include_router(RetrievalQA_router)
|
| 60 |
+
app.include_router(main_chatbot_route_v2)
|
| 61 |
+
#==================================login and user management routes==================================
|
| 62 |
+
app.include_router(login_apis_router)
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
# ------------------- MIDDLEWARE -------------------
|
| 66 |
+
|
| 67 |
+
@app.middleware("http")
|
| 68 |
+
async def add_success_flag(request: Request, call_next):
|
| 69 |
+
response = await call_next(request)
|
| 70 |
+
|
| 71 |
+
# Only modify JSON responses
|
| 72 |
+
if "application/json" in response.headers.get("content-type", ""):
|
| 73 |
+
try:
|
| 74 |
+
body = b"".join([chunk async for chunk in response.body_iterator])
|
| 75 |
+
data = json.loads(body.decode())
|
| 76 |
+
|
| 77 |
+
# Add success flag
|
| 78 |
+
data["success"] = 200 <= response.status_code < 300
|
| 79 |
+
|
| 80 |
+
# Build new JSONResponse (auto handles Content-Length)
|
| 81 |
+
response = JSONResponse(
|
| 82 |
+
content=data,
|
| 83 |
+
status_code=response.status_code,
|
| 84 |
+
headers={k: v for k, v in response.headers.items() if k.lower() != "content-length"},
|
| 85 |
+
)
|
| 86 |
+
except Exception:
|
| 87 |
+
# fallback if response is not JSON parseable
|
| 88 |
+
pass
|
| 89 |
+
return response
|
| 90 |
+
|
| 91 |
+
|
backend/__init__.py
ADDED
|
File without changes
|
backend/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (166 Bytes). View file
|
|
|