tugrulkaya commited on
Commit
ceb4795
·
verified ·
1 Parent(s): 1eeadfa

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +88 -107
app.py CHANGED
@@ -1,12 +1,12 @@
1
  import gradio as gr
2
  import torch
3
- from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
4
  from peft import PeftModel
5
  from qwen_vl_utils import process_vision_info
6
  from PIL import Image
7
  import os
8
 
9
- # EuroSAT sınıfları
10
  CLASS_DESCRIPTIONS = {
11
  "AnnualCrop": "🌾 Yıllık Tarım Alanı",
12
  "Forest": "🌲 Orman",
@@ -20,63 +20,75 @@ CLASS_DESCRIPTIONS = {
20
  "SeaLake": "🏞️ Deniz/Göl"
21
  }
22
 
23
- # GPU kontrolü
24
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
25
- print(f"Kullanılan cihaz: {DEVICE}")
26
 
27
- # Model yükleme
28
- print("Model yükleniyor...")
29
-
30
- if DEVICE == "cuda":
31
- # GPU varsa 4-bit quantization
32
- from transformers import BitsAndBytesConfig
33
- bnb_config = BitsAndBytesConfig(
34
- load_in_4bit=True,
35
- bnb_4bit_quant_type="nf4",
36
- bnb_4bit_compute_dtype=torch.float16
37
- )
38
- base_model = Qwen2VLForConditionalGeneration.from_pretrained(
39
- "Qwen/Qwen2-VL-2B-Instruct",
40
- quantization_config=bnb_config,
41
- device_map="auto",
42
- trust_remote_code=True
43
- )
44
- else:
45
- # CPU mode - offload ile
46
- os.makedirs("offload", exist_ok=True)
47
- base_model = Qwen2VLForConditionalGeneration.from_pretrained(
48
- "Qwen/Qwen2-VL-2B-Instruct",
49
- torch_dtype=torch.float32,
50
- device_map="auto",
51
- offload_folder="offload",
52
- trust_remote_code=True
53
- )
54
-
55
- model = PeftModel.from_pretrained(
56
- base_model,
57
- "tugrulkaya/GeoQwen-VL-2B-EuroSAT",
58
- offload_folder="offload" if DEVICE == "cpu" else None
59
- )
60
- model.eval()
 
61
 
62
- processor = AutoProcessor.from_pretrained(
63
- "Qwen/Qwen2-VL-2B-Instruct",
64
- trust_remote_code=True
65
- )
 
 
 
 
 
 
 
 
 
 
 
66
 
67
- print("Model başarıyla yüklendi!")
 
68
 
 
69
  def classify_satellite_image(image):
70
- """Uydu görüntüsünü sınıflandır"""
71
  if image is None:
72
  return "⚠️ Lütfen bir görüntü yükleyin.", ""
73
 
74
  try:
75
- # Görüntüyü PIL formatına çevir
76
  if not isinstance(image, Image.Image):
77
  image = Image.fromarray(image)
78
 
79
- # Mesaj hazırla
80
  messages = [
81
  {
82
  "role": "user",
@@ -87,7 +99,7 @@ def classify_satellite_image(image):
87
  }
88
  ]
89
 
90
- # İşleme
91
  text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
92
  image_inputs, video_inputs = process_vision_info(messages)
93
 
@@ -103,7 +115,7 @@ def classify_satellite_image(image):
103
  with torch.no_grad():
104
  generated_ids = model.generate(
105
  **inputs,
106
- max_new_tokens=20,
107
  do_sample=False
108
  )
109
 
@@ -118,87 +130,56 @@ def classify_satellite_image(image):
118
  clean_up_tokenization_spaces=False
119
  )[0].strip()
120
 
121
- # Sonucu formatla
122
- if result in CLASS_DESCRIPTIONS:
123
- formatted_result = CLASS_DESCRIPTIONS[result]
124
- confidence_text = f"**Tahmin:** {formatted_result}\n\n**Sınıf:** `{result}`"
 
 
125
  else:
126
- confidence_text = f"**Tahmin:** {result}"
 
127
 
128
- return confidence_text, result
129
 
130
  except Exception as e:
131
- return f"❌ Hata oluştu: {str(e)}", ""
132
 
133
- # Gradio arayüzü
134
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
135
 
136
  gr.HTML("""
137
  <div style="text-align: center; padding: 20px;">
138
- <h1 style="font-size: 2.5em; margin-bottom: 10px;">
139
- 🛰️ GeoQwen-VL-2B-EuroSAT
140
- </h1>
141
- <p style="font-size: 1.2em; color: #666;">
142
- Uydu Görüntülerinden Arazi Sınıflandırma
143
- </p>
144
  </div>
145
  """)
146
 
147
  with gr.Row():
148
- with gr.Column(scale=1):
149
- input_image = gr.Image(
150
- label="📷 Uydu Görüntüsü Yükle",
151
- type="pil",
152
- height=350
153
- )
154
-
155
- classify_btn = gr.Button(
156
- "🔍 Sınıflandır",
157
- variant="primary",
158
- size="lg"
159
- )
160
 
161
- gr.HTML("""
162
- <div style="margin-top: 15px; padding: 15px; background: #e8f4f8; border-radius: 8px;">
163
- <h4 style="margin: 0 0 10px 0;">📌 Desteklenen Sınıflar:</h4>
164
- <div style="display: grid; grid-template-columns: 1fr 1fr; gap: 5px; font-size: 0.9em;">
165
- <span>🌾 Yıllık Tarım</span>
166
- <span>🌲 Orman</span>
167
- <span>🌿 Otsu Bitki</span>
168
- <span>🛣️ Otoyol</span>
169
- <span>🏭 Endüstriyel</span>
170
- <span>🐄 Mera</span>
171
- <span>🍇 Kalıcı Tarım</span>
172
- <span>🏘️ Yerleşim</span>
173
- <span>🌊 Nehir</span>
174
- <span>🏞️ Deniz/Göl</span>
175
- </div>
176
- </div>
177
- """)
178
-
179
- with gr.Column(scale=1):
180
- output_text = gr.Markdown(
181
- label="Sonuç",
182
- value="*Bir görüntü yükleyip 'Sınıflandır' butonuna tıklayın...*"
183
- )
184
-
185
- output_class = gr.Textbox(
186
- label="Ham Çıktı (Raw Output)",
187
- interactive=False
188
  )
 
 
 
 
189
 
190
  gr.HTML("""
191
- <div style="text-align: center; margin-top: 30px; padding: 20px; border-top: 1px solid #eee;">
192
- <p style="color: #888; font-size: 0.9em;">
193
- 🤗 <a href="https://huggingface.co/tugrulkaya/GeoQwen-VL-2B-EuroSAT" target="_blank">tugrulkaya/GeoQwen-VL-2B-EuroSAT</a>
194
- </p>
195
  </div>
196
  """)
197
-
198
  classify_btn.click(
199
  fn=classify_satellite_image,
200
  inputs=[input_image],
201
- outputs=[output_text, output_class]
202
  )
203
 
204
  if __name__ == "__main__":
 
1
  import gradio as gr
2
  import torch
3
+ from transformers import Qwen2VLForConditionalGeneration, AutoProcessor, BitsAndBytesConfig
4
  from peft import PeftModel
5
  from qwen_vl_utils import process_vision_info
6
  from PIL import Image
7
  import os
8
 
9
+ # --- AYARLAR ---
10
  CLASS_DESCRIPTIONS = {
11
  "AnnualCrop": "🌾 Yıllık Tarım Alanı",
12
  "Forest": "🌲 Orman",
 
20
  "SeaLake": "🏞️ Deniz/Göl"
21
  }
22
 
 
23
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
24
+ print(f"🚀 Kullanılan cihaz: {DEVICE}")
25
 
26
+ # --- MODEL YÜKLEME ---
27
+ def load_model():
28
+ print("⏳ Model yükleniyor...")
29
+ try:
30
+ model_id = "Qwen/Qwen2-VL-2B-Instruct"
31
+ adapter_id = "tugrulkaya/GeoQwen-VL-2B-EuroSAT"
32
+
33
+ if DEVICE == "cuda":
34
+ # GPU Ayarları
35
+ bnb_config = BitsAndBytesConfig(
36
+ load_in_4bit=True,
37
+ bnb_4bit_quant_type="nf4",
38
+ bnb_4bit_compute_dtype=torch.float16
39
+ )
40
+ base_model = Qwen2VLForConditionalGeneration.from_pretrained(
41
+ model_id,
42
+ quantization_config=bnb_config,
43
+ device_map="auto",
44
+ trust_remote_code=True,
45
+ _attn_implementation="flash_attention_2" # Sadece GPU varsa
46
+ )
47
+ else:
48
+ # CPU Ayarları (Spaces Free Tier için Kritik)
49
+ # Offload klasörü oluştur
50
+ os.makedirs("offload", exist_ok=True)
51
+
52
+ base_model = Qwen2VLForConditionalGeneration.from_pretrained(
53
+ model_id,
54
+ torch_dtype=torch.float32,
55
+ device_map="auto",
56
+ offload_folder="offload",
57
+ trust_remote_code=True,
58
+ low_cpu_mem_usage=True,
59
+ _attn_implementation="eager" # <--- BU ÇOK ÖNEMLİ: CPU'da flash attn çalışmaz
60
+ )
61
 
62
+ # LoRA Adaptörünü Yükle
63
+ model = PeftModel.from_pretrained(
64
+ base_model,
65
+ adapter_id,
66
+ offload_folder="offload" if DEVICE == "cpu" else None
67
+ )
68
+ model.eval()
69
+
70
+ processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
71
+ print("✅ Model başarıyla yüklendi!")
72
+ return model, processor
73
+
74
+ except Exception as e:
75
+ print(f"❌ Model yükleme hatası: {str(e)}")
76
+ raise e
77
 
78
+ # Global değişkenler olarak yükle
79
+ model, processor = load_model()
80
 
81
+ # --- SINIFLANDIRMA FONKSİYONU ---
82
  def classify_satellite_image(image):
 
83
  if image is None:
84
  return "⚠️ Lütfen bir görüntü yükleyin.", ""
85
 
86
  try:
87
+ # Görüntü kontrolü
88
  if not isinstance(image, Image.Image):
89
  image = Image.fromarray(image)
90
 
91
+ # Prompt
92
  messages = [
93
  {
94
  "role": "user",
 
99
  }
100
  ]
101
 
102
+ # Hazırlık
103
  text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
104
  image_inputs, video_inputs = process_vision_info(messages)
105
 
 
115
  with torch.no_grad():
116
  generated_ids = model.generate(
117
  **inputs,
118
+ max_new_tokens=32,
119
  do_sample=False
120
  )
121
 
 
130
  clean_up_tokenization_spaces=False
131
  )[0].strip()
132
 
133
+ # Sonuç Temizleme (Bazen model nokta vb. ekleyebilir)
134
+ clean_result = result.replace('.', '').strip()
135
+
136
+ if clean_result in CLASS_DESCRIPTIONS:
137
+ formatted_result = CLASS_DESCRIPTIONS[clean_result]
138
+ display_text = f"### 🎯 Sonuç: {formatted_result}\n\n**Orijinal Sınıf:** `{clean_result}`"
139
  else:
140
+ display_text = f"### 🤖 Model Çıktısı: {result}"
141
+ clean_result = result
142
 
143
+ return display_text, clean_result
144
 
145
  except Exception as e:
146
+ return f"❌ Hata: {str(e)}", "Error"
147
 
148
+ # --- ARAYÜZ ---
149
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
150
 
151
  gr.HTML("""
152
  <div style="text-align: center; padding: 20px;">
153
+ <h1 style="font-size: 2.5em; margin-bottom: 10px;">🛰️ GeoQwen-VL-2B-EuroSAT</h1>
154
+ <p style="font-size: 1.2em; color: #666;">Uydu Görüntülerinden Arazi Sınıflandırma</p>
 
 
 
 
155
  </div>
156
  """)
157
 
158
  with gr.Row():
159
+ with gr.Column():
160
+ input_image = gr.Image(label="Uydu Görüntüsü Yükle", type="pil", height=300)
161
+ classify_btn = gr.Button("🔍 Sınıflandır", variant="primary", size="lg")
 
 
 
 
 
 
 
 
 
162
 
163
+ # Örnek sınıfları göster
164
+ gr.Examples(
165
+ examples=[], # Buraya örnek resim yolları eklenebilir
166
+ inputs=input_image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
167
  )
168
+
169
+ with gr.Column():
170
+ output_text = gr.Markdown(label="Analiz Sonucu", value="*Görüntü bekleniyor...*")
171
+ output_raw = gr.Textbox(label="Ham Çıktı", interactive=False)
172
 
173
  gr.HTML("""
174
+ <div style="margin-top: 20px; padding: 10px; background-color: #f0f0f0; border-radius: 5px;">
175
+ <p style="margin:0"><b>Not:</b> CPU üzerinde çalışıyorsa işlem 10-30 saniye sürebilir.</p>
 
 
176
  </div>
177
  """)
178
+
179
  classify_btn.click(
180
  fn=classify_satellite_image,
181
  inputs=[input_image],
182
+ outputs=[output_text, output_raw]
183
  )
184
 
185
  if __name__ == "__main__":