messy092 commited on
Commit
2936872
ยท
verified ยท
1 Parent(s): 3d448e7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -12
app.py CHANGED
@@ -1,49 +1,64 @@
 
 
 
 
 
1
  import gradio as gr
2
  from google import generativeai as genai
3
 
4
-
5
  API_KEY = os.getenv("GOOGLE_API_KEY")
6
 
7
  if API_KEY:
 
 
8
  else:
9
  raise ValueError("API ํ‚ค๊ฐ€ ์„ค์ •๋˜์ง€ ์•Š์•˜์Šต๋‹ˆ๋‹ค. Hugging Face Spaces์˜ Repository secrets์— 'GOOGLE_API_KEY'๋ฅผ ์„ค์ •ํ•ด์ฃผ์„ธ์š”.")
10
 
11
-
12
  df = pd.read_csv('https://raw.githubusercontent.com/kairess/mental-health-chatbot/master/wellness_dataset_original.csv')
13
  df = df.drop(columns=['Unnamed: 3'], errors='ignore')
14
  df = df.dropna(subset=['์œ ์ €', '์ฑ—๋ด‡'])
15
 
16
-
17
  model = SentenceTransformer('jhgan/ko-sbert-nli')
18
 
19
-
20
  print("๋ฐ์ดํ„ฐ์…‹ ์ž„๋ฒ ๋”ฉ์„ ๋ฏธ๋ฆฌ ๊ณ„์‚ฐ ์ค‘์ž…๋‹ˆ๋‹ค. ์ด ๊ณผ์ •์€ ์‹œ๊ฐ„์ด ์†Œ์š”๋ฉ๋‹ˆ๋‹ค...")
21
  df['embedding'] = df['์œ ์ €'].apply(lambda x: model.encode(x))
22
  print("์ž„๋ฒ ๋”ฉ ๊ณ„์‚ฐ์ด ์™„๋ฃŒ๋˜์—ˆ์Šต๋‹ˆ๋‹ค! ์ด์ œ ์ฑ—๋ด‡ ์‘๋‹ต์ด ํ›จ์”ฌ ๋นจ๋ผ์ง‘๋‹ˆ๋‹ค.")
23
 
24
-
25
  def call_gemini_api(question):
26
  try:
27
  llm_model = genai.GenerativeModel('gemini-2.5')
 
 
 
28
  print(f"API ํ˜ธ์ถœ ์ค‘ ์˜ค๋ฅ˜ ๋ฐœ์ƒ: {e}")
29
  return f"์ฃ„์†กํ•ฉ๋‹ˆ๋‹ค. API ํ˜ธ์ถœ ์ค‘ ์˜ค๋ฅ˜๊ฐ€ ๋ฐœ์ƒํ–ˆ์Šต๋‹ˆ๋‹ค: {e}"
30
 
31
-
32
- COSINE_SIMILARITY_THRESHOLD = 0.8
33
-
34
 
35
  def chatbot(user_question):
36
  try:
37
  user_embedding = model.encode(user_question)
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  print(f"์ฑ—๋ด‡ ์‹คํ–‰ ์ค‘ ์˜ค๋ฅ˜ ๋ฐœ์ƒ: {e}")
39
  return f"์ฃ„์†กํ•ฉ๋‹ˆ๋‹ค. ์ฑ—๋ด‡ ์‹คํ–‰ ์ค‘ ์˜ค๋ฅ˜๊ฐ€ ๋ฐœ์ƒํ–ˆ์Šต๋‹ˆ๋‹ค: {e}"
40
 
41
-
42
  demo = gr.Interface(
43
  fn=chatbot,
44
  inputs=gr.Textbox(lines=2, placeholder="์งˆ๋ฌธ์„ ์ž…๋ ฅํ•ด ์ฃผ์„ธ์š”...", label="์งˆ๋ฌธ", elem_id="user_question_input"),
45
- description="5๋ถ„ ๋™์•ˆ ๋Œ€ํ™”ํ•˜์—ฌ ์ฃผ์‹œ๊ณ  ๋‹ค์Œ์˜ ๋งํฌ๋ฅผ ํด๋ฆญํ•˜์—ฌ ๊ผญ ์„ค๋ฌธ์กฐ์‚ฌ์— ์ฐธ์—ฌํ•ด์ฃผ์„ธ์š”! https://forms.gle/eWtyejQaQntKbbxG8"
46
- )
47
-
48
 
49
  demo.launch(server_name="0.0.0.0", server_port=7860, share=False)
 
1
+ import os
2
+ import pandas as pd
3
+ import numpy as np
4
+ from sentence_transformers import SentenceTransformer
5
+ from sklearn.metrics.pairwise import cosine_similarity
6
  import gradio as gr
7
  from google import generativeai as genai
8
 
 
9
  API_KEY = os.getenv("GOOGLE_API_KEY")
10
 
11
  if API_KEY:
12
+ genai.configure(api_key=API_KEY)
13
+ print("API ํ‚ค๊ฐ€ ์„ฑ๊ณต์ ์œผ๋กœ ์„ค์ •๋˜์—ˆ์Šต๋‹ˆ๋‹ค.")
14
  else:
