frozen8569 commited on
Commit
a36d7cd
·
verified ·
1 Parent(s): b632e3b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +124 -183
app.py CHANGED
@@ -1,4 +1,4 @@
1
- import streamlit as st
2
  import torch
3
  import fitz # PyMuPDF
4
  from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
@@ -9,220 +9,161 @@ from langchain_community.llms.huggingface_pipeline import HuggingFacePipeline
9
  from langchain.chains import ConversationalRetrievalChain
10
  from langchain.memory import ConversationBufferMemory
11
  from langchain.prompts import PromptTemplate
12
-
13
- # For Fairness Audit
14
  import pandas as pd
15
  from aif360.datasets import StandardDataset
16
  from aif360.metrics import BinaryLabelDatasetMetric
 
17
 
18
- # --- Page Configuration ---
19
- st.set_page_config(
20
- page_title="Sahay AI 🇮🇳",
21
- page_icon="🤖",
22
- layout="wide",
23
- initial_sidebar_state="expanded"
24
- )
25
 
26
- # --- Caching for Performance ---
27
- @st.cache_resource
28
  def load_llm():
29
- """Loads the IBM Granite LLM with quantization for efficiency."""
30
- llm_model_name = "ibm-granite/granite-3.3-8b-instruct"
31
-
32
- # Ensure we are on a CUDA-enabled device (GPU)
33
- if not torch.cuda.is_available():
34
- raise RuntimeError("This application requires a GPU to run. Please ensure the Space is configured with a T4 GPU.")
35
-
36
- model = AutoModelForCausalLM.from_pretrained(
37
- llm_model_name,
38
- torch_dtype=torch.bfloat16,
39
- load_in_4bit=True
40
- )
41
- tokenizer = AutoTokenizer.from_pretrained(llm_model_name)
42
-
43
- # --- THIS IS THE KEY FIX ---
44
- # We explicitly tell the pipeline to use the first GPU (device=0).
45
- # This prevents it from falling back to the slow CPU.
46
- pipe = pipeline(
47
- "text-generation",
48
- model=model,
49
- tokenizer=tokenizer,
50
- max_new_tokens=512,
51
- temperature=0.1,
52
- device=0 # Force the pipeline to use the GPU
53
- )
54
- return HuggingFacePipeline(pipeline=pipe)
55
-
56
- @st.cache_resource
57
- def load_and_process_pdf(pdf_path):
58
- """Loads, chunks, and embeds the PDF into a FAISS vector store."""
59
- try:
60
- doc = fitz.open(pdf_path)
61
- text = "".join(page.get_text() for page in doc)
62
- if not text:
63
- st.error("Could not extract text from the PDF.")
64
- return None
65
- except Exception as e:
66
- st.error(f"Error reading PDF file: {e}. Make sure 'PMKisanSamanNidhi.PDF' is uploaded to the Space.")
67
- return None
68
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
  text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=150)
70
  docs = text_splitter.create_documents([text])
71
-
72
- model_name = "ibm-granite/granite-embedding-278m-multilingual"
73
- embedding_model = HuggingFaceEmbeddings(model_name=model_name)
74
-
75
  vector_db = FAISS.from_documents(docs, embedding_model)
76
  return vector_db
77
 
78
- # --- Conversational Chain ---
79
- def create_conversational_chain(_llm, _vector_db):
80
  """Creates the LangChain conversational retrieval chain."""
81
- prompt_template = """You are a polite and professional AI assistant for the PM-KISAN scheme.
82
- Use the following context to answer the user's question precisely.
83
- If the question is not related to the provided context, you must state: "I can only answer questions related to the PM-KISAN scheme."
84
- Do not make up information. Present answers in a clear, easy-to-read format, using bullet points if helpful.
85
-
86
- Context: {context}
87
- Question: {question}
88
-
89
- Helpful Answer:"""
90
-
91
  QA_PROMPT = PromptTemplate(template=prompt_template, input_variables=["context", "question"])
92
  memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True, output_key='answer')
93
-
94
  chain = ConversationalRetrievalChain.from_llm(
95
- llm=_llm,
96
- retriever=_vector_db.as_retriever(search_kwargs={'k': 3}),
97
- memory=memory,
98
- return_source_documents=True,
99
- combine_docs_chain_kwargs={"prompt": QA_PROMPT}
100
  )
101
  return chain
102
 
103
- # --- IBM AIF360 Fairness Audit ---
104
- def run_fairness_audit():
105
- """Performs and displays a simulated fairness audit."""
106
- st.subheader("🤖 IBM AIF360 - Fairness Audit")
107
- st.info("""
108
- This is a simulation to demonstrate how we can check for bias in our information retriever.
109
- We test if queries related to different demographic groups (e.g., male vs. female farmers)
110
- get relevant results at a similar rate. A fair system should provide equally good information to all groups.
111
- """)
112
 
