| \ | |
| import json, os | |
| import numpy as np, pandas as pd | |
| import faiss | |
| from sentence_transformers import SentenceTransformer, CrossEncoder | |
| class SloganSearcher: | |
| def _init_(self, assets_dir="assets", use_rerank=False, rerank_model="cross-encoder/stsb-roberta-base"): | |
| meta_path = os.path.join(assets_dir, "meta.json") | |
| if not os.path.exists(meta_path): | |
| raise FileNotFoundError(f"Missing {meta_path}. Build assets first.") | |
| with open(meta_path, "r") as f: | |
| self.meta = json.load(f) | |
| self.df = pd.read_parquet(os.path.join(assets_dir, "slogans_clean.parquet")) | |
| self.index = faiss.read_index(os.path.join(assets_dir, "faiss.index")) | |
| self.encoder = SentenceTransformer(self.meta["model_name"]) | |
| self.use_rerank = use_rerank | |
| self.reranker = CrossEncoder(rerank_model) if use_rerank else None | |
| self.text_col = self.meta.get("text_col", "description") | |
| self.fallback_col = self.meta.get("fallback_col", "tagline") | |
| self.norm = bool(self.meta.get("normalized", True)) | |
| def search(self, query: str, top_k=5, rerank_top_n=20): | |
| if not isinstance(query, str) or len(query.strip()) == 0: | |
| return pd.DataFrame(columns=["display", "score"] + (["rerank_score"] if self.use_rerank else [])) | |
| q = self.encoder.encode([query], convert_to_numpy=True, normalize_embeddings=self.norm) | |
| sims, idxs = self.index.search(q, max(int(top_k), int(rerank_top_n) if self.use_rerank else int(top_k))) | |
| idxs = idxs[0].tolist() | |
| sims = sims[0].tolist() | |
| results = self.df.iloc[idxs].copy() | |
| results["score"] = sims | |
| if self.use_rerank: | |
| texts = results[self.text_col].fillna(results[self.fallback_col]).astype(str).tolist() | |
| pairs = [[query, t] for t in texts] | |
| rr = self.reranker.predict(pairs) | |
| results["rerank_score"] = rr | |
| results = results.sort_values("rerank_score", ascending=False).head(int(top_k)) | |
| else: | |
| results = results.head(int(top_k)) | |
| results["display"] = results[self.fallback_col] | |
| cols = ["display", "score"] + (["rerank_score"] if self.use_rerank else []) | |
| return results[cols] | |