messy092 commited on
Commit
2ee7fcd
ยท
verified ยท
1 Parent(s): 9257a99

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -4
app.py CHANGED
@@ -6,6 +6,7 @@ from sklearn.metrics.pairwise import cosine_similarity
6
  import gradio as gr
7
  from google import generativeai as genai
8
 
 
9
  API_KEY = os.getenv("GOOGLE_API_KEY")
10
 
11
  if API_KEY:
@@ -14,51 +15,69 @@ if API_KEY:
14
  else:
15
  raise ValueError("API ํ‚ค๊ฐ€ ์„ค์ •๋˜์ง€ ์•Š์•˜์Šต๋‹ˆ๋‹ค. Hugging Face Spaces์˜ Repository secrets์— 'GOOGLE_API_KEY'๋ฅผ ์„ค์ •ํ•ด์ฃผ์„ธ์š”.")
16
 
 
17
  df = pd.read_csv('https://raw.githubusercontent.com/kairess/mental-health-chatbot/master/wellness_dataset_original.csv')
18
  df = df.drop(columns=['Unnamed: 3'], errors='ignore')
19
  df = df.dropna(subset=['์œ ์ €', '์ฑ—๋ด‡'])
20
 
 
21
  model = SentenceTransformer('jhgan/ko-sbert-nli')
22
 
23
  print("๋ฐ์ดํ„ฐ์…‹ ์ž„๋ฒ ๋”ฉ์„ ๋ฏธ๋ฆฌ ๊ณ„์‚ฐ ์ค‘์ž…๋‹ˆ๋‹ค. ์ด ๊ณผ์ •์€ ์‹œ๊ฐ„์ด ์†Œ์š”๋ฉ๋‹ˆ๋‹ค...")
 
24
  df['embedding'] = df['์œ ์ €'].apply(lambda x: model.encode(x))
25
  print("์ž„๋ฒ ๋”ฉ ๊ณ„์‚ฐ์ด ์™„๋ฃŒ๋˜์—ˆ์Šต๋‹ˆ๋‹ค! ์ด์ œ ์ฑ—๋ด‡ ์‘๋‹ต์ด ํ›จ์”ฌ ๋นจ๋ผ์ง‘๋‹ˆ๋‹ค.")
26
 
 
27
  def call_gemini_api(question):
 
28
  try:
29
- llm_model = genai.GenerativeModel('gemini-2.5')
30
  response = llm_model.generate_content(question)
31
  return response.text
32
  except Exception as e:
33
  print(f"API ํ˜ธ์ถœ ์ค‘ ์˜ค๋ฅ˜ ๋ฐœ์ƒ: {e}")
34
- return (f"์ฃ„์†กํ•ฉ๋‹ˆ๋‹ค. API ํ˜ธ์ถœ ์ค‘ ์˜ค๋ฅ˜๊ฐ€ ๋ฐœ์ƒํ–ˆ์Šต๋‹ˆ๋‹ค: {e}")
35
 
36
- COSINE_SIMILARITY_THRESHOLD = 0.75
 
37
 
38
  def chatbot(user_question):
 
39
  try:
 
40
  user_embedding = model.encode(user_question)
 
 
41
  similarities = df['embedding'].apply(lambda x: cosine_similarity([user_embedding], [x])[0][0])
 
 
42
  best_match_index = similarities.idxmax()
43
  best_score = similarities.loc[best_match_index]
44
  best_match_row = df.loc[best_match_index]
45
 
 
46
  if best_score >= COSINE_SIMILARITY_THRESHOLD:
 
47
  answer = best_match_row['์ฑ—๋ด‡']
48
  print(f"์œ ์‚ฌ๋„ ๊ธฐ๋ฐ˜ ๋‹ต๋ณ€. ์ ์ˆ˜: {best_score}")
49
  return answer
50
  else:
 
51
  print(f"์œ ์‚ฌ๋„ ์ž„๊ณ„๊ฐ’({COSINE_SIMILARITY_THRESHOLD}) ๋ฏธ๋งŒ. Gemini ๋ชจ๋ธ์„ ํ˜ธ์ถœํ•ฉ๋‹ˆ๋‹ค. ์ ์ˆ˜: {best_score}")
52
  return call_gemini_api(user_question)
 
53
  except Exception as e:
54
  print(f"์ฑ—๋ด‡ ์‹คํ–‰ ์ค‘ ์˜ค๋ฅ˜ ๋ฐœ์ƒ: {e}")
55
  return f"์ฃ„์†กํ•ฉ๋‹ˆ๋‹ค. ์ฑ—๋ด‡ ์‹คํ–‰ ์ค‘ ์˜ค๋ฅ˜๊ฐ€ ๋ฐœ์ƒํ–ˆ์Šต๋‹ˆ๋‹ค: {e}"
