Spaces:
Configuration error
Configuration error
| import gradio as gr | |
| import torch | |
| import fitz # PyMuPDF | |
| from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline | |
| from langchain.text_splitter import RecursiveCharacterTextSplitter | |
| from langchain_community.vectorstores import FAISS | |
| from langchain_community.embeddings import HuggingFaceEmbeddings | |
| from langchain_community.llms.huggingface_pipeline import HuggingFacePipeline | |
| from langchain.chains import ConversationalRetrievalChain | |
| from langchain.memory import ConversationBufferMemory | |
| from langchain.prompts import PromptTemplate | |
| import pandas as pd | |
| from aif360.datasets import StandardDataset | |
| from aif360.metrics import BinaryLabelDatasetMetric | |
| import time | |
| # --- Caching (simple global variables for Gradio) --- | |
| _llm = None | |
| _qa_chain = None | |
| # --- Core AI and Data Processing Functions (Unchanged) --- | |
| def load_llm(): | |
| """Loads the IBM Granite LLM, forcing it to use the GPU.""" | |
| global _llm | |
| if _llm is None: | |
| print("Loading LLM for the first time...") | |
| llm_model_name = "ibm-granite/granite-3.3-8b-instruct" | |
| if not torch.cuda.is_available(): | |
| raise RuntimeError("ZeroGPU requires a GPU. Please ensure hardware is set correctly.") | |
| model = AutoModelForCausalLM.from_pretrained( | |
| llm_model_name, torch_dtype=torch.bfloat16, load_in_4bit=True | |
| ) | |
| tokenizer = AutoTokenizer.from_pretrained(llm_model_name) | |
| pipe = pipeline( | |
| "text-generation", model=model, tokenizer=tokenizer, | |
| max_new_tokens=512, temperature=0.1, device=0 | |
| ) | |
| _llm = HuggingFacePipeline(pipeline=pipe) | |
| return _llm | |
| def load_and_process_pdf(pdf_path="PMKisanSamanNidhi.PDF"): | |
| """Loads and processes the PDF into a FAISS vector store.""" | |
| print("Loading and processing PDF...") | |
| doc = fitz.open(pdf_path) | |
| text = "".join(page.get_text() for page in doc) | |
| text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=150) | |
| docs = text_splitter.create_documents([text]) | |
| embedding_model = HuggingFaceEmbeddings(model_name="ibm-granite/granite-embedding-278m-multilingual") | |
| vector_db = FAISS.from_documents(docs, embedding_model) | |
| return vector_db | |
| def create_conversational_chain(llm, vector_db): | |
| """Creates the LangChain conversational retrieval chain.""" | |
| prompt_template = """You are a polite and professional AI assistant for the PM-KISAN scheme... (rest of prompt)""" | |
| QA_PROMPT = PromptTemplate(template=prompt_template, input_variables=["context", "question"]) | |
| memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True, output_key='answer') | |
| chain = ConversationalRetrievalChain.from_llm( | |
| llm=llm, retriever=vector_db.as_retriever(), memory=memory, | |
| return_source_documents=True, combine_docs_chain_kwargs={"prompt": QA_PROMPT} | |
| ) | |
| return chain | |
| def get_qa_chain(): | |
| """Initializes and returns the QA chain.""" | |
| global _qa_chain | |
| if _qa_chain is None: | |
| llm = load_llm() | |
| vector_db = load_and_process_pdf() | |
| _qa_chain = create_conversational_chain(llm, vector_db) | |
| return _qa_chain | |
| def run_fairness_audit(): | |
| """Performs the fairness audit and returns a formatted string.""" | |
| df_display = pd.DataFrame({ | |
| 'query': ["loan for my farm", "help for my crops", "scheme for women", "grant for female farmer"], | |
| 'gender_text': ['male', 'male', 'female', 'female'], | |
| 'expected_doc': ['doc1', 'doc1', 'doc2', 'doc2'] | |
| }) | |
| def simulate_retriever(query): | |
| return "doc2" if "women" in query or "female" in query else "doc1" | |
| df_display['retrieved_doc'] = df_display['query'].apply(simulate_retriever) | |
| df_display['favorable_outcome'] = (df_display['retrieved_doc'] == df_display['expected_doc']).astype(int) | |
| df_for_aif = pd.DataFrame() | |
| df_for_aif['gender'] = df_display['gender_text'].map({'male': 1, 'female': 0}) | |
| df_for_aif['favorable_outcome'] = df_display['favorable_outcome'] | |
| aif_dataset = StandardDataset(df_for_aif, label_name='favorable_outcome', favorable_classes=[1], | |
| protected_attribute_names=['gender'], privileged_classes=[[1]]) | |
| metric = BinaryLabelDatasetMetric(aif_dataset, unprivileged_groups=[{'gender': 0}], privileged_groups=[{'gender': 1}]) | |
| spd = metric.statistical_parity_difference() | |
| report = f""" | |
| ### 🤖 IBM AIF360 - Fairness Audit Results | |
| **Metric: Statistical Parity Difference (SPD):** {spd:.4f} | |
| **Interpretation:** An SPD of 0.0 indicates perfect fairness in this simulation. | |
| --- | |
| **Raw Audit Data:** | |
| ``` | |
| {df_display.to_string()} | |
| ``` | |
| """ | |
| return report | |
| # --- Gradio UI --- | |
| def chat_response(message, history): | |
| """Handles the user's message and returns the bot's response.""" | |
| qa_chain = get_qa_chain() | |
| result = qa_chain.invoke({"question": message}) | |
| response = result["answer"] | |
| # Add sources to the response | |
| source_docs = result.get("source_documents", []) | |
| if source_docs: | |
| response += "\n\n--- \n*Sources used to generate this answer:*" | |
| for i, doc in enumerate(source_docs): | |
| cleaned_content = ' '.join(doc.page_content.split()) | |
| response += f"\n\n> **Source {i+1}:** \"{cleaned_content[:150]}...\"" | |
| # Yield response for streaming effect | |
| for i in range(len(response)): | |
| time.sleep(0.005) | |
| yield response[:i+1] | |
| # Initialize the AI model on startup | |
| print("Initializing AI Chain...") | |
| get_qa_chain() | |
| print("AI Chain Ready.") | |
| with gr.Blocks(theme=gr.themes.Soft(), title="Sahay AI") as demo: | |
| gr.Markdown("# 🇮🇳 Chat with Sahay AI 💬") | |
| gr.Markdown("Your trusted guide to the PM-KISAN scheme, powered by IBM Granite.") | |
| with gr.Row(): | |
| with gr.Column(scale=3): | |
| chatbot = gr.Chatbot( | |
| value=[[None, "Welcome! Ask me a question about the PM-KISAN scheme."]], | |
| label="Conversation", | |
| bubble_full_width=False, | |
| avatar_images=(None, "https://upload.wikimedia.org/wikipedia/commons/5/51/IBM_logo.svg") | |
| ) | |
| msg = gr.Textbox(label="Your Question", placeholder="e.g., Who is eligible for this scheme?") | |
| submit_btn = gr.Button("Send", variant="primary") | |
| with gr.Column(scale=1): | |
| gr.Markdown("### Actions & Connect") | |
| audit_button = gr.Button("Run Fairness Audit") | |
| audit_report = gr.Markdown(visible=False) | |
| whatsapp_link = "https://wa.me/15551234567?text=Hello%20Sahay%20AI!" | |
| gr.Markdown(f"📱 [Chat on WhatsApp]({whatsapp_link})") | |
| gr.Markdown("⭐ [View Project on GitHub](https://github.com)") | |
| # Event handlers | |
| msg.submit(chat_response, [msg, chatbot], chatbot) | |
| submit_btn.click(chat_response, [msg, chatbot], chatbot) | |
| msg.submit(lambda: "", None, msg) # Clear textbox on submit | |
| submit_btn.click(lambda: "", None, msg) # Clear textbox on submit | |
| def show_audit(): | |
| report = run_fairness_audit() | |
| return gr.update(value=report, visible=True) | |
| audit_button.click(show_audit, outputs=audit_report) | |
| if __name__ == "__main__": | |
| demo.launch() |