import os import getpass from groq import Groq from langchain.chat_models import init_chat_model from langchain_core.messages import HumanMessage, SystemMessage from langchain_core.vectorstores import InMemoryVectorStore from langchain_core.documents import Document from langchain_text_splitters import RecursiveCharacterTextSplitter from langchain_community.document_loaders import UnstructuredMarkdownLoader from langchain_community.embeddings import HuggingFaceInferenceAPIEmbeddings from langchain import hub from langgraph.graph import START, StateGraph from pydantic.main import BaseModel from typing_extensions import List, TypedDict from langchain_cohere import CohereEmbeddings import re # from dotenv import load_dotenv from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse ''' if not os.environ.get("GROQ_API_KEY"): os.environ["GROQ_API_KEY"] = getpass.getpass("Enter API key for Groq: ") ''' # load_dotenv() # print(f"GROQ_API_KEY: {os.getenv('GROQ_API_KEY')}") # print(f"HUGGING_FACE_API_KEY: {os.getenv('HUGGING_FACE_API_KEY')}") llm = init_chat_model("qwen-qwq-32b", model_provider="groq", api_key=os.environ["GROQ_API_KEY"]) ''' embeddings = HuggingFaceInferenceAPIEmbeddings( api_key = os.getenv('HUGGING_FACE_API_KEY'), model_name="sentence-transformers/all-MiniLM-L6-v2" ) embeddings = HuggingFaceInferenceAPIEmbeddings( api_key=os.getenv('HUGGING_FACE_API_KEY'), model_name="sentence-transformers/all-MiniLM-L6-v2" )''' embeddings = CohereEmbeddings( cohere_api_key=os.environ['COHERE'], model="embed-english-v3.0", # Added this line user_agent="langchain-cohere-embeddings" ) vector_store = InMemoryVectorStore(embedding=embeddings) # Data - 1 and Data - 2 data_1 = open(r'data_1.txt', 'r').read() data_2 = open(r'data_2.txt', 'r').read() data_3 = open(r'data_3.txt', 'r').read() data_4 = open(r'data_4.txt', 'r').read() comb = open(r'comb.txt', 'r').read() md_loader = UnstructuredMarkdownLoader('comb.md') text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=100) # all_splits = text_splitter.split_text(data_1 + "\n\n" + data_2 + "\n\n" + data_3 + "\n\n" + data_4) # all_splits = text_splitter.split_text(comb) all_splits = text_splitter.split_documents(md_loader.load()) # docs = [Document(page_content=text) for text in all_splits] docs = [Document(page_content=text.page_content, metadata=text.metadata) for text in all_splits] _ = vector_store.add_documents(documents=docs) prompt = hub.pull("rlm/rag-prompt") # Replace with custom prompt system_message = """You are a helpful and professional FAQ chatbot for the MLSC Coherence 25 Hackathon. Your role is to: 1. Provide accurate and concise answers based on the provided context 2. Be friendly but professional in tone 3. If you don't know the answer, simply say "I don't have information about that" 4. Keep responses brief and to the point 5. Focus on providing factual information from the context 6. Never mention "the provided context" or similar phrases in your responses 7. Never explain why you don't know something - just state that you don't know 8. Be direct and avoid unnecessary explanations""" human_message_template = """Context: {context} Question: {question} Please provide a clear and concise answer based on the context above.""" class State(TypedDict): question: str context: List[Document] answer: str def retrieve(state: State): retrieved_docs = vector_store.similarity_search(state["question"]) return {"context": retrieved_docs} def generate(state: State): docs_content = "\n\n".join(doc.page_content for doc in state["context"]) messages = [ SystemMessage(content=system_message), HumanMessage(content=human_message_template.format( context=docs_content, question=state["question"] )) ] print(messages) response = llm.invoke(messages) return {"answer": response.content} graph_builder = StateGraph(State).add_sequence([retrieve, generate]) graph_builder.add_edge(START, "retrieve") graph = graph_builder.compile() ''' response = graph.invoke({"question": "Who should i contact for help ?"}) print(response["answer"]) ''' app = FastAPI() origins = ["*"] app.add_middleware( CORSMiddleware, allow_origins=origins, allow_credentials=True, allow_methods=["GET", "POST", "PUT", "DELETE"], allow_headers=["*"], ) @app.get("/ping") async def ping(): return "Pong!" class Query(BaseModel): question: str @app.get("/chat") async def chat(request: Query): response = graph.invoke({"question": request.question}) response = response["answer"] response = re.sub(r'.*?', '', response, flags=re.DOTALL) # response = response[4:] return {"response": response} @app.post("/chat") async def chat(request: Query): response = graph.invoke({"question": request.question}) response = response["answer"] response = re.sub(r'.*?', '', response, flags=re.DOTALL) # response = response[4:] return {"response": response} @app.get("/") async def root(): return {"message": "Hello World"}