import gradio as gr from gradio.themes.ocean import Ocean import torch import numpy as np import supervision as sv from transformers import ( Qwen3VLForConditionalGeneration, Qwen3VLProcessor, ) import json import ast import re from PIL import Image from spaces import GPU # --- Constants and Configuration --- DEVICE = "cuda" if torch.cuda.is_available() else "cpu" DTYPE = "auto" CATEGORIES = ["Query", "Caption", "Point", "Detect"] PLACEHOLDERS = { "Query": "What's in this image?", "Caption": "Enter caption length: short, normal, or long", "Point": "Select an object from suggestions or enter manually", "Detect": "Select an object from suggestions or enter manually", } # --- Model Loading --- # Load Qwen3-VL qwen_model = Qwen3VLForConditionalGeneration.from_pretrained( "Qwen/Qwen3-VL-4B-Instruct", dtype=DTYPE, device_map=DEVICE, ).eval() qwen_processor = Qwen3VLProcessor.from_pretrained( "Qwen/Qwen3-VL-4B-Instruct", ) # --- Utility Functions --- def safe_parse_json(text: str): text = text.strip() text = re.sub(r"^```(json)?", "", text) text = re.sub(r"```$", "", text) text = text.strip() try: return json.loads(text) except json.JSONDecodeError: pass try: return ast.literal_eval(text) except Exception: return {} @GPU def get_suggested_objects(image: Image.Image): """Get suggested objects in the image using Qwen""" if image is None: return [] try: prompt = "List the objects in the image in python list format." messages = [ { "role": "user", "content": [ {"type": "image", "image": image}, {"type": "text", "text": prompt}, ], } ] inputs = qwen_processor.apply_chat_template( messages, tokenize=True, add_generation_prompt=True, return_dict=True, return_tensors="pt", ).to(DEVICE) with torch.inference_mode(): generated_ids = qwen_model.generate( **inputs, max_new_tokens=128, ) generated_ids_trimmed = [ out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) ] output_text = qwen_processor.batch_decode( generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False, )[0] suggested_objects = ast.literal_eval(output_text) if isinstance(suggested_objects, list): return suggested_objects[:3] if len(suggested_objects) > 3 else suggested_objects return [] except Exception as e: print(f"Error getting suggestions: {e}") return [] def annotate_image(image: Image.Image, result: dict): if not isinstance(image, Image.Image) or not isinstance(result, dict): return image original_width, original_height = image.size # Handle Point annotations if "points" in result and result["points"]: points_list = [ [int(p["x"] * original_width), int(p["y"] * original_height)] for p in result.get("points", []) ] if not points_list: return image points_array = np.array(points_list).reshape(1, -1, 2) key_points = sv.KeyPoints(xy=points_array) vertex_annotator = sv.VertexAnnotator(radius=8, color=sv.Color.RED) return vertex_annotator.annotate(scene=image.copy(), key_points=key_points) # Handle Detection annotations if "objects" in result and result["objects"]: # Manually create detections from the Qwen output format boxes = [] for obj in result["objects"]: x_min = obj.get("x_min", 0.0) * original_width y_min = obj.get("y_min", 0.0) * original_height x_max = obj.get("x_max", 0.0) * original_width y_max = obj.get("y_max", 0.0) * original_height boxes.append([x_min, y_min, x_max, y_max]) if not boxes: return image detections = sv.Detections(xyxy=np.array(boxes)) if len(detections) == 0: return image box_annotator = sv.BoxAnnotator(color_lookup=sv.ColorLookup.INDEX, thickness=5) return box_annotator.annotate(scene=image.copy(), detections=detections) return image # --- Inference Functions --- def run_qwen_inference(image: Image.Image, prompt: str): messages = [ { "role": "user", "content": [ {"type": "image", "image": image}, {"type": "text", "text": prompt}, ], } ] inputs = qwen_processor.apply_chat_template( messages, tokenize=True, add_generation_prompt=True, return_dict=True, return_tensors="pt", ).to(DEVICE) with torch.inference_mode(): generated_ids = qwen_model.generate( **inputs, max_new_tokens=512, ) generated_ids_trimmed = [ out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) ] return qwen_processor.batch_decode( generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False, )[0] @GPU def process_qwen(image: Image.Image, category: str, prompt: str): if category == "Query": return run_qwen_inference(image, prompt), {} elif category == "Caption": full_prompt = f"Provide a {prompt} length caption for the image." return run_qwen_inference(image, full_prompt), {} elif category == "Point": full_prompt = ( f"Provide 2d point coordinates for {prompt}. Report in JSON format." ) output_text = run_qwen_inference(image, full_prompt) parsed_json = safe_parse_json(output_text) points_result = {"points": []} if isinstance(parsed_json, list): for item in parsed_json: if "point_2d" in item and len(item["point_2d"]) == 2: x, y = item["point_2d"] points_result["points"].append({"x": x / 1000.0, "y": y / 1000.0}) return json.dumps(points_result, indent=2), points_result elif category == "Detect": full_prompt = ( f"Provide bounding box coordinates for {prompt}. Report in JSON format." ) output_text = run_qwen_inference(image, full_prompt) parsed_json = safe_parse_json(output_text) objects_result = {"objects": []} if isinstance(parsed_json, list): for item in parsed_json: if "bbox_2d" in item and len(item["bbox_2d"]) == 4: xmin, ymin, xmax, ymax = item["bbox_2d"] objects_result["objects"].append( { "x_min": xmin / 1000.0, "y_min": ymin / 1000.0, "x_max": xmax / 1000.0, "y_max": ymax / 1000.0, } ) return json.dumps(objects_result, indent=2), objects_result return "Invalid category", {} # --- Gradio Interface Logic --- def on_category_and_image_change(image, category): """Generate suggestions when category changes""" text_box = gr.Textbox(value="", placeholder=PLACEHOLDERS.get(category, ""), interactive=True) if category == "Caption": return gr.Radio(choices=["short", "normal", "long"], visible=True, label="Caption Length"), text_box if image is None or category not in ["Point", "Detect"]: return gr.Radio(choices=[], visible=False), text_box suggestions = get_suggested_objects(image) if suggestions: return gr.Radio(choices=suggestions, visible=True, interactive=True, label="Suggestions"), text_box else: return gr.Radio(choices=[], visible=False), text_box def update_prompt_from_radio(selected_object): """Update prompt textbox when a radio option is selected""" return gr.Textbox(value=selected_object) if selected_object else gr.Textbox(value="") def process_inputs(image, category, prompt): if image is None: raise gr.Error("Please upload an image.") if not prompt: raise gr.Error("Please provide a prompt.") image.thumbnail((512, 512)) qwen_text, qwen_data = process_qwen(image, category, prompt) qwen_annotated_image = annotate_image(image.copy(), qwen_data) return qwen_annotated_image, qwen_text css_hide_share = """ button#gradio-share-link-button-0 { display: none !important; } """ # --- Gradio UI Layout --- with gr.Blocks(theme=Ocean(), css=css_hide_share) as demo: gr.Markdown("# 👓 Object Understanding with Qwen3-VL") gr.Markdown( "### Explore object detection, visual grounding, and keypoint detection through natural language prompts." ) gr.Markdown( "*Powered by [Qwen/Qwen3-VL-4B-Instruct](https://huggingface.co/Qwen/Qwen3-VL-4B-Instruct).*" ) with gr.Row(): with gr.Column(scale=1): image_input = gr.Image(type="pil", label="Input Image") category_select = gr.Radio( choices=CATEGORIES, value=CATEGORIES[0], label="Select Task Category", interactive=True, ) suggestions_radio = gr.Radio( choices=[], label="Suggestions", visible=False, interactive=True, ) prompt_input = gr.Textbox( placeholder=PLACEHOLDERS[CATEGORIES[0]], label="Prompt", lines=2, ) submit_btn = gr.Button("Process Image", variant="primary") with gr.Column(scale=2): qwen_img_output = gr.Image(label="Annotated Image") qwen_text_output = gr.Textbox( label="Text Output", lines=10, interactive=False ) gr.Examples( examples=[ ["examples/example_1.jpg", "Query", "How many cars are in the image?"], ["examples/example_1.jpg", "Caption", "short"], ["examples/example_2.JPG", "Point", "the person's face"], ["examples/example_2.JPG", "Detect", "the person"], ], inputs=[image_input, category_select, prompt_input], ) # --- Event Listeners --- category_select.change( fn=on_category_and_image_change, inputs=[image_input, category_select], outputs=[suggestions_radio, prompt_input], ) image_input.change( fn=on_category_and_image_change, inputs=[image_input, category_select], outputs=[suggestions_radio, prompt_input], ) suggestions_radio.change( fn=update_prompt_from_radio, inputs=[suggestions_radio], outputs=[prompt_input], ) submit_btn.click( fn=process_inputs, inputs=[image_input, category_select, prompt_input], outputs=[qwen_img_output, qwen_text_output], ) if __name__ == "__main__": demo.launch()