import os from dotenv import load_dotenv import re import pickle import faiss import numpy as np from typing import List, Dict from sentence_transformers import SentenceTransformer, CrossEncoder, util from rank_bm25 import BM25Okapi import nltk from nltk.corpus import stopwords import requests import json from openai import OpenAI import logging import generate_indexes load_dotenv() generate_indexes.main() # ---------------- Logging Setup ---------------- logging.basicConfig( level=logging.INFO, format='%(asctime)s %(levelname)s %(message)s', handlers=[logging.StreamHandler()] ) logger = logging.getLogger(__name__) nltk.download("stopwords") STOPWORDS = set(stopwords.words("english")) os.environ["TOKENIZERS_PARALLELISM"] = "false" # ...rest of your imports... # ---------------- Paths & Models ---------------- EMBED_MODEL = "sentence-transformers/all-MiniLM-L6-v2" CROSS_ENCODER = "cross-encoder/ms-marco-MiniLM-L-6-v2" OUT_DIR = "data/index_merged" FAISS_PATH = os.path.join(OUT_DIR, "faiss_merged.index") BM25_PATH = os.path.join(OUT_DIR, "bm25_merged.pkl") META_PATH = os.path.join(OUT_DIR, "meta_merged.pkl") BLOCKED_TERMS = ["weather","cricket","movie","song","football","holiday", "travel","recipe","music","game","sports","politics","election"] FINANCE_DOMAINS = [ "financial reporting","balance sheet","income statement","assets and liabilities", "equity","revenue","profit and loss","goodwill impairment","cash flow","dividends", "taxation","investment","valuation","capital structure","ownership interests", "subsidiaries","shareholders equity","expenses","earnings","debt","amortization","depreciation" ] ALLOWED_COMPANY = ["make my trip","mmt"] # crude regex to detect "company-like" words (any capitalized word(s) followed by Ltd, Inc, Company, etc.) COMPANY_PATTERN = re.compile(r"\b([A-Z][a-zA-Z]+(?:\s+[A-Z][a-zA-Z]+)*\s+(?:Ltd|Limited|Inc|Corporation|Corp|LLC|Group|Company|Bank))\b", re.IGNORECASE) # ---------------- Load Indexes ---------------- logger.info("Loading FAISS, BM25, metadata, and models...") try: faiss_index = faiss.read_index(FAISS_PATH) with open(BM25_PATH, "rb") as f: bm25_obj = pickle.load(f) bm25 = bm25_obj["bm25"] with open(META_PATH, "rb") as f: meta: List[Dict] = pickle.load(f) embed_model = SentenceTransformer(EMBED_MODEL) reranker = CrossEncoder(CROSS_ENCODER) api_key = os.getenv("HF_API_KEY") if not api_key: logger.error("HF_API_KEY environment variable not set. Please check your .env file or environment.") raise ValueError("HF_API_KEY environment variable not set.") client = OpenAI( base_url="https://router.huggingface.co/v1", api_key=api_key ) except Exception as e: logger.error(f"Error loading models or indexes: {e}") raise # ---------------- Hugging Face Mistral API ---------------- #HF_TOKEN = "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" #HF_MODEL = "mistralai/Mistral-7B-Instruct-v0.2:featherless-ai" def get_mistral_answer(query: str, context: str) -> str: """ Calls Mistral 7B Instruct API via Hugging Face Inference API. Adds error handling and logging. """ prompt = f"Context:\n{context}\n\nQuestion: {query}\nAnswer in full sentences using context." try: logger.info(f"Calling Mistral API for query: {query}") completion = client.chat.completions.create( model="mistralai/Mistral-7B-Instruct-v0.2:featherless-ai", messages=[ { "role": "user", "content": prompt } ] ) answer = str(completion.choices[0].message.content) logger.info(f"Mistral API response: {answer}") return answer except Exception as e: logger.error(f"Error in Mistral API call: {e}") return f"Error fetching answer from LLM: {e}" # ---------------- Guardrails ---------------- finance_embeds = embed_model.encode(FINANCE_DOMAINS, convert_to_tensor=True) def validate_query(query: str, threshold: float = 0.5) -> bool: q_lower = query.lower() # Blocklist check if any(bad in q_lower for bad in BLOCKED_TERMS): print("[Guardrail] Rejected by blocklist.") return False # Check for company mentions companies_found = COMPANY_PATTERN.findall(query) if companies_found: # If any company is mentioned, only allow MakeMyTrip if not any(ALLOWED_COMPANY in c.lower() for c in companies_found): print(f"[Guardrail] Rejected: company mention {companies_found}, not {ALLOWED_COMPANY}.") return False # Semantic similarity check with financial domain q_emb = embed_model.encode(query, convert_to_tensor=True) sim_scores = util.cos_sim(q_emb, finance_embeds) max_score = float(sim_scores.max()) if max_score > threshold: print(f"[Guardrail] Accepted (semantic match {max_score:.2f})") return True else: print(f"[Guardrail] Rejected (low semantic score {max_score:.2f})") return False #-------------------Output Guardrail------------------ def validate_output(answer: str, context_docs: List[Dict]) -> str: combined_context = " ".join([doc["content"].lower() for doc in context_docs]) if answer.lower() in combined_context: return answer return "The information could not be verified in the financial statement attached." # ---------------- Preprocess ---------------- def preprocess_query(query: str, remove_stopwords: bool = True) -> str: query = query.lower() query = re.sub(r"[^a-z0-9\s]", " ", query) tokens = query.split() if remove_stopwords: tokens = [t for t in tokens if t not in STOPWORDS] return " ".join(tokens) # ---------------- Hybrid Retrieval ---------------- def hybrid_candidates(query: str, candidate_k: int = 50, alpha: float = 0.5) -> List[int]: q_emb = embed_model.encode([preprocess_query(query, remove_stopwords=False)], convert_to_numpy=True, normalize_embeddings=True) faiss_scores, faiss_ids = faiss_index.search(q_emb, max(candidate_k, 50)) faiss_ids = faiss_ids[0] faiss_scores = faiss_scores[0] tokenized_query = preprocess_query(query).split() bm25_scores = bm25.get_scores(tokenized_query) topN = max(candidate_k, 50) bm25_top = np.argsort(bm25_scores)[::-1][:topN] faiss_top = faiss_ids[:topN] union_ids = np.unique(np.concatenate([bm25_top, faiss_top])) faiss_score_map = {int(i): float(s) for i, s in zip(faiss_ids, faiss_scores)} f_arr = np.array([faiss_score_map.get(int(i), -1.0) for i in union_ids], dtype=float) f_min = np.min(f_arr) if np.any(f_arr < 0): f_arr = np.where(f_arr < 0, f_min, f_arr) b_arr = np.array([bm25_scores[int(i)] for i in union_ids], dtype=float) def _norm(x): return (x - np.min(x)) / (np.ptp(x) + 1e-9) combined = alpha * _norm(f_arr) + (1 - alpha) * _norm(b_arr) order = np.argsort(combined)[::-1] return union_ids[order][:candidate_k].tolist() # ---------------- Cross-Encoder Rerank ---------------- def rerank_cross_encoder(query: str, cand_ids: List[int], top_k: int = 10) -> List[Dict]: pairs = [(query, meta[i]["content"]) for i in cand_ids] scores = reranker.predict(pairs) order = np.argsort(scores)[::-1][:top_k] return [{"id": cand_ids[i], "chunk_size": meta[cand_ids[i]]["chunk_size"], "content": meta[cand_ids[i]]["content"], "rerank_score": float(scores[i])} for i in order] # ---------------- Extract Numeric ---------------- def extract_value_for_year_and_concept(year: str, concept: str, context_docs: List[Dict]) -> str: target_year = str(year) concept_lower = concept.lower() for doc in context_docs: text = doc.get("content", "") lines = [line for line in text.split("\n") if line.strip() and any(c.isdigit() for c in line)] header_idx = None year_to_col = {} for idx, line in enumerate(lines): years_in_line = re.findall(r"20\d{2}", line) if years_in_line: for col_idx, y in enumerate(years_in_line): year_to_col[y] = col_idx header_idx = idx break if target_year not in year_to_col or header_idx is None: continue for line in lines[header_idx+1:]: if concept_lower in line.lower(): cols = re.split(r"\s{2,}|\t", line) col_idx = year_to_col[target_year] if col_idx < len(cols): return cols[col_idx].replace(",", "") return "" # ---------------- RAG Pipeline ---------------- def generate_answer(query: str, top_k: int = 5, candidate_k: int = 50, alpha: float = 0.6): logger.info(f"Received query: {query}") try: if not validate_query(query): logger.warning("Query rejected: Not finance-related.") return "Query rejected: Please ask finance-related questions.", [] cand_ids = hybrid_candidates(query, candidate_k=candidate_k, alpha=alpha) logger.info(f"Hybrid candidates retrieved: {cand_ids}") reranked = rerank_cross_encoder(query, cand_ids, top_k=top_k) logger.info(f"Reranked top docs: {[d['id'] for d in reranked]}") year_match = re.search(r"(20\d{2})", query) year = year_match.group(0) if year_match else None concept = re.sub(r"for the year 20\d{2}", "", query, flags=re.IGNORECASE).strip() year_specific_answer = None if year and concept: year_specific_answer = extract_value_for_year_and_concept(year, concept, reranked) logger.info(f"Year-specific answer: {year_specific_answer}") if year_specific_answer: answer = year_specific_answer else: # Pass top 5 chunks as context context_text = "\n".join([d["content"] for d in reranked]) answer = get_mistral_answer(query, context_text) final_answer = answer #validate_output(answer, reranked) logger.info(f"Final Answer: {final_answer}") return final_answer except Exception as e: logger.error(f"Error in RAG pipeline: {e}") return f"Error in RAG pipeline: {e}", []