Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import os | |
| from langchain_google_genai import ChatGoogleGenerativeAI | |
| from langchain_core.prompts import ChatPromptTemplate | |
| from langchain_community.document_loaders import TextLoader | |
| from langchain_huggingface import HuggingFaceEmbeddings | |
| from langchain.prompts import PromptTemplate | |
| from langchain_core.output_parsers import StrOutputParser | |
| from langchain_core.runnables import RunnablePassthrough | |
| from langchain_chroma import Chroma | |
| import Raptor | |
| page = st.title("Chat with AskUSTH") | |
| if "gemini_api" not in st.session_state: | |
| st.session_state.gemini_api = None | |
| if "rag" not in st.session_state: | |
| st.session_state.rag = None | |
| if "llm" not in st.session_state: | |
| st.session_state.llm = None | |
| def get_chat_google_model(api_key): | |
| os.environ["GOOGLE_API_KEY"] = api_key | |
| return ChatGoogleGenerativeAI( | |
| model="gemini-1.5-flash", | |
| temperature=0, | |
| max_tokens=None, | |
| timeout=None, | |
| max_retries=2, | |
| ) | |
| def get_embedding_model(): | |
| model_name = "bkai-foundation-models/vietnamese-bi-encoder" | |
| model_kwargs = {'device': 'cpu'} | |
| encode_kwargs = {'normalize_embeddings': False} | |
| model = HuggingFaceEmbeddings( | |
| model_name=model_name, | |
| model_kwargs=model_kwargs, | |
| encode_kwargs=encode_kwargs | |
| ) | |
| return model | |
| if "embd" not in st.session_state: | |
| st.session_state.embd = get_embedding_model() | |
| if "model" not in st.session_state: | |
| st.session_state.model = None | |
| if "save_dir" not in st.session_state: | |
| st.session_state.save_dir = None | |
| if "uploaded_files" not in st.session_state: | |
| st.session_state.uploaded_files = set() | |
| def vote(): | |
| st.markdown( | |
| """ | |
| Để sử dụng Google Gemini, bạn cần cung cấp API key. Tạo key của bạn [tại đây](https://ai.google.dev/gemini-api/docs/get-started/tutorial?lang=python&hl=vi) và dán vào bên dưới. | |
| """ | |
| ) | |
| key = st.text_input("Key:", "") | |
| if st.button("Save") and key != "": | |
| st.session_state.gemini_api = key | |
| st.rerun() | |
| if st.session_state.gemini_api is None: | |
| vote() | |
| if st.session_state.gemini_api and st.session_state.model is None: | |
| st.session_state.model = get_chat_google_model(st.session_state.gemini_api) | |
| if st.session_state.save_dir is None: | |
| save_dir = "./Documents" | |
| if not os.path.exists(save_dir): | |
| os.makedirs(save_dir) | |
| st.session_state.save_dir = save_dir | |
| def load_txt(file_path): | |
| loader_sv = TextLoader(file_path=file_path, encoding="utf-8") | |
| doc = loader_sv.load() | |
| return doc | |
| with st.sidebar: | |
| uploaded_files = st.file_uploader("Chọn file txt", accept_multiple_files=True, type=["txt"]) | |
| if st.session_state.gemini_api: | |
| if uploaded_files: | |
| documents = [] | |
| uploaded_file_names = set() | |
| new_docs = False | |
| for uploaded_file in uploaded_files: | |
| uploaded_file_names.add(uploaded_file.name) | |
| if uploaded_file.name not in st.session_state.uploaded_files: | |
| file_path = os.path.join(st.session_state.save_dir, uploaded_file.name) | |
| with open(file_path, mode='wb') as w: | |
| w.write(uploaded_file.getvalue()) | |
| else: | |
| continue | |
| new_docs = True | |
| doc = load_txt(file_path) | |
| documents.extend([*doc]) | |
| if new_docs: | |
| st.session_state.uploaded_files = uploaded_file_names | |
| st.session_state.rag = None | |
| else: | |
| st.session_state.uploaded_files = set() | |
| st.session_state.rag = None | |
| def format_docs(docs): | |
| return "\n\n".join(doc.page_content for doc in docs) | |
| def compute_rag_chain(_model, _embd, docs_texts): | |
| results = Raptor.recursive_embed_cluster_summarize(_model, _embd, docs_texts, level=1, n_levels=3) | |
| all_texts = docs_texts.copy() | |
| i = 0 | |
| for level in sorted(results.keys()): | |
| summaries = results[level][1]["summaries"].tolist() | |
| all_texts.extend(summaries) | |
| print(f"summary {i} -------------------------------------------------") | |
| print(summaries) | |
| i += 1 | |
| print("all_texts ______________________________________") | |
| print(all_texts) | |
| vectorstore = Chroma.from_texts(texts=all_texts, embedding=_embd) | |
| retriever = vectorstore.as_retriever() | |
| template = """ | |
| Bạn là một trợ lí AI hỗ trợ tuyển sinh và sinh viên. \n | |
| Hãy trả lời câu hỏi chính xác, tập trung vào thông tin liên quan đến câu hỏi. \n | |
| Nếu bạn không biết câu trả lời, hãy nói không biết, đừng cố tạo ra câu trả lời.\n | |
| Dưới đây là thông tin liên quan mà bạn cần sử dụng tới:\n | |
| {context}\n | |
| hãy trả lời:\n | |
| {question} | |
| """ | |
| prompt = PromptTemplate(template=template, input_variables=["context", "question"]) | |
| rag_chain = ( | |
| {"context": retriever | format_docs, "question": RunnablePassthrough()} | |
| | prompt | |
| | _model | |
| | StrOutputParser() | |
| ) | |
| return rag_chain | |
| def load_rag(): | |
| docs_texts = [d.page_content for d in documents] | |
| st.session_state.rag = compute_rag_chain(st.session_state.model, st.session_state.embd, docs_texts) | |
| st.rerun() | |
| if st.session_state.uploaded_files and st.session_state.model is not None: | |
| if st.session_state.rag is None: | |
| load_rag() | |
| if st.session_state.model is not None: | |
| if st.session_state.llm is None: | |
| mess = ChatPromptTemplate.from_messages( | |
| [ | |
| ( | |
| "system", | |
| "Bản là một trợ lí AI hỗ trợ tuyển sinh và sinh viên", | |
| ), | |
| ("human", "{input}"), | |
| ] | |
| ) | |
| chain = mess | st.session_state.model | |
| st.session_state.llm = chain | |
| if "chat_history" not in st.session_state: | |
| st.session_state.chat_history = [] | |
| for message in st.session_state.chat_history: | |
| with st.chat_message(message["role"]): | |
| st.write(message["content"]) | |
| prompt = st.chat_input("Bạn muốn hỏi gì?") | |
| if st.session_state.model is not None: | |
| if prompt: | |
| st.session_state.chat_history.append({"role": "user", "content": prompt}) | |
| with st.chat_message("user"): | |
| st.write(prompt) | |
| with st.chat_message("assistant"): | |
| if st.session_state.rag is not None: | |
| respone = st.session_state.rag.invoke(prompt) | |
| st.write(respone) | |
| else: | |
| ans = st.session_state.llm.invoke(prompt) | |
| respone = ans.content | |
| st.write(respone) | |
| st.session_state.chat_history.append({"role": "assistant", "content": respone}) |