sahay-ai / app.py
frozen8569's picture
Update app.py
a36d7cd verified
raw
history blame
7.27 kB
import gradio as gr
import torch
import fitz # PyMuPDF
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import FAISS
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.llms.huggingface_pipeline import HuggingFacePipeline
from langchain.chains import ConversationalRetrievalChain
from langchain.memory import ConversationBufferMemory
from langchain.prompts import PromptTemplate
import pandas as pd
from aif360.datasets import StandardDataset
from aif360.metrics import BinaryLabelDatasetMetric
import time
# --- Caching (simple global variables for Gradio) ---
_llm = None
_qa_chain = None
# --- Core AI and Data Processing Functions (Unchanged) ---
def load_llm():
"""Loads the IBM Granite LLM, forcing it to use the GPU."""
global _llm
if _llm is None:
print("Loading LLM for the first time...")
llm_model_name = "ibm-granite/granite-3.3-8b-instruct"
if not torch.cuda.is_available():
raise RuntimeError("ZeroGPU requires a GPU. Please ensure hardware is set correctly.")
model = AutoModelForCausalLM.from_pretrained(
llm_model_name, torch_dtype=torch.bfloat16, load_in_4bit=True
)
tokenizer = AutoTokenizer.from_pretrained(llm_model_name)
pipe = pipeline(
"text-generation", model=model, tokenizer=tokenizer,
max_new_tokens=512, temperature=0.1, device=0
)
_llm = HuggingFacePipeline(pipeline=pipe)
return _llm
def load_and_process_pdf(pdf_path="PMKisanSamanNidhi.PDF"):
"""Loads and processes the PDF into a FAISS vector store."""
print("Loading and processing PDF...")
doc = fitz.open(pdf_path)
text = "".join(page.get_text() for page in doc)
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=150)
docs = text_splitter.create_documents([text])
embedding_model = HuggingFaceEmbeddings(model_name="ibm-granite/granite-embedding-278m-multilingual")
vector_db = FAISS.from_documents(docs, embedding_model)
return vector_db
def create_conversational_chain(llm, vector_db):
"""Creates the LangChain conversational retrieval chain."""
prompt_template = """You are a polite and professional AI assistant for the PM-KISAN scheme... (rest of prompt)"""
QA_PROMPT = PromptTemplate(template=prompt_template, input_variables=["context", "question"])
memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True, output_key='answer')
chain = ConversationalRetrievalChain.from_llm(
llm=llm, retriever=vector_db.as_retriever(), memory=memory,
return_source_documents=True, combine_docs_chain_kwargs={"prompt": QA_PROMPT}
)
return chain
def get_qa_chain():
"""Initializes and returns the QA chain."""
global _qa_chain
if _qa_chain is None:
llm = load_llm()
vector_db = load_and_process_pdf()
_qa_chain = create_conversational_chain(llm, vector_db)
return _qa_chain
def run_fairness_audit():
"""Performs the fairness audit and returns a formatted string."""
df_display = pd.DataFrame({
'query': ["loan for my farm", "help for my crops", "scheme for women", "grant for female farmer"],
'gender_text': ['male', 'male', 'female', 'female'],
'expected_doc': ['doc1', 'doc1', 'doc2', 'doc2']
})
def simulate_retriever(query):
return "doc2" if "women" in query or "female" in query else "doc1"
df_display['retrieved_doc'] = df_display['query'].apply(simulate_retriever)
df_display['favorable_outcome'] = (df_display['retrieved_doc'] == df_display['expected_doc']).astype(int)
df_for_aif = pd.DataFrame()
df_for_aif['gender'] = df_display['gender_text'].map({'male': 1, 'female': 0})
df_for_aif['favorable_outcome'] = df_display['favorable_outcome']
aif_dataset = StandardDataset(df_for_aif, label_name='favorable_outcome', favorable_classes=[1],
protected_attribute_names=['gender'], privileged_classes=[[1]])
metric = BinaryLabelDatasetMetric(aif_dataset, unprivileged_groups=[{'gender': 0}], privileged_groups=[{'gender': 1}])
spd = metric.statistical_parity_difference()
report = f"""
### 🤖 IBM AIF360 - Fairness Audit Results
**Metric: Statistical Parity Difference (SPD):** {spd:.4f}
**Interpretation:** An SPD of 0.0 indicates perfect fairness in this simulation.
---
**Raw Audit Data:**
```
{df_display.to_string()}
```
"""
return report
# --- Gradio UI ---
def chat_response(message, history):
"""Handles the user's message and returns the bot's response."""
qa_chain = get_qa_chain()
result = qa_chain.invoke({"question": message})
response = result["answer"]
# Add sources to the response
source_docs = result.get("source_documents", [])
if source_docs:
response += "\n\n--- \n*Sources used to generate this answer:*"
for i, doc in enumerate(source_docs):
cleaned_content = ' '.join(doc.page_content.split())
response += f"\n\n> **Source {i+1}:** \"{cleaned_content[:150]}...\""
# Yield response for streaming effect
for i in range(len(response)):
time.sleep(0.005)
yield response[:i+1]
# Initialize the AI model on startup
print("Initializing AI Chain...")
get_qa_chain()
print("AI Chain Ready.")
with gr.Blocks(theme=gr.themes.Soft(), title="Sahay AI") as demo:
gr.Markdown("# 🇮🇳 Chat with Sahay AI 💬")
gr.Markdown("Your trusted guide to the PM-KISAN scheme, powered by IBM Granite.")
with gr.Row():
with gr.Column(scale=3):
chatbot = gr.Chatbot(
value=[[None, "Welcome! Ask me a question about the PM-KISAN scheme."]],
label="Conversation",
bubble_full_width=False,
avatar_images=(None, "https://upload.wikimedia.org/wikipedia/commons/5/51/IBM_logo.svg")
)
msg = gr.Textbox(label="Your Question", placeholder="e.g., Who is eligible for this scheme?")
submit_btn = gr.Button("Send", variant="primary")
with gr.Column(scale=1):
gr.Markdown("### Actions & Connect")
audit_button = gr.Button("Run Fairness Audit")
audit_report = gr.Markdown(visible=False)
whatsapp_link = "https://wa.me/15551234567?text=Hello%20Sahay%20AI!"
gr.Markdown(f"📱 [Chat on WhatsApp]({whatsapp_link})")
gr.Markdown("⭐ [View Project on GitHub](https://github.com)")
# Event handlers
msg.submit(chat_response, [msg, chatbot], chatbot)
submit_btn.click(chat_response, [msg, chatbot], chatbot)
msg.submit(lambda: "", None, msg) # Clear textbox on submit
submit_btn.click(lambda: "", None, msg) # Clear textbox on submit
def show_audit():
report = run_fairness_audit()
return gr.update(value=report, visible=True)
audit_button.click(show_audit, outputs=audit_report)
if __name__ == "__main__":
demo.launch()