|
|
import asyncio |
|
|
import os |
|
|
from agents import Agent, Runner |
|
|
from agents.extensions.memory import RedisSession |
|
|
import dotenv |
|
|
import redis.asyncio as aioredis |
|
|
|
|
|
dotenv.load_dotenv() |
|
|
Redis_url = os.getenv("REDIS_URL") |
|
|
|
|
|
|
|
|
|
|
|
from transformers import pipeline |
|
|
import re |
|
|
from datetime import datetime |
|
|
import json |
|
|
import requests |
|
|
|
|
|
|
|
|
try: |
|
|
print("π Loading BART model for summarization...") |
|
|
summarizer = pipeline("summarization", model="facebook/bart-large-cnn", tokenizer="facebook/bart-large-cnn",framework="pt") |
|
|
except Exception as e: |
|
|
print(f"Error loading BART model: {e}") |
|
|
summarizer = None |
|
|
|
|
|
print("β
Summarization pipeline ready!") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def get_sessions(user_login_id: str): |
|
|
""" |
|
|
Get all session IDs for a given user (based on key prefix). |
|
|
""" |
|
|
redis = await aioredis.from_url(Redis_url) |
|
|
pattern = f"{user_login_id}:*" |
|
|
keys = await redis.keys(pattern) |
|
|
sessions = [key.decode().replace(f"{user_login_id}:", "") for key in keys] |
|
|
await redis.close() |
|
|
return sessions |
|
|
|
|
|
|
|
|
async def get_session_history(user_login_id: str, session_id: str): |
|
|
""" |
|
|
Retrieve chat history for a given user's session. |
|
|
""" |
|
|
try: |
|
|
session = RedisSession.from_url( |
|
|
session_id, |
|
|
url=Redis_url, |
|
|
key_prefix=f"{user_login_id}:", |
|
|
) |
|
|
if not await session.ping(): |
|
|
raise Exception("Redis connection failed") |
|
|
|
|
|
items = await session.get_items() |
|
|
history = [ |
|
|
{"role": msg.get("role", "unknown"), "content": msg.get("content", "")} |
|
|
for msg in items |
|
|
] |
|
|
await session.close() |
|
|
return history |
|
|
|
|
|
except Exception as e: |
|
|
return {"error": str(e)} |
|
|
|
|
|
|
|
|
async def delete_session(user_login_id: str, session_id: str): |
|
|
""" |
|
|
Delete a specific session for a given user. |
|
|
""" |
|
|
try: |
|
|
session = RedisSession.from_url( |
|
|
session_id, |
|
|
url=Redis_url, |
|
|
key_prefix=f"{user_login_id}:", |
|
|
) |
|
|
if not await session.ping(): |
|
|
raise Exception("Redis connection failed") |
|
|
|
|
|
await session.clear_session() |
|
|
await session.close() |
|
|
return {"status": "success", "message": f"Session {session_id} deleted"} |
|
|
|
|
|
except Exception as e: |
|
|
return {"status": "error", "message": str(e)} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def main_demo(): |
|
|
user_id = "vatsav_user2" |
|
|
session_id = ":uuid_12345" |
|
|
|
|
|
print("Creating session...") |
|
|
session = RedisSession.from_url( |
|
|
session_id, |
|
|
url=Redis_url, |
|
|
key_prefix=f"{user_id}:", |
|
|
) |
|
|
|
|
|
agent = Agent(name="Assistant", instructions="Be concise.") |
|
|
await Runner.run(agent, "Hello!", session=session) |
|
|
await Runner.run(agent, "How are you?", session=session) |
|
|
await session.close() |
|
|
|
|
|
print("\n--- All Sessions ---") |
|
|
print(await get_sessions(user_id)) |
|
|
print("lenth of the sessions: ", len(await get_sessions(user_id)) or 0) |
|
|
|
|
|
print("\n--- Session History ---") |
|
|
|
|
|
history = await get_session_history(user_id, session_id) |
|
|
print("lenght of the history: ", len(history) or 0) |
|
|
print(history) |
|
|
print("\nHistoryends=======================:") |
|
|
for msg in history: |
|
|
print(f"{msg['role']}: {msg['content']}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print("\n--- Delete Session ---") |
|
|
print(await delete_session(user_id, session_id)) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
asyncio.run(main_demo()) |
|
|
|