25chatbot / app.py
messy092's picture
Update app.py
edd4f90 verified
raw
history blame
3.75 kB
import os
import pandas as pd
import numpy as np
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
import gradio as gr
from google import generativeai as genai
API_KEY = os.getenv("GOOGLE_API_KEY")
if API_KEY:
genai.configure(api_key=API_KEY)
print("API ν‚€κ°€ μ„±κ³΅μ μœΌλ‘œ μ„€μ •λ˜μ—ˆμŠ΅λ‹ˆλ‹€.")
else:
raise ValueError("API ν‚€κ°€ μ„€μ •λ˜μ§€ μ•Šμ•˜μŠ΅λ‹ˆλ‹€. Hugging Face Spaces의 Repository secrets에 'GOOGLE_API_KEY'λ₯Ό μ„€μ •ν•΄μ£Όμ„Έμš”.")
# 1. κΈ°μ‘΄ 데이터 λ‘œλ“œ 및 μ „μ²˜λ¦¬
original_df = pd.read_csv('https://raw.githubusercontent.com/kairess/mental-health-chatbot/master/wellness_dataset_original.csv')
original_df = original_df.drop(columns=['Unnamed: 3'], errors='ignore')
original_df = original_df.dropna(subset=['μœ μ €', '챗봇'])
# 2. μƒˆλ‘œμš΄ 데이터셋 λ‘œλ“œ 및 μ „μ²˜λ¦¬
new_data_url = 'https://gist.githubusercontent.com/kimminchear/469d84e61bad0334b34a58a030e4a27a/raw/260bde0f335b2bb365a9837e6a6105a93b0b957d/2025_gpdba.csv'
new_df = pd.read_csv(new_data_url)
# μƒˆλ‘œμš΄ λ°μ΄ν„°μ…‹μ˜ 컬럼λͺ…이 'μœ μ €'와 '챗봇'κ³Ό λ‹€λ₯Ό 경우, 여기에 λ§žμΆ°μ„œ λ³€κ²½ν•΄μ•Ό ν•©λ‹ˆλ‹€.
# μ˜ˆμ‹œ: new_df = new_df.rename(columns={'질문': 'μœ μ €', 'λ‹΅λ³€': '챗봇'})
new_df = new_df.dropna(subset=['μœ μ €', '챗봇']) # κ²°μΈ‘κ°’ 제거
# 3. 두 λ°μ΄ν„°ν”„λ ˆμž„ 병합 (concatenate)
df = pd.concat([original_df, new_df], ignore_index=True)
model = SentenceTransformer('jhgan/ko-sbert-nli')
print(f"총 {len(df)}개의 질문-λ‹΅λ³€ μŒμ— λŒ€ν•΄ 데이터셋 μž„λ² λ”©μ„ λ‹€μ‹œ 계산 μ€‘μž…λ‹ˆλ‹€. 이 과정은 μ‹œκ°„μ΄ μ†Œμš”λ©λ‹ˆλ‹€...")
df['embedding'] = df['μœ μ €'].apply(lambda x: model.encode(x))
print("μž„λ² λ”© 계산이 μ™„λ£Œλ˜μ—ˆμŠ΅λ‹ˆλ‹€! 이제 챗봇 응닡이 훨씬 λΉ¨λΌμ§‘λ‹ˆλ‹€.")
# λ‚˜λ¨Έμ§€ ν•¨μˆ˜λŠ” λ™μΌν•©λ‹ˆλ‹€.
def call_gemini_api(question):
try:
llm_model = genai.GenerativeModel('gemini-2.0-flash')
response = llm_model.generate_content(question)
return response.text
except Exception as e:
print(f"API 호좜 쀑 였λ₯˜ λ°œμƒ: {e}")
return f"μ£„μ†‘ν•©λ‹ˆλ‹€. API 호좜 쀑 였λ₯˜κ°€ λ°œμƒν–ˆμŠ΅λ‹ˆλ‹€: {e}"
COSINE_SIMILARITY_THRESHOLD = 0.7
def chatbot(user_question):
try:
user_embedding = model.encode(user_question)
similarities = df['embedding'].apply(lambda x: cosine_similarity([user_embedding], [x])[0][0])
best_match_index = similarities.idxmax()
best_score = similarities.loc[best_match_index]
best_match_row = df.loc[best_match_index]
if best_score >= COSINE_SIMILARITY_THRESHOLD:
answer = best_match_row['챗봇']
print(f"μœ μ‚¬λ„ 기반 λ‹΅λ³€. 점수: {best_score}")
return answer
else:
print(f"μœ μ‚¬λ„ μž„κ³„κ°’({COSINE_SIMILARITY_THRESHOLD}) 미만. Gemini λͺ¨λΈμ„ ν˜ΈμΆœν•©λ‹ˆλ‹€. 점수: {best_score}")
return call_gemini_api(user_question)
except Exception as e:
print(f"챗봇 μ‹€ν–‰ 쀑 였λ₯˜ λ°œμƒ: {e}")
return f"μ£„μ†‘ν•©λ‹ˆλ‹€. 챗봇 μ‹€ν–‰ 쀑 였λ₯˜κ°€ λ°œμƒν–ˆμŠ΅λ‹ˆλ‹€: {e}"
demo = gr.Interface(
fn=chatbot,
inputs=gr.Textbox(lines=2, placeholder="μ§ˆλ¬Έμ„ μž…λ ₯ν•΄ μ£Όμ„Έμš”...", label="질문", elem_id="user_question_input"),
outputs=gr.Textbox(lines=5, label="챗봇 λ‹΅λ³€"),
title="또래 상담 챗봇 2.0",
description="데이터셋이 ν™•μž₯λ˜μ–΄ 더 λ§Žμ€ μ§ˆλ¬Έμ— λŒ€ν•΄ μ •ν™•ν•œ 닡변을 μ œκ³΅ν•©λ‹ˆλ‹€. 5λΆ„ λ™μ•ˆ λŒ€ν™”ν•˜μ—¬ μ£Όμ‹œκ³  λ‹€μŒμ˜ 링크λ₯Ό ν΄λ¦­ν•˜μ—¬ κΌ­ 섀문쑰사에 μ°Έμ—¬ν•΄μ£Όμ„Έμš”! https://forms.gle/eWtyejQaQntKbbxG8"
)
demo.launch(server_name="0.0.0.0", server_port=7860, share=False)