15
  raise ValueError("API ํ‚ค๊ฐ€ ์„ค์ •๋˜์ง€ ์•Š์•˜์Šต๋‹ˆ๋‹ค. Hugging Face Spaces์˜ Repository secrets์— 'GOOGLE_API_KEY'๋ฅผ ์„ค์ •ํ•ด์ฃผ์„ธ์š”.")
16
 
 
17
  df = pd.read_csv('https://raw.githubusercontent.com/kairess/mental-health-chatbot/master/wellness_dataset_original.csv')
18
  df = df.drop(columns=['Unnamed: 3'], errors='ignore')
19
  df = df.dropna(subset=['์œ ์ €', '์ฑ—๋ด‡'])
20
 
 
21
  model = SentenceTransformer('jhgan/ko-sbert-nli')
22
 
 
23
  print("๋ฐ์ดํ„ฐ์…‹ ์ž„๋ฒ ๋”ฉ์„ ๋ฏธ๋ฆฌ ๊ณ„์‚ฐ ์ค‘์ž…๋‹ˆ๋‹ค. ์ด ๊ณผ์ •์€ ์‹œ๊ฐ„์ด ์†Œ์š”๋ฉ๋‹ˆ๋‹ค...")
24
  df['embedding'] = df['์œ ์ €'].apply(lambda x: model.encode(x))
25
  print("์ž„๋ฒ ๋”ฉ ๊ณ„์‚ฐ์ด ์™„๋ฃŒ๋˜์—ˆ์Šต๋‹ˆ๋‹ค! ์ด์ œ ์ฑ—๋ด‡ ์‘๋‹ต์ด ํ›จ์”ฌ ๋นจ๋ผ์ง‘๋‹ˆ๋‹ค.")
26
 
 
27
  def call_gemini_api(question):
28
  try:
29
  llm_model = genai.GenerativeModel('gemini-2.5')
30
+ response = llm_model.generate_content(question)
31
+ return response.text
32
+ except Exception as e:
33
  print(f"API ํ˜ธ์ถœ ์ค‘ ์˜ค๋ฅ˜ ๋ฐœ์ƒ: {e}")
34
  return f"์ฃ„์†กํ•ฉ๋‹ˆ๋‹ค. API ํ˜ธ์ถœ ์ค‘ ์˜ค๋ฅ˜๊ฐ€ ๋ฐœ์ƒํ–ˆ์Šต๋‹ˆ๋‹ค: {e}"
35
 
36
+ COSINE_SIMILARITY_THRESHOLD = 0.75
 
 
37
 
38
  def chatbot(user_question):
39
  try:
40
  user_embedding = model.encode(user_question)
41
+ similarities = df['embedding'].apply(lambda x: cosine_similarity([user_embedding], [x])[0][0])
42
+ best_match_index = similarities.idxmax()
43
+ best_score = similarities.loc[best_match_index]
44
+ best_match_row = df.loc[best_match_index]
45
+
46
+ if best_score >= COSINE_SIMILARITY_THRESHOLD:
47
+ answer = best_match_row['์ฑ—๋ด‡']
48
+ print(f"์œ ์‚ฌ๋„ ๊ธฐ๋ฐ˜ ๋‹ต๋ณ€. ์ ์ˆ˜: {best_score}")
49
+ return answer
50
+ else:
51
+ print(f"์œ ์‚ฌ๋„ ์ž„๊ณ„๊ฐ’({COSINE_SIMILARITY_THRESHOLD}) ๋ฏธ๋งŒ. Gemini ๋ชจ๋ธ์„ ํ˜ธ์ถœํ•ฉ๋‹ˆ๋‹ค. ์ ์ˆ˜: {best_score}")
52
+ return call_gemini_api(user_question)
53
+ except Exception as e:
54
  print(f"์ฑ—๋ด‡ ์‹คํ–‰ ์ค‘ ์˜ค๋ฅ˜ ๋ฐœ์ƒ: {e}")
55
  return f"์ฃ„์†กํ•ฉ๋‹ˆ๋‹ค. ์ฑ—๋ด‡ ์‹คํ–‰ ์ค‘ ์˜ค๋ฅ˜๊ฐ€ ๋ฐœ์ƒํ–ˆ์Šต๋‹ˆ๋‹ค: {e}"
56
 
 
57
  demo = gr.Interface(
58
  fn=chatbot,
59
  inputs=gr.Textbox(lines=2, placeholder="์งˆ๋ฌธ์„ ์ž…๋ ฅํ•ด ์ฃผ์„ธ์š”...", label="์งˆ๋ฌธ", elem_id="user_question_input"),
60
+ outputs=gr.Textbox(lines=5, label="์ฑ—๋ด‡ ๋‹ต๋ณ€"),
61
+ title="๋˜๋ž˜ ์ƒ๋‹ด ์ฑ—๋ด‡",
62
+ description="5๋ถ„ ๋™์•ˆ ๋Œ€ํ™”ํ•˜์—ฌ ์ฃผ์‹œ๊ณ  ๋‹ค์Œ์˜ ๋งํฌ๋ฅผ ํด๋ฆญํ•˜์—ฌ ๊ผญ ์„ค๋ฌธ์กฐ์‚ฌ์— ์ฐธ์—ฌํ•ด์ฃผ์„ธ์š”! https://forms.gle/eWtyejQaQntKbbxG8")
63
 
64
  demo.launch(server_name="0.0.0.0", server_port=7860, share=False)