56
 
 
57
  demo = gr.Interface(
58
  fn=chatbot,
59
  inputs=gr.Textbox(lines=2, placeholder="์งˆ๋ฌธ์„ ์ž…๋ ฅํ•ด ์ฃผ์„ธ์š”...", label="์งˆ๋ฌธ", elem_id="user_question_input"),
60
  outputs=gr.Textbox(lines=5, label="์ฑ—๋ด‡ ๋‹ต๋ณ€"),
61
  title="๋˜๋ž˜ ์ƒ๋‹ด ์ฑ—๋ด‡",
62
- description="5๋ถ„ ๋™์•ˆ ๋Œ€ํ™”ํ•˜์—ฌ ์ฃผ์‹œ๊ณ  ๋‹ค์Œ์˜ ๋งํฌ๋ฅผ ํด๋ฆญํ•˜์—ฌ ๊ผญ ์„ค๋ฌธ์กฐ์‚ฌ์— ์ฐธ์—ฌํ•ด์ฃผ์„ธ์š”! https://forms.gle/eWtyejQaQntKbbxG8")
 
63
 
64
  demo.launch(server_name="0.0.0.0", server_port=7860, share=False)
 
6
  import gradio as gr
7
  from google import generativeai as genai
8
 
9
+ # --- 1. ํ™˜๊ฒฝ ์„ค์ • ๋ฐ ๋ฐ์ดํ„ฐ ๋กœ๋“œ ---
10
  API_KEY = os.getenv("GOOGLE_API_KEY")
11
 
12
  if API_KEY:
 
15
  else:
16
  raise ValueError("API ํ‚ค๊ฐ€ ์„ค์ •๋˜์ง€ ์•Š์•˜์Šต๋‹ˆ๋‹ค. Hugging Face Spaces์˜ Repository secrets์— 'GOOGLE_API_KEY'๋ฅผ ์„ค์ •ํ•ด์ฃผ์„ธ์š”.")
17
 
18
+ # ๋ฐ์ดํ„ฐ ๋กœ๋“œ ๋ฐ ์ „์ฒ˜๋ฆฌ
19
  df = pd.read_csv('https://raw.githubusercontent.com/kairess/mental-health-chatbot/master/wellness_dataset_original.csv')
20
  df = df.drop(columns=['Unnamed: 3'], errors='ignore')
21
  df = df.dropna(subset=['์œ ์ €', '์ฑ—๋ด‡'])
22
 
23
+ # --- 2. ๋ชจ๋ธ ๋กœ๋“œ ๋ฐ ์ž„๋ฒ ๋”ฉ ๊ณ„์‚ฐ ---
24
  model = SentenceTransformer('jhgan/ko-sbert-nli')
25
 
26
  print("๋ฐ์ดํ„ฐ์…‹ ์ž„๋ฒ ๋”ฉ์„ ๋ฏธ๋ฆฌ ๊ณ„์‚ฐ ์ค‘์ž…๋‹ˆ๋‹ค. ์ด ๊ณผ์ •์€ ์‹œ๊ฐ„์ด ์†Œ์š”๋ฉ๋‹ˆ๋‹ค...")
27
+ # '์œ ์ €' ์ปฌ๋Ÿผ์˜ ํ…์ŠคํŠธ๋ฅผ ์ž„๋ฒ ๋”ฉ ๋ฒกํ„ฐ๋กœ ๋ณ€ํ™˜ํ•˜์—ฌ ์ €์žฅ
28
  df['embedding'] = df['์œ ์ €'].apply(lambda x: model.encode(x))
29
  print("์ž„๋ฒ ๋”ฉ ๊ณ„์‚ฐ์ด ์™„๋ฃŒ๋˜์—ˆ์Šต๋‹ˆ๋‹ค! ์ด์ œ ์ฑ—๋ด‡ ์‘๋‹ต์ด ํ›จ์”ฌ ๋นจ๋ผ์ง‘๋‹ˆ๋‹ค.")
30
 
31
+ # --- 3. Gemini API ํ˜ธ์ถœ ํ•จ์ˆ˜ ---
32
  def call_gemini_api(question):
33
+ """์œ ์‚ฌ๋„ ์ž„๊ณ„๊ฐ’ ๋ฏธ๋งŒ ์‹œ Gemini ๋ชจ๋ธ์„ ํ˜ธ์ถœํ•˜์—ฌ ์‘๋‹ต์„ ์ƒ์„ฑํ•ฉ๋‹ˆ๋‹ค."""
34
  try:
35
+ llm_model = genai.GenerativeModel('gemini-2.0-flash')
36
  response = llm_model.generate_content(question)
37
  return response.text
38
  except Exception as e:
39
  print(f"API ํ˜ธ์ถœ ์ค‘ ์˜ค๋ฅ˜ ๋ฐœ์ƒ: {e}")
