Update chat model initialization in main.py to use "qwen-qwq-32b" and add a new ping function in ping.gs for testing the chatbot API response.
ec59be7
| import os | |
| import getpass | |
| from groq import Groq | |
| from langchain.chat_models import init_chat_model | |
| from langchain_core.messages import HumanMessage, SystemMessage | |
| from langchain_core.vectorstores import InMemoryVectorStore | |
| from langchain_core.documents import Document | |
| from langchain_text_splitters import RecursiveCharacterTextSplitter | |
| from langchain_community.document_loaders import UnstructuredMarkdownLoader | |
| from langchain_community.embeddings import HuggingFaceInferenceAPIEmbeddings | |
| from langchain import hub | |
| from langgraph.graph import START, StateGraph | |
| from pydantic.main import BaseModel | |
| from typing_extensions import List, TypedDict | |
| from langchain_cohere import CohereEmbeddings | |
| import re | |
| # from dotenv import load_dotenv | |
| from fastapi import FastAPI | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import JSONResponse | |
| ''' | |
| if not os.environ.get("GROQ_API_KEY"): | |
| os.environ["GROQ_API_KEY"] = getpass.getpass("Enter API key for Groq: ") | |
| ''' | |
| # load_dotenv() | |
| # print(f"GROQ_API_KEY: {os.getenv('GROQ_API_KEY')}") | |
| # print(f"HUGGING_FACE_API_KEY: {os.getenv('HUGGING_FACE_API_KEY')}") | |
| llm = init_chat_model("qwen-qwq-32b", model_provider="groq", api_key=os.environ["GROQ_API_KEY"]) | |
| ''' | |
| embeddings = HuggingFaceInferenceAPIEmbeddings( | |
| api_key = os.getenv('HUGGING_FACE_API_KEY'), | |
| model_name="sentence-transformers/all-MiniLM-L6-v2" | |
| ) | |
| embeddings = HuggingFaceInferenceAPIEmbeddings( | |
| api_key=os.getenv('HUGGING_FACE_API_KEY'), model_name="sentence-transformers/all-MiniLM-L6-v2" | |
| )''' | |
| embeddings = CohereEmbeddings( | |
| cohere_api_key=os.environ['COHERE'], | |
| model="embed-english-v3.0", # Added this line | |
| user_agent="langchain-cohere-embeddings" | |
| ) | |
| vector_store = InMemoryVectorStore(embedding=embeddings) | |
| # Data - 1 and Data - 2 | |
| data_1 = open(r'data_1.txt', 'r').read() | |
| data_2 = open(r'data_2.txt', 'r').read() | |
| data_3 = open(r'data_3.txt', 'r').read() | |
| data_4 = open(r'data_4.txt', 'r').read() | |
| comb = open(r'comb.txt', 'r').read() | |
| md_loader = UnstructuredMarkdownLoader('comb.md') | |
| text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=100) | |
| # all_splits = text_splitter.split_text(data_1 + "\n\n" + data_2 + "\n\n" + data_3 + "\n\n" + data_4) | |
| # all_splits = text_splitter.split_text(comb) | |
| all_splits = text_splitter.split_documents(md_loader.load()) | |
| # docs = [Document(page_content=text) for text in all_splits] | |
| docs = [Document(page_content=text.page_content, metadata=text.metadata) for text in all_splits] | |
| _ = vector_store.add_documents(documents=docs) | |
| prompt = hub.pull("rlm/rag-prompt") | |
| # Replace with custom prompt | |
| system_message = """You are a helpful and professional FAQ chatbot for the MLSC Coherence 25 Hackathon. Your role is to: | |
| 1. Provide accurate and concise answers based on the provided context | |
| 2. Be friendly but professional in tone | |
| 3. If you don't know the answer, simply say "I don't have information about that" | |
| 4. Keep responses brief and to the point | |
| 5. Focus on providing factual information from the context | |
| 6. Never mention "the provided context" or similar phrases in your responses | |
| 7. Never explain why you don't know something - just state that you don't know | |
| 8. Be direct and avoid unnecessary explanations""" | |
| human_message_template = """Context: {context} | |
| Question: {question} | |
| Please provide a clear and concise answer based on the context above.""" | |
| class State(TypedDict): | |
| question: str | |
| context: List[Document] | |
| answer: str | |
| def retrieve(state: State): | |
| retrieved_docs = vector_store.similarity_search(state["question"]) | |
| return {"context": retrieved_docs} | |
| def generate(state: State): | |
| docs_content = "\n\n".join(doc.page_content for doc in state["context"]) | |
| messages = [ | |
| SystemMessage(content=system_message), | |
| HumanMessage(content=human_message_template.format( | |
| context=docs_content, | |
| question=state["question"] | |
| )) | |
| ] | |
| print(messages) | |
| response = llm.invoke(messages) | |
| return {"answer": response.content} | |
| graph_builder = StateGraph(State).add_sequence([retrieve, generate]) | |
| graph_builder.add_edge(START, "retrieve") | |
| graph = graph_builder.compile() | |
| ''' | |
| response = graph.invoke({"question": "Who should i contact for help ?"}) | |
| print(response["answer"]) | |
| ''' | |
| app = FastAPI() | |
| origins = ["*"] | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=origins, | |
| allow_credentials=True, | |
| allow_methods=["GET", "POST", "PUT", "DELETE"], | |
| allow_headers=["*"], | |
| ) | |
| async def ping(): | |
| return "Pong!" | |
| class Query(BaseModel): | |
| question: str | |
| async def chat(request: Query): | |
| response = graph.invoke({"question": request.question}) | |
| response = response["answer"] | |
| response = re.sub(r'<think>.*?</think>', '', response, flags=re.DOTALL) | |
| # response = response[4:] | |
| return {"response": response} | |
| async def chat(request: Query): | |
| response = graph.invoke({"question": request.question}) | |
| response = response["answer"] | |
| response = re.sub(r'<think>.*?</think>', '', response, flags=re.DOTALL) | |
| # response = response[4:] | |
| return {"response": response} | |
| async def root(): | |
| return {"message": "Hello World"} | |