bebechien's picture
Upload folder using huggingface_hub
75684b1 verified
import os
import gradio as gr
import pickle
import torch
from tqdm import tqdm
from web_helper import get_html, find_wiki_links, get_markdown_from_html, get_markdown_from_url
# --- Hugging Face & Model Configuration ---
HF_TOKEN = os.getenv('HF_TOKEN')
EMBEDDING_MODEL_ID = "google/embeddinggemma-300M"
LLM_MODEL_ID = "google/gemma-3-12B-it"
# --- Data Source Configuration ---
BASE_URL = "https://hollowknight.wiki"
GAME_KNOWLEDGE_DATA = [
{
"title": "Hollow Knight",
"cache_folder": "1_cache",
"category_list": [
{
"entry": "/w/Category:Bosses_(Hollow_Knight)",
"cache": "hollow_knight_bosses.pkl",
"label": "Bosses",
},
],
},
{
"title": "Silksong",
"cache_folder": "2_cache",
"category_list": [
{
"entry": "/w/Hornet_(Silksong)",
"cache": "silksong_hornet.pkl",
"label": "General",
},
{
"entry": "/w/Hollow_Knight:_Silksong",
"cache": "silksong_game.pkl",
"label": "General",
},
{
"entry": "/w/Category:Areas_(Silksong)",
"cache": "silksong_areas.pkl",
"label": "Areas",
},
{
"entry": "/w/Category:Bosses_(Silksong)",
"cache": "silksong_bosses.pkl",
"label": "Bosses",
},
{
"entry": "/w/Category:Items_(Silksong)",
"cache": "silksong_items.pkl",
"label": "Items",
},
{
"entry": "/w/Category:NPCs_(Silksong)",
"cache": "silksong_npcs.pkl",
"label": "NPCs",
},
{
"entry": "/w/Tasks",
"cache": "silksong_tasks.pkl",
"label": "Tasks",
},
{
"entry": "/w/Category:Crests_and_Skills",
"cache": "silksong_crests_and_skills.pkl",
"label": "Crests and Skills",
},
{
"entry": "/w/Category:Tools",
"cache": "silksong_tools.pkl",
"label": "Tools",
},
{
"entry": "/w/Category:Abilities_(Silksong)",
"cache": "silksong_abilities.pkl",
"label": "Abilities",
},
],
},
]
def get_all_game_data(embedding_model):
"""Loops through the config and processes/loads all knowledge sources."""
print("\n--- Processing Game Data ---")
knowledge_base = {}
for item in GAME_KNOWLEDGE_DATA:
title = item['title']
knowledge_base[title] = []
for category in item['category_list']:
cache_path = f"""{item["cache_folder"]}/{category["cache"]}"""
knowledge_base[title] += _load_or_process_source(
category['entry'],
cache_path,
category['label'],
embedding_model
)
return knowledge_base
# --- DATA PROCESSING & CACHING ---
# Scrapes data and generates embeddings, using a cache to avoid re-running.
def _clean_text(text: str) -> str:
"""Removes the references section from the raw text."""
return text.split("References\n----------\n", 1)[0].strip()
@torch.no_grad()
def _create_data_entry(text: str, doc_path: str, label: str, embedding_model) -> dict | None:
"""Creates a single structured data entry with text, metadata, and embedding."""
cleaned_text = _clean_text(text)
if not cleaned_text:
return None
title = doc_path.split('/')[-1]
# Encode returns a numpy array; convert to tensor for stacking later.
embedding = embedding_model.encode(cleaned_text, prompt=f"title: {title} | text: ")
return {
"text": cleaned_text,
"embedding": torch.tensor(embedding), ### Store as tensor for faster processing
"metadata": {
"category": label,
"source": BASE_URL + doc_path,
"title": title
}
}
def _load_or_process_source(entry_point: str, cache_file: str, label: str, embedding_model):
"""
Loads processed data from a cache file if it exists. Otherwise, scrapes,
processes, generates embeddings, and saves to the cache.
"""
if os.path.exists(cache_file):
print(f"✅ Found cache for {label}. Loading data from '{cache_file}'...")
with open(cache_file, 'rb') as f:
return pickle.load(f)
print(f"ℹ️ No cache for {label}. Starting data scraping and processing...")
processed_data = []
main_page_html = get_html(BASE_URL + entry_point)
data_entry = _create_data_entry(get_markdown_from_html(main_page_html), entry_point, label, embedding_model)
if (data_entry):
processed_data.append(data_entry)
extracted_links = find_wiki_links(main_page_html)
for doc_path in tqdm(extracted_links, desc=f"Processing {label} Pages"):
full_url = BASE_URL + doc_path
text = get_markdown_from_url(full_url)
data_entry = _create_data_entry(text, doc_path, label, embedding_model)
if data_entry:
processed_data.append(data_entry)
print(f"✅ {label} processing complete. Saving {len(processed_data)} entries to '{cache_file}'...")
os.makedirs(os.path.dirname(cache_file), exist_ok=True)
with open(cache_file, 'wb') as f:
pickle.dump(processed_data, f)
return processed_data
# --- App Logic Configuration ---
BASE_SIMILARITY_THRESHOLD = 0.2
FOLLOWUP_SIMILARITY_THRESHOLD = 0.5
DEFAULT_MESSAGE_NO_MATCH = "I'm sorry, I can't find a relevant document to answer that question."
# --- Gradio UI Configuration ---
silksong_theme = gr.themes.Default(
primary_hue=gr.themes.colors.red,
secondary_hue=gr.themes.colors.zinc,
neutral_hue=gr.themes.colors.zinc,
font=[gr.themes.GoogleFont("IM Fell English"), "ui-sans-serif", "system-ui", "sans-serif"],
)
silksong_css="""
.gradio-container {
background-image: linear-gradient(rgba(255,255,255, 0.5), rgba(255, 255, 255, 1.0)), url("/gradio_api/file=assets/background.jpg");
background-size: 100%;
background-repeat: no-repeat;
background-position: top center;
}
body.dark .gradio-container {
background-image: linear-gradient(rgba(0, 0, 0, 0.5), rgba(0, 0, 0, 1.0)), url("/gradio_api/file=assets/background.jpg");
}
.header-text { text-align: center; text-shadow: 2px 2px 5px #000; }
.header-text h1 { font-size: 2.5em; color: #dc2626; }
.dark .header-text { text-shadow: 2px 2px 5px #FFF; }
.context { text-align: center; color: var(--body-text-color-subdued); }
.context a { color: #dc2626; }
.disclaimer { text-align: center; color: var(--body-text-color-subdued); font-size: 0.9em; padding: 20px; }
.disclaimer ul { list-style: none; padding: 0; }
.disclaimer a { color: #dc2626; }
"""