40
+ return f"์ฃ„์†กํ•ฉ๋‹ˆ๋‹ค. API ํ˜ธ์ถœ ์ค‘ ์˜ค๋ฅ˜๊ฐ€ ๋ฐœ์ƒํ–ˆ์Šต๋‹ˆ๋‹ค: {e}"
41
 
42
+ # --- 4. ์ฑ—๋ด‡ ํ•ต์‹ฌ ๋กœ์ง ํ•จ์ˆ˜ ---
43
+ COSINE_SIMILARITY_THRESHOLD = 0.7
44
 
45
  def chatbot(user_question):
46
+ """์‚ฌ์šฉ์ž ์งˆ๋ฌธ์— ๋Œ€ํ•ด ์œ ์‚ฌ๋„ ๊ฒ€์ƒ‰ ํ›„, ์—†์œผ๋ฉด Gemini ๋ชจ๋ธ์„ ํ˜ธ์ถœํ•ฉ๋‹ˆ๋‹ค."""
47
  try:
48
+ # ์‚ฌ์šฉ์ž ์งˆ๋ฌธ ์ž„๋ฒ ๋”ฉ
49
  user_embedding = model.encode(user_question)
50
+
51
+ # ๋ฐ์ดํ„ฐ์…‹ ์ „์ฒด์™€ ์ฝ”์‚ฌ์ธ ์œ ์‚ฌ๋„ ๊ณ„์‚ฐ
52
  similarities = df['embedding'].apply(lambda x: cosine_similarity([user_embedding], [x])[0][0])
53
+
54
+ # ๊ฐ€์žฅ ์œ ์‚ฌํ•œ ๋‹ต๋ณ€ ์ฐพ๊ธฐ
55
  best_match_index = similarities.idxmax()
56
  best_score = similarities.loc[best_match_index]
57
  best_match_row = df.loc[best_match_index]
58
 
59
+ # ์œ ์‚ฌ๋„ ์ž„๊ณ„๊ฐ’ ๋น„๊ต
60
  if best_score >= COSINE_SIMILARITY_THRESHOLD:
61
+ # ์œ ์‚ฌ๋„ ๊ธฐ๋ฐ˜ ๋‹ต๋ณ€ (Retrieval-based)
62
  answer = best_match_row['์ฑ—๋ด‡']
63
  print(f"์œ ์‚ฌ๋„ ๊ธฐ๋ฐ˜ ๋‹ต๋ณ€. ์ ์ˆ˜: {best_score}")
64
  return answer
65
  else:
66
+ # Gemini ๋ชจ๋ธ ํ˜ธ์ถœ (Generative-based)
67
  print(f"์œ ์‚ฌ๋„ ์ž„๊ณ„๊ฐ’({COSINE_SIMILARITY_THRESHOLD}) ๋ฏธ๋งŒ. Gemini ๋ชจ๋ธ์„ ํ˜ธ์ถœํ•ฉ๋‹ˆ๋‹ค. ์ ์ˆ˜: {best_score}")
68
  return call_gemini_api(user_question)
69
+
70
  except Exception as e:
71
  print(f"์ฑ—๋ด‡ ์‹คํ–‰ ์ค‘ ์˜ค๋ฅ˜ ๋ฐœ์ƒ: {e}")
72
  return f"์ฃ„์†กํ•ฉ๋‹ˆ๋‹ค. ์ฑ—๋ด‡ ์‹คํ–‰ ์ค‘ ์˜ค๋ฅ˜๊ฐ€ ๋ฐœ์ƒํ–ˆ์Šต๋‹ˆ๋‹ค: {e}"
73
 
74
+ # --- 5. Gradio ์ธํ„ฐํŽ˜์ด์Šค ์‹คํ–‰ ---
75
  demo = gr.Interface(
76
  fn=chatbot,
77
  inputs=gr.Textbox(lines=2, placeholder="์งˆ๋ฌธ์„ ์ž…๋ ฅํ•ด ์ฃผ์„ธ์š”...", label="์งˆ๋ฌธ", elem_id="user_question_input"),
78
  outputs=gr.Textbox(lines=5, label="์ฑ—๋ด‡ ๋‹ต๋ณ€"),
79
  title="๋˜๋ž˜ ์ƒ๋‹ด ์ฑ—๋ด‡",
80
+ description="5๋ถ„ ๋™์•ˆ ๋Œ€ํ™”ํ•˜์—ฌ ์ฃผ์‹œ๊ณ  ๋‹ค์Œ์˜ ๋งํฌ๋ฅผ ํด๋ฆญํ•˜์—ฌ ๊ผญ ์„ค๋ฌธ์กฐ์‚ฌ์— ์ฐธ์—ฌํ•ด์ฃผ์„ธ์š”! https://forms.gle/eWtyejQaQntKbbxG8"
81
+ )
82
 
83
  demo.launch(server_name="0.0.0.0", server_port=7860, share=False)