prithivMLmods commited on
Commit
d1a91c5
·
verified ·
1 Parent(s): c7a140a

update app

Browse files
Files changed (1) hide show
  1. app.py +118 -112
app.py CHANGED
@@ -20,7 +20,7 @@ DTYPE = "auto"
20
  CATEGORIES = ["Query", "Caption", "Point", "Detect"]
21
  PLACEHOLDERS = {
22
  "Query": "What's in this image?",
23
- "Caption": "Select caption length: short, normal, or long",
24
  "Point": "Select an object from suggestions or enter manually",
25
  "Detect": "Select an object from suggestions or enter manually",
26
  }
@@ -39,9 +39,7 @@ qwen_processor = Qwen3VLProcessor.from_pretrained(
39
 
40
  # --- Utility Functions ---
41
  def safe_parse_json(text: str):
42
- """Safely parse a string that may be JSON or a Python literal."""
43
  text = text.strip()
44
- # Remove markdown code blocks
45
  text = re.sub(r"^```(json)?", "", text)
46
  text = re.sub(r"```$", "", text)
47
  text = text.strip()
@@ -50,127 +48,142 @@ def safe_parse_json(text: str):
50
  except json.JSONDecodeError:
51
  pass
52
  try:
53
- # Fallback to literal_eval for Python-like dictionary/list strings
54
  return ast.literal_eval(text)
55
  except Exception:
56
  return {}
57
 
58
- # --- Inference Functions ---
59
- def run_qwen_inference(image: Image.Image, prompt: str):
60
- """Core function to run inference with the Qwen model."""
61
- messages = [
62
- {
63
- "role": "user",
64
- "content": [
65
- {"type": "image", "image": image},
66
- {"type": "text", "text": prompt},
67
- ],
68
- }
69
- ]
70
- inputs = qwen_processor.apply_chat_template(
71
- messages,
72
- tokenize=True,
73
- add_generation_prompt=True,
74
- return_dict=True,
75
- return_tensors="pt",
76
- ).to(DEVICE)
77
-
78
- with torch.inference_mode():
79
- generated_ids = qwen_model.generate(
80
- **inputs,
81
- max_new_tokens=512,
82
- )
83
-
84
- generated_ids_trimmed = [
85
- out_ids[len(in_ids) :]
86
- for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
87
- ]
88
- output_text = qwen_processor.batch_decode(
89
- generated_ids_trimmed,
90
- skip_special_tokens=True,
91
- clean_up_tokenization_spaces=False,
92
- )[0]
93
- return output_text
94
-
95
 
96
  @GPU
97
  def get_suggested_objects(image: Image.Image):
98
- """Get suggested objects in the image using Qwen."""
99
  if image is None:
100
  return []
 
101
  try:
102
- # Resize image for faster suggestion generation
103
- suggest_image = image.copy()
104
- suggest_image.thumbnail((512, 512))
105
-
106
- prompt = "List the main objects in the image in a Python list format. For example: ['cat', 'dog', 'table']"
107
- result_text = run_qwen_inference(suggest_image, prompt)
108
-
109
- # Clean up the output to find the list
110
- match = re.search(r'\[.*?\]', result_text)
111
- if match:
112
- suggested_objects = ast.literal_eval(match.group())
113
- if isinstance(suggested_objects, list):
114
- # Return up to 3 suggestions
115
- return suggested_objects[:3]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
  return []
117
  except Exception as e:
118
- print(f"Error getting suggestions with Qwen: {e}")
119
  return []
120
 
121
 
122
  def annotate_image(image: Image.Image, result: dict):
123
- """Annotates the image with points or bounding boxes based on model output."""
124
  if not isinstance(image, Image.Image) or not isinstance(result, dict):
125
  return image
126
 
127
  original_width, original_height = image.size
128
- scene_np = np.array(image.copy())
129
 
130
  # Handle Point annotations
131
  if "points" in result and result["points"]:
132
- points_list = []
133
- for point in result.get("points", []):
134
- x = int(point["x"] * original_width)
135
- y = int(point["y"] * original_height)
136
- points_list.append([x, y])
137
-
138
  if not points_list:
139
  return image
140
 
141
- points_array = np.array(points_list).reshape(-1, 2)
142
  key_points = sv.KeyPoints(xy=points_array)
143
  vertex_annotator = sv.VertexAnnotator(radius=8, color=sv.Color.RED)
144
- annotated_image_np = vertex_annotator.annotate(
145
- scene=scene_np, key_points=key_points
146
- )
147
- return Image.fromarray(annotated_image_np)
148
 
149
  # Handle Detection annotations
150
  if "objects" in result and result["objects"]:
 
151
  boxes = []
152
  for obj in result["objects"]:
153
- x_min = obj["x_min"] * original_width
154
- y_min = obj["y_min"] * original_height
155
- x_max = obj["x_max"] * original_width
156
- y_max = obj["y_max"] * original_height
157
  boxes.append([x_min, y_min, x_max, y_max])
158
 
159
  if not boxes:
160
  return image
161
-
162
  detections = sv.Detections(xyxy=np.array(boxes))
163
- box_annotator = sv.BoxAnnotator(color_lookup=sv.ColorLookup.INDEX, thickness=4)
164
- label_annotator = sv.LabelAnnotator(color_lookup=sv.ColorLookup.INDEX)
165
-
166
- annotated_image_np = box_annotator.annotate(
167
- scene=scene_np, detections=detections
168
- )
169
- return Image.fromarray(annotated_image_np)
170
 
171
  return image
172
 
173
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
174
  @GPU
175
  def process_qwen(image: Image.Image, category: str, prompt: str):
176
  if category == "Query":
@@ -216,59 +229,56 @@ def process_qwen(image: Image.Image, category: str, prompt: str):
216
 
217
  # --- Gradio Interface Logic ---
218
  def on_category_and_image_change(image, category):
219
- """Generate suggestions when category or image changes."""
220
  text_box = gr.Textbox(value="", placeholder=PLACEHOLDERS.get(category, ""), interactive=True)
221
 
222
  if category == "Caption":
223
- return gr.Radio(choices=["short", "normal", "long"], label="Caption Length", value="normal", visible=True), text_box
224
-
225
  if image is None or category not in ["Point", "Detect"]:
226
  return gr.Radio(choices=[], visible=False), text_box
227
 
228
  suggestions = get_suggested_objects(image)
229
  if suggestions:
230
- return gr.Radio(choices=suggestions, label="Suggestions", visible=True, interactive=True), text_box
231
  else:
232
  return gr.Radio(choices=[], visible=False), text_box
233
 
234
 
235
  def update_prompt_from_radio(selected_object):
236
- """Update prompt textbox when a radio option is selected."""
237
- if selected_object:
238
- return gr.Textbox(value=selected_object)
239
- return gr.Textbox(value="")
240
 
241
 
242
  def process_inputs(image, category, prompt):
243
- """Main function to handle the user's request."""
244
  if image is None:
245
  raise gr.Error("Please upload an image.")
246
- if not prompt and category not in ["Caption"]:
247
- # Caption can have an empty prompt if a length is selected
248
- if category == "Caption" and not prompt:
249
- prompt = "normal" # default
250
- else:
251
- raise gr.Error("Please provide a prompt or select a suggestion.")
252
 
253
- # Resize the image to make inference quicker
254
- image.thumbnail((1024, 1024))
255
 
256
- # Process with Qwen
257
  qwen_text, qwen_data = process_qwen(image, category, prompt)
258
- qwen_annotated_image = annotate_image(image, qwen_data)
259
 
260
  return qwen_annotated_image, qwen_text
261
 
262
 
 
 
 
 
 
 
263
  # --- Gradio UI Layout ---
264
- with gr.Blocks(theme=Ocean()) as demo:
265
  gr.Markdown("# 👓 Object Understanding with Qwen3-VL")
266
  gr.Markdown(
267
  "### Explore object detection, visual grounding, and keypoint detection through natural language prompts."
268
  )
269
- gr.Markdown("""
270
- *Powered by [Qwen/Qwen3-VL-4B-Instruct](https://huggingface.co/Qwen/Qwen3-VL-4B-Instruct). Inspired by the tutorial [Object Detection and Visual Grounding with Qwen 2.5](https://pyimagesearch.com/2025/06/09/object-detection-and-visual-grounding-with-qwen-2-5/) on PyImageSearch.*
271
- """)
272
 
273
  with gr.Row():
274
  with gr.Column(scale=1):
@@ -293,7 +303,6 @@ with gr.Blocks(theme=Ocean()) as demo:
293
  submit_btn = gr.Button("Process Image", variant="primary")
294
 
295
  with gr.Column(scale=2):
296
- gr.Markdown("### Qwen/Qwen3-VL-4B-Instruct Output")
297
  qwen_img_output = gr.Image(label="Annotated Image")
298
  qwen_text_output = gr.Textbox(
299
  label="Text Output", lines=10, interactive=False
@@ -302,15 +311,14 @@ with gr.Blocks(theme=Ocean()) as demo:
302
  gr.Examples(
303
  examples=[
304
  ["examples/example_1.jpg", "Query", "How many cars are in the image?"],
305
- ["examples/example_1.jpg", "Detect", "car"],
306
  ["examples/example_2.JPG", "Point", "the person's face"],
307
- ["examples/example_2.JPG", "Caption", "short"],
308
  ],
309
  inputs=[image_input, category_select, prompt_input],
310
  )
311
 
312
  # --- Event Listeners ---
313
- # When image or category changes, update suggestions
314
  category_select.change(
315
  fn=on_category_and_image_change,
316
  inputs=[image_input, category_select],
@@ -322,14 +330,12 @@ with gr.Blocks(theme=Ocean()) as demo:
322
  outputs=[suggestions_radio, prompt_input],
323
  )
324
 
325
- # When a suggestion is clicked, update the prompt box
326
  suggestions_radio.change(
327
  fn=update_prompt_from_radio,
328
  inputs=[suggestions_radio],
329
  outputs=[prompt_input],
330
  )
331
 
332
- # Main submission action
333
  submit_btn.click(
334
  fn=process_inputs,
335
  inputs=[image_input, category_select, prompt_input],
@@ -337,4 +343,4 @@ with gr.Blocks(theme=Ocean()) as demo:
337
  )
338
 
339
  if __name__ == "__main__":
340
- demo.launch(debug=True)
 
20
  CATEGORIES = ["Query", "Caption", "Point", "Detect"]
21
  PLACEHOLDERS = {
22
  "Query": "What's in this image?",
23
+ "Caption": "Enter caption length: short, normal, or long",
24
  "Point": "Select an object from suggestions or enter manually",
25
  "Detect": "Select an object from suggestions or enter manually",
26
  }
 
39
 
40
  # --- Utility Functions ---
41
  def safe_parse_json(text: str):
 
42
  text = text.strip()
 
43
  text = re.sub(r"^```(json)?", "", text)
44
  text = re.sub(r"```$", "", text)
45
  text = text.strip()
 
48
  except json.JSONDecodeError:
49
  pass
50
  try:
 
51
  return ast.literal_eval(text)
52
  except Exception:
53
  return {}
54
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
 
56
  @GPU
57
  def get_suggested_objects(image: Image.Image):
58
+ """Get suggested objects in the image using Qwen"""
59
  if image is None:
60
  return []
61
+
62
  try:
63
+ prompt = "List the objects in the image in python list format."
64
+ messages = [
65
+ {
66
+ "role": "user",
67
+ "content": [
68
+ {"type": "image", "image": image},
69
+ {"type": "text", "text": prompt},
70
+ ],
71
+ }
72
+ ]
73
+ inputs = qwen_processor.apply_chat_template(
74
+ messages,
75
+ tokenize=True,
76
+ add_generation_prompt=True,
77
+ return_dict=True,
78
+ return_tensors="pt",
79
+ ).to(DEVICE)
80
+
81
+ with torch.inference_mode():
82
+ generated_ids = qwen_model.generate(
83
+ **inputs,
84
+ max_new_tokens=128,
85
+ )
86
+
87
+ generated_ids_trimmed = [
88
+ out_ids[len(in_ids) :]
89
+ for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
90
+ ]
91
+ output_text = qwen_processor.batch_decode(
92
+ generated_ids_trimmed,
93
+ skip_special_tokens=True,
94
+ clean_up_tokenization_spaces=False,
95
+ )[0]
96
+
97
+ suggested_objects = ast.literal_eval(output_text)
98
+ if isinstance(suggested_objects, list):
99
+ return suggested_objects[:3] if len(suggested_objects) > 3 else suggested_objects
100
  return []
101
  except Exception as e:
102
+ print(f"Error getting suggestions: {e}")
103
  return []
104
 
105
 
106
  def annotate_image(image: Image.Image, result: dict):
 
107
  if not isinstance(image, Image.Image) or not isinstance(result, dict):
108
  return image
109
 
110
  original_width, original_height = image.size
 
111
 
112
  # Handle Point annotations
113
  if "points" in result and result["points"]:
114
+ points_list = [
115
+ [int(p["x"] * original_width), int(p["y"] * original_height)]
116
+ for p in result.get("points", [])
117
+ ]
 
 
118
  if not points_list:
119
  return image
120
 
121
+ points_array = np.array(points_list).reshape(1, -1, 2)
122
  key_points = sv.KeyPoints(xy=points_array)
123
  vertex_annotator = sv.VertexAnnotator(radius=8, color=sv.Color.RED)
124
+ return vertex_annotator.annotate(scene=image.copy(), key_points=key_points)
 
 
 
125
 
126
  # Handle Detection annotations
127
  if "objects" in result and result["objects"]:
128
+ # Manually create detections from the Qwen output format
129
  boxes = []
130
  for obj in result["objects"]:
131
+ x_min = obj.get("x_min", 0.0) * original_width
132
+ y_min = obj.get("y_min", 0.0) * original_height
133
+ x_max = obj.get("x_max", 0.0) * original_width
134
+ y_max = obj.get("y_max", 0.0) * original_height
135
  boxes.append([x_min, y_min, x_max, y_max])
136
 
137
  if not boxes:
138
  return image
139
+
140
  detections = sv.Detections(xyxy=np.array(boxes))
141
+
142
+ if len(detections) == 0:
143
+ return image
144
+
145
+ box_annotator = sv.BoxAnnotator(color_lookup=sv.ColorLookup.INDEX, thickness=5)
146
+ return box_annotator.annotate(scene=image.copy(), detections=detections)
 
147
 
148
  return image
149
 
150
 
151
+ # --- Inference Functions ---
152
+ def run_qwen_inference(image: Image.Image, prompt: str):
153
+ messages = [
154
+ {
155
+ "role": "user",
156
+ "content": [
157
+ {"type": "image", "image": image},
158
+ {"type": "text", "text": prompt},
159
+ ],
160
+ }
161
+ ]
162
+ inputs = qwen_processor.apply_chat_template(
163
+ messages,
164
+ tokenize=True,
165
+ add_generation_prompt=True,
166
+ return_dict=True,
167
+ return_tensors="pt",
168
+ ).to(DEVICE)
169
+
170
+ with torch.inference_mode():
171
+ generated_ids = qwen_model.generate(
172
+ **inputs,
173
+ max_new_tokens=512,
174
+ )
175
+
176
+ generated_ids_trimmed = [
177
+ out_ids[len(in_ids) :]
178
+ for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
179
+ ]
180
+ return qwen_processor.batch_decode(
181
+ generated_ids_trimmed,
182
+ skip_special_tokens=True,
183
+ clean_up_tokenization_spaces=False,
184
+ )[0]
185
+
186
+
187
  @GPU
188
  def process_qwen(image: Image.Image, category: str, prompt: str):
189
  if category == "Query":
 
229
 
230
  # --- Gradio Interface Logic ---
231
  def on_category_and_image_change(image, category):
232
+ """Generate suggestions when category changes"""
233
  text_box = gr.Textbox(value="", placeholder=PLACEHOLDERS.get(category, ""), interactive=True)
234
 
235
  if category == "Caption":
236
+ return gr.Radio(choices=["short", "normal", "long"], visible=True, label="Caption Length"), text_box
237
+
238
  if image is None or category not in ["Point", "Detect"]:
239
  return gr.Radio(choices=[], visible=False), text_box
240
 
241
  suggestions = get_suggested_objects(image)
242
  if suggestions:
243
+ return gr.Radio(choices=suggestions, visible=True, interactive=True, label="Suggestions"), text_box
244
  else:
245
  return gr.Radio(choices=[], visible=False), text_box
246
 
247
 
248
  def update_prompt_from_radio(selected_object):
249
+ """Update prompt textbox when a radio option is selected"""
250
+ return gr.Textbox(value=selected_object) if selected_object else gr.Textbox(value="")
 
 
251
 
252
 
253
  def process_inputs(image, category, prompt):
 
254
  if image is None:
255
  raise gr.Error("Please upload an image.")
256
+ if not prompt:
257
+ raise gr.Error("Please provide a prompt.")
 
 
 
 
258
 
259
+ image.thumbnail((512, 512))
 
260
 
 
261
  qwen_text, qwen_data = process_qwen(image, category, prompt)
262
+ qwen_annotated_image = annotate_image(image.copy(), qwen_data)
263
 
264
  return qwen_annotated_image, qwen_text
265
 
266
 
267
+ css_hide_share = """
268
+ button#gradio-share-link-button-0 {
269
+ display: none !important;
270
+ }
271
+ """
272
+
273
  # --- Gradio UI Layout ---
274
+ with gr.Blocks(theme=Ocean(), css=css_hide_share) as demo:
275
  gr.Markdown("# 👓 Object Understanding with Qwen3-VL")
276
  gr.Markdown(
277
  "### Explore object detection, visual grounding, and keypoint detection through natural language prompts."
278
  )
279
+ gr.Markdown(
280
+ "*Powered by [Qwen/Qwen3-VL-4B-Instruct](https://huggingface.co/Qwen/Qwen3-VL-4B-Instruct).*"
281
+ )
282
 
283
  with gr.Row():
284
  with gr.Column(scale=1):
 
303
  submit_btn = gr.Button("Process Image", variant="primary")
304
 
305
  with gr.Column(scale=2):
 
306
  qwen_img_output = gr.Image(label="Annotated Image")
307
  qwen_text_output = gr.Textbox(
308
  label="Text Output", lines=10, interactive=False
 
311
  gr.Examples(
312
  examples=[
313
  ["examples/example_1.jpg", "Query", "How many cars are in the image?"],
314
+ ["examples/example_1.jpg", "Caption", "short"],
315
  ["examples/example_2.JPG", "Point", "the person's face"],
316
+ ["examples/example_2.JPG", "Detect", "the person"],
317
  ],
318
  inputs=[image_input, category_select, prompt_input],
319
  )
320
 
321
  # --- Event Listeners ---
 
322
  category_select.change(
323
  fn=on_category_and_image_change,
324
  inputs=[image_input, category_select],
 
330
  outputs=[suggestions_radio, prompt_input],
331
  )
332
 
 
333
  suggestions_radio.change(
334
  fn=update_prompt_from_radio,
335
  inputs=[suggestions_radio],
336
  outputs=[prompt_input],
337
  )
338
 
 
339
  submit_btn.click(
340
  fn=process_inputs,
341
  inputs=[image_input, category_select, prompt_input],
 
343
  )
344
 
345
  if __name__ == "__main__":
346
+ demo.launch()