|
|
import os |
|
|
import pandas as pd |
|
|
import numpy as np |
|
|
from sentence_transformers import SentenceTransformer |
|
|
from sklearn.metrics.pairwise import cosine_similarity |
|
|
import gradio as gr |
|
|
from google import generativeai as genai |
|
|
|
|
|
API_KEY = os.getenv("GOOGLE_API_KEY") |
|
|
|
|
|
if API_KEY: |
|
|
genai.configure(api_key=API_KEY) |
|
|
print("API ν€κ° μ±κ³΅μ μΌλ‘ μ€μ λμμ΅λλ€.") |
|
|
else: |
|
|
raise ValueError("API ν€κ° μ€μ λμ§ μμμ΅λλ€. Hugging Face Spacesμ Repository secretsμ 'GOOGLE_API_KEY'λ₯Ό μ€μ ν΄μ£ΌμΈμ.") |
|
|
|
|
|
|
|
|
original_df = pd.read_csv('https://raw.githubusercontent.com/kairess/mental-health-chatbot/master/wellness_dataset_original.csv') |
|
|
original_df = original_df.drop(columns=['Unnamed: 3'], errors='ignore') |
|
|
original_df = original_df.dropna(subset=['μ μ ', 'μ±λ΄']) |
|
|
|
|
|
|
|
|
new_data_url = 'https://gist.githubusercontent.com/kimminchear/469d84e61bad0334b34a58a030e4a27a/raw/260bde0f335b2bb365a9837e6a6105a93b0b957d/2025_gpdba.csv' |
|
|
new_df = pd.read_csv(new_data_url) |
|
|
|
|
|
|
|
|
new_df = new_df.dropna(subset=['μ μ ', 'μ±λ΄']) |
|
|
|
|
|
|
|
|
df = pd.concat([original_df, new_df], ignore_index=True) |
|
|
|
|
|
model = SentenceTransformer('jhgan/ko-sbert-nli') |
|
|
|
|
|
print(f"μ΄ {len(df)}κ°μ μ§λ¬Έ-λ΅λ³ μμ λν΄ λ°μ΄ν°μ
μλ² λ©μ λ€μ κ³μ° μ€μ
λλ€. μ΄ κ³Όμ μ μκ°μ΄ μμλ©λλ€...") |
|
|
df['embedding'] = df['μ μ '].apply(lambda x: model.encode(x)) |
|
|
print("μλ² λ© κ³μ°μ΄ μλ£λμμ΅λλ€! μ΄μ μ±λ΄ μλ΅μ΄ ν¨μ¬ λΉ¨λΌμ§λλ€.") |
|
|
|
|
|
|
|
|
def call_gemini_api(question): |
|
|
try: |
|
|
llm_model = genai.GenerativeModel('gemini-2.0-flash') |
|
|
response = llm_model.generate_content(question) |
|
|
return response.text |
|
|
except Exception as e: |
|
|
print(f"API νΈμΆ μ€ μ€λ₯ λ°μ: {e}") |
|
|
return f"μ£μ‘ν©λλ€. API νΈμΆ μ€ μ€λ₯κ° λ°μνμ΅λλ€: {e}" |
|
|
|
|
|
COSINE_SIMILARITY_THRESHOLD = 0.7 |
|
|
|
|
|
def chatbot(user_question): |
|
|
try: |
|
|
user_embedding = model.encode(user_question) |
|
|
similarities = df['embedding'].apply(lambda x: cosine_similarity([user_embedding], [x])[0][0]) |
|
|
best_match_index = similarities.idxmax() |
|
|
best_score = similarities.loc[best_match_index] |
|
|
best_match_row = df.loc[best_match_index] |
|
|
|
|
|
if best_score >= COSINE_SIMILARITY_THRESHOLD: |
|
|
answer = best_match_row['μ±λ΄'] |
|
|
print(f"μ μ¬λ κΈ°λ° λ΅λ³. μ μ: {best_score}") |
|
|
return answer |
|
|
else: |
|
|
print(f"μ μ¬λ μκ³κ°({COSINE_SIMILARITY_THRESHOLD}) λ―Έλ§. Gemini λͺ¨λΈμ νΈμΆν©λλ€. μ μ: {best_score}") |
|
|
return call_gemini_api(user_question) |
|
|
except Exception as e: |
|
|
print(f"μ±λ΄ μ€ν μ€ μ€λ₯ λ°μ: {e}") |
|
|
return f"μ£μ‘ν©λλ€. μ±λ΄ μ€ν μ€ μ€λ₯κ° λ°μνμ΅λλ€: {e}" |
|
|
|
|
|
demo = gr.Interface( |
|
|
fn=chatbot, |
|
|
inputs=gr.Textbox(lines=2, placeholder="μ§λ¬Έμ μ
λ ₯ν΄ μ£ΌμΈμ...", label="μ§λ¬Έ", elem_id="user_question_input"), |
|
|
outputs=gr.Textbox(lines=5, label="μ±λ΄ λ΅λ³"), |
|
|
title="λλ μλ΄ μ±λ΄ 2.0", |
|
|
description="λ°μ΄ν°μ
μ΄ νμ₯λμ΄ λ λ§μ μ§λ¬Έμ λν΄ μ νν λ΅λ³μ μ 곡ν©λλ€. 5λΆ λμ λννμ¬ μ£Όμκ³ λ€μμ λ§ν¬λ₯Ό ν΄λ¦νμ¬ κΌ μ€λ¬Έμ‘°μ¬μ μ°Έμ¬ν΄μ£ΌμΈμ! https://forms.gle/eWtyejQaQntKbbxG8" |
|
|
) |
|
|
|
|
|
demo.launch(server_name="0.0.0.0", server_port=7860, share=False) |