113
- # Step 1: Create the raw data
114
- test_data = {
 
115
  'query': ["loan for my farm", "help for my crops", "scheme for women", "grant for female farmer"],
116
- 'gender_text': ['male', 'male', 'female', 'female'], # Renamed to avoid confusion
117
  'expected_doc': ['doc1', 'doc1', 'doc2', 'doc2']
118
- }
119
- df_display = pd.DataFrame(test_data) # This one is for showing the user
120
-
121
  def simulate_retriever(query):
122
  return "doc2" if "women" in query or "female" in query else "doc1"
123
  df_display['retrieved_doc'] = df_display['query'].apply(simulate_retriever)
124
  df_display['favorable_outcome'] = (df_display['retrieved_doc'] == df_display['expected_doc']).astype(int)
125
-
126
- # --- THIS IS THE FIX ---
127
- # Step 2: Create a purely numerical DataFrame for aif360
128
- # The library requires all input columns to be numbers.
129
  df_for_aif = pd.DataFrame()
130
  df_for_aif['gender'] = df_display['gender_text'].map({'male': 1, 'female': 0})
131
  df_for_aif['favorable_outcome'] = df_display['favorable_outcome']
132
 
133
- # Step 3: Call aif360 with the numerical DataFrame
134
- aif_dataset = StandardDataset(df_for_aif,
135
- label_name='favorable_outcome',
136
- favorable_classes=[1],
137
- protected_attribute_names=['gender'],
138
- privileged_classes=[[1]]) # Use 1 for 'male'
139
-
140
- # Use the numerical representation for groups
141
- metric = BinaryLabelDatasetMetric(aif_dataset,
142
- unprivileged_groups=[{'gender': 0}], # female is 0
143
- privileged_groups=[{'gender': 1}]) # male is 1
144
  spd = metric.statistical_parity_difference()
145
-
146
- st.markdown("---")
147
- col1, col2 = st.columns(2)
148
- with col1:
149
- st.metric(label="**Metric: Statistical Parity Difference (SPD)**", value=f"{spd:.4f}")
150
- with col2:
151
- st.success("An SPD of **0.0** indicates perfect fairness in this simulation.")
152
 
153
- # Show the original, readable data to the user
154
- with st.expander("Show Raw Audit Data"):
155
- st.dataframe(df_display)
156
-
157
- # --- Main Application UI ---
158
- if __name__ == "__main__":
 
 
 
 
 
 
 
 
 
 
 
 
159
 
160
- with st.sidebar:
161
- st.image("https://upload.wikimedia.org/wikipedia/commons/5/51/IBM_logo.svg", width=100)
162
- st.title("🇮🇳 Sahay AI")
163
- st.markdown("### About")
164
- st.markdown("An AI assistant for the **PM-KISAN** scheme, built on **IBM's Granite** foundation models.")
165
- st.markdown("---")
166
-
167
- st.markdown("### Actions")
168
- if st.button("Run Fairness Audit", use_container_width=True):
169
- st.session_state.run_audit = True
170
- st.markdown("---")
171
-
172
- st.markdown("### Connect")
173
- whatsapp_link = "https://wa.me/15551234567?text=Hello%20Sahay%20AI!"
174
- st.markdown(f"📱 [Chat on WhatsApp]({whatsapp_link})")
175
- st.markdown("⭐ [View Project on GitHub](https://github.com)")
176
- st.markdown("---")
177
-
178
- st.header("Chat with Sahay AI 💬")
179
- st.markdown("Your trusted guide to the PM-KISAN scheme.")
180
-
181
- if st.session_state.get('run_audit', False):
182
- run_fairness_audit()
183
- st.session_state.run_audit = False
184
 
185
- if "messages" not in st.session_state:
186
- st.session_state.messages = []
187
- st.session_state.messages.append({
188
- "role": "assistant",
189
- "content": "Welcome! How can I help you understand the PM-KISAN scheme today? You can ask me questions like:\n- What is this scheme about?\n- Who is eligible?\n- *इस योजना के लिए कौन पात्र है?*"
190
- })
191
-
192
- try:
193
- llm = load_llm()
194
- vector_db = load_and_process_pdf("PMKisanSamanNidhi.PDF")
195
- if vector_db and "qa_chain" not in st.session_state:
196
- st.session_state.qa_chain = create_conversational_chain(llm, vector_db)
197
- except Exception as e:
198
- st.error("Could not initialize the AI model. The service might be temporarily unavailable. Please try again later.")
199
- st.stop()
200
-
201
- for message in st.session_state.messages:
202
- with st.chat_message(message["role"]):
203
- st.markdown(message["content"])
204
-
205
- if prompt := st.chat_input("Ask a question about the PM-KISAN scheme..."):
206
- st.session_state.messages.append({"role": "user", "content": prompt})
207
- with st.chat_message("user"):
208
- st.markdown(prompt)
 
 
 
 
 
 
209
 
210
- with st.chat_message("assistant"):
211
- with st.spinner("🧠 Thinking..."):
212
- if "qa_chain" in st.session_state:
213
- result = st.session_state.qa_chain.invoke({"question": prompt})
214
- response = result["answer"]
215
- source_docs = result.get("source_documents", [])
216
-
217
- if source_docs:
218
- response += "\n\n--- \n*Sources used to generate this answer:*"
219
- for i, doc in enumerate(source_docs):
220
- cleaned_content = ' '.join(doc.page_content.split())
221
- response += f"\n\n> **Source {i+1}:** \"{cleaned_content[:150]}...\""
222
-
223
- st.markdown(response)
224
- else:
225
- response = "Sorry, the application is not properly initialized."
226
- st.error(response)
227
-
228
- st.session_state.messages.append({"role": "assistant", "content": response})
 
1
+ import gradio as gr
2
  import torch
3
  import fitz # PyMuPDF
4
  from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
 
9
  from langchain.chains import ConversationalRetrievalChain
10
  from langchain.memory import ConversationBufferMemory
11
  from langchain.prompts import PromptTemplate
 
 
12
  import pandas as pd
13
  from aif360.datasets import StandardDataset
14
  from aif360.metrics import BinaryLabelDatasetMetric
15
+ import time
16
 
17
+ # --- Caching (simple global variables for Gradio) ---
18
+ _llm = None
19
+ _qa_chain = None
 
 
 
 
20
 
21
+ # --- Core AI and Data Processing Functions (Unchanged) ---
 
22
  def load_llm():
23
+ """Loads the IBM Granite LLM, forcing it to use the GPU."""
24
+ global _llm
25
+ if _llm is None:
26
+ print("Loading LLM for the first time...")
27
+ llm_model_name = "ibm-granite/granite-3.3-8b-instruct"
28
+
29
+ if not torch.cuda.is_available():
30
+ raise RuntimeError("ZeroGPU requires a GPU. Please ensure hardware is set correctly.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
+ model = AutoModelForCausalLM.from_pretrained(
33
+ llm_model_name, torch_dtype=torch.bfloat16, load_in_4bit=True
34
+ )
35
+ tokenizer = AutoTokenizer.from_pretrained(llm_model_name)
36
+
37
+ pipe = pipeline(
38
+ "text-generation", model=model, tokenizer=tokenizer,
39
+ max_new_tokens=512, temperature=0.1, device=0
40
+ )
41
+ _llm = HuggingFacePipeline(pipeline=pipe)
42
+ return _llm
43
+
44
+ def load_and_process_pdf(pdf_path="PMKisanSamanNidhi.PDF"):
45
+ """Loads and processes the PDF into a FAISS vector store."""
46
+ print("Loading and processing PDF...")
47
+ doc = fitz.open(pdf_path)
48
+ text = "".join(page.get_text() for page in doc)
49
  text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=150)
50
  docs = text_splitter.create_documents([text])
51
+ embedding_model = HuggingFaceEmbeddings(model_name="ibm-granite/granite-embedding-278m-multilingual")
 
 
 
52
  vector_db = FAISS.from_documents(docs, embedding_model)
53
  return vector_db
54
 
55
+ def create_conversational_chain(llm, vector_db):
 
56
  """Creates the LangChain conversational retrieval chain."""
57
+ prompt_template = """You are a polite and professional AI assistant for the PM-KISAN scheme... (rest of prompt)"""
 
 
 
 
 
 
 
 
 
58
  QA_PROMPT = PromptTemplate(template=prompt_template, input_variables=["context", "question"])
59
  memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True, output_key='answer')
 
60
  chain = ConversationalRetrievalChain.from_llm(
61
+ llm=llm, retriever=vector_db.as_retriever(), memory=memory,
62
+ return_source_documents=True, combine_docs_chain_kwargs={"prompt": QA_PROMPT}
 
 
 
63
  )
64
  return chain
65
 
66
+ def get_qa_chain():
67
+ """Initializes and returns the QA chain."""
68
+ global _qa_chain
69
+ if _qa_chain is None:
70
+ llm = load_llm()
71
+ vector_db = load_and_process_pdf()
72
+ _qa_chain = create_conversational_chain(llm, vector_db)
73
+ return _qa_chain
 
74
 
75
+ def run_fairness_audit():
76
+ """Performs the fairness audit and returns a formatted string."""
77
+ df_display = pd.DataFrame({
78
  'query': ["loan for my farm", "help for my crops", "scheme for women", "grant for female farmer"],
79
+ 'gender_text': ['male', 'male', 'female', 'female'],
80
  'expected_doc': ['doc1', 'doc1', 'doc2', 'doc2']
81
+ })
 
 
82
  def simulate_retriever(query):
83
  return "doc2" if "women" in query or "female" in query else "doc1"
84
  df_display['retrieved_doc'] = df_display['query'].apply(simulate_retriever)
85
  df_display['favorable_outcome'] = (df_display['retrieved_doc'] == df_display['expected_doc']).astype(int)
86
+
 
 
 
87
  df_for_aif = pd.DataFrame()
88
  df_for_aif['gender'] = df_display['gender_text'].map({'male': 1, 'female': 0})
89
  df_for_aif['favorable_outcome'] = df_display['favorable_outcome']
90
 
91
+ aif_dataset = StandardDataset(df_for_aif, label_name='favorable_outcome', favorable_classes=[1],
92
+ protected_attribute_names=['gender'], privileged_classes=[[1]])
93
+ metric = BinaryLabelDatasetMetric(aif_dataset, unprivileged_groups=[{'gender': 0}], privileged_groups=[{'gender': 1}])
 
 
 
 
 
 
 
 
94
  spd = metric.statistical_parity_difference()
 
 
 
 
 
 
 
95
 
96
+ report = f"""
97
+ ### 🤖 IBM AIF360 - Fairness Audit Results
98
+ **Metric: Statistical Parity Difference (SPD):** {spd:.4f}
99
+ **Interpretation:** An SPD of 0.0 indicates perfect fairness in this simulation.
100
+ ---
101
+ **Raw Audit Data:**
102
+ ```
103
+ {df_display.to_string()}
104
+ ```
105
+ """
106
+ return report
107
+
108
+ # --- Gradio UI ---
109
+ def chat_response(message, history):
110
+ """Handles the user's message and returns the bot's response."""
111
+ qa_chain = get_qa_chain()
112
+ result = qa_chain.invoke({"question": message})
113
+ response = result["answer"]
114
 
115
+ # Add sources to the response
116
+ source_docs = result.get("source_documents", [])
117
+ if source_docs:
118
+ response += "\n\n--- \n*Sources used to generate this answer:*"
119
+ for i, doc in enumerate(source_docs):
120
+ cleaned_content = ' '.join(doc.page_content.split())
121
+ response += f"\n\n> **Source {i+1}:** \"{cleaned_content[:150]}...\""
122
+
123
+ # Yield response for streaming effect
124
+ for i in range(len(response)):
125
+ time.sleep(0.005)
126
+ yield response[:i+1]
127
+
128
+ # Initialize the AI model on startup
129
+ print("Initializing AI Chain...")
130
+ get_qa_chain()
131
+ print("AI Chain Ready.")
132
+
133
+ with gr.Blocks(theme=gr.themes.Soft(), title="Sahay AI") as demo:
134
+ gr.Markdown("# 🇮🇳 Chat with Sahay AI 💬")
135
+ gr.Markdown("Your trusted guide to the PM-KISAN scheme, powered by IBM Granite.")
 
 
 
136
 
137
+ with gr.Row():
138
+ with gr.Column(scale=3):
139
+ chatbot = gr.Chatbot(
140
+ value=[[None, "Welcome! Ask me a question about the PM-KISAN scheme."]],
141
+ label="Conversation",
142
+ bubble_full_width=False,
143
+ avatar_images=(None, "https://upload.wikimedia.org/wikipedia/commons/5/51/IBM_logo.svg")
144
+ )
145
+ msg = gr.Textbox(label="Your Question", placeholder="e.g., Who is eligible for this scheme?")
146
+ submit_btn = gr.Button("Send", variant="primary")
147
+
148
+ with gr.Column(scale=1):
149
+ gr.Markdown("### Actions & Connect")
150
+ audit_button = gr.Button("Run Fairness Audit")
151
+ audit_report = gr.Markdown(visible=False)
152
+
153
+ whatsapp_link = "https://wa.me/15551234567?text=Hello%20Sahay%20AI!"
154
+ gr.Markdown(f"📱 [Chat on WhatsApp]({whatsapp_link})")
155
+ gr.Markdown("⭐ [View Project on GitHub](https://github.com)")
156
+
157
+ # Event handlers
158
+ msg.submit(chat_response, [msg, chatbot], chatbot)
159
+ submit_btn.click(chat_response, [msg, chatbot], chatbot)
160
+ msg.submit(lambda: "", None, msg) # Clear textbox on submit
161
+ submit_btn.click(lambda: "", None, msg) # Clear textbox on submit
162
+
163
+ def show_audit():
164
+ report = run_fairness_audit()
165
+ return gr.update(value=report, visible=True)
166
+ audit_button.click(show_audit, outputs=audit_report)
167
 
168
+ if __name__ == "__main__":
169
+ demo.launch()