import gradio as gr import torch import numpy as np import supervision as sv from typing import Iterable from transformers import ( Qwen3VLForConditionalGeneration, Qwen3VLProcessor, ) from gradio.themes import Soft from gradio.themes.utils import colors, fonts, sizes import json import ast import re from PIL import Image from spaces import GPU colors.steel_blue = colors.Color( name="steel_blue", c50="#EBF3F8", c100="#D3E5F0", c200="#A8CCE1", c300="#7DB3D2", c400="#529AC3", c500="#4682B4", c600="#3E72A0", c700="#36638C", c800="#2E5378", c900="#264364", c950="#1E3450", ) class SteelBlueTheme(Soft): def __init__( self, *, primary_hue: colors.Color | str = colors.gray, secondary_hue: colors.Color | str = colors.steel_blue, neutral_hue: colors.Color | str = colors.slate, text_size: sizes.Size | str = sizes.text_lg, font: fonts.Font | str | Iterable[fonts.Font | str] = ( fonts.GoogleFont("Outfit"), "Arial", "sans-serif", ), font_mono: fonts.Font | str | Iterable[fonts.Font | str] = ( fonts.GoogleFont("IBM Plex Mono"), "ui-monospace", "monospace", ), ): super().__init__( primary_hue=primary_hue, secondary_hue=secondary_hue, neutral_hue=neutral_hue, text_size=text_size, font=font, font_mono=font_mono, ) super().set( background_fill_primary="*primary_50", background_fill_primary_dark="*primary_900", body_background_fill="linear-gradient(135deg, *primary_200, *primary_100)", body_background_fill_dark="linear-gradient(135deg, *primary_900, *primary_800)", button_primary_text_color="white", button_primary_text_color_hover="white", button_primary_background_fill="linear-gradient(90deg, *secondary_500, *secondary_600)", button_primary_background_fill_hover="linear-gradient(90deg, *secondary_600, *secondary_700)", button_primary_background_fill_dark="linear-gradient(90deg, *secondary_600, *secondary_800)", button_primary_background_fill_hover_dark="linear-gradient(90deg, *secondary_500, *secondary_500)", button_secondary_text_color="black", button_secondary_text_color_hover="white", button_secondary_background_fill="linear-gradient(90deg, *primary_300, *primary_300)", button_secondary_background_fill_hover="linear-gradient(90deg, *primary_400, *primary_400)", button_secondary_background_fill_dark="linear-gradient(90deg, *primary_500, *primary_600)", button_secondary_background_fill_hover_dark="linear-gradient(90deg, *primary_500, *primary_500)", slider_color="*secondary_500", slider_color_dark="*secondary_600", block_title_text_weight="600", block_border_width="3px", block_shadow="*shadow_drop_lg", button_primary_shadow="*shadow_drop_lg", button_large_padding="11px", color_accent_soft="*primary_100", block_label_background_fill="*primary_200", ) steel_blue_theme = SteelBlueTheme() DEVICE = "cuda" if torch.cuda.is_available() else "cpu" DTYPE = "auto" CATEGORIES = ["Query", "Caption", "Point", "Detect"] qwen_model = Qwen3VLForConditionalGeneration.from_pretrained( "Qwen/Qwen3-VL-4B-Instruct", dtype=DTYPE, device_map=DEVICE, ).eval() qwen_processor = Qwen3VLProcessor.from_pretrained( "Qwen/Qwen3-VL-4B-Instruct", ) def safe_parse_json(text: str): text = text.strip() text = re.sub(r"^```(json)?", "", text) text = re.sub(r"```$", "", text) text = text.strip() try: return json.loads(text) except json.JSONDecodeError: pass try: return ast.literal_eval(text) except Exception: return {} def annotate_image(image: Image.Image, result: dict): if not isinstance(image, Image.Image) or not isinstance(result, dict): return image # Ensure image is mutable image = image.convert("RGB") original_width, original_height = image.size if "points" in result and result["points"]: points_list = [ [int(p["x"] * original_width), int(p["y"] * original_height)] for p in result.get("points", []) ] if not points_list: return image points_array = np.array(points_list).reshape(1, -1, 2) key_points = sv.KeyPoints(xy=points_array) vertex_annotator = sv.VertexAnnotator(radius=4, color=sv.Color.RED) annotated_image = vertex_annotator.annotate(scene=np.array(image.copy()), key_points=key_points) return Image.fromarray(annotated_image) if "objects" in result and result["objects"]: boxes = [] for obj in result["objects"]: x_min = obj.get("x_min", 0.0) * original_width y_min = obj.get("y_min", 0.0) * original_height x_max = obj.get("x_max", 0.0) * original_width y_max = obj.get("y_max", 0.0) * original_height boxes.append([x_min, y_min, x_max, y_max]) if not boxes: return image detections = sv.Detections(xyxy=np.array(boxes)) if len(detections) == 0: return image box_annotator = sv.BoxAnnotator(color_lookup=sv.ColorLookup.INDEX, thickness=2) annotated_image = box_annotator.annotate(scene=np.array(image.copy()), detections=detections) return Image.fromarray(annotated_image) return image def run_qwen_inference(image: Image.Image, prompt: str): messages = [ { "role": "user", "content": [ {"type": "image", "image": image}, {"type": "text", "text": prompt}, ], } ] inputs = qwen_processor.apply_chat_template( messages, tokenize=True, add_generation_prompt=True, return_dict=True, return_tensors="pt", ).to(DEVICE) with torch.inference_mode(): generated_ids = qwen_model.generate( **inputs, max_new_tokens=512, ) generated_ids_trimmed = [ out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) ] return qwen_processor.batch_decode( generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False, )[0] @GPU def process_qwen(image: Image.Image, category: str, prompt: str): if category == "Query": return run_qwen_inference(image, prompt), {} elif category == "Caption": full_prompt = f"Provide a {prompt} length caption for the image." return run_qwen_inference(image, full_prompt), {} elif category == "Point": full_prompt = ( f"Provide 2d point coordinates for {prompt}. Report in JSON format." ) output_text = run_qwen_inference(image, full_prompt) parsed_json = safe_parse_json(output_text) points_result = {"points": []} if isinstance(parsed_json, list): for item in parsed_json: if "point_2d" in item and len(item["point_2d"]) == 2: x, y = item["point_2d"] points_result["points"].append({"x": x / 1000.0, "y": y / 1000.0}) return json.dumps(points_result, indent=2), points_result elif category == "Detect": full_prompt = ( f"Provide bounding box coordinates for {prompt}. Report in JSON format." ) output_text = run_qwen_inference(image, full_prompt) parsed_json = safe_parse_json(output_text) objects_result = {"objects": []} if isinstance(parsed_json, list): for item in parsed_json: if "bbox_2d" in item and len(item["bbox_2d"]) == 4: xmin, ymin, xmax, ymax = item["bbox_2d"] objects_result["objects"].append( { "x_min": xmin / 1000.0, "y_min": ymin / 1000.0, "x_max": xmax / 1000.0, "y_max": ymax / 1000.0, } ) return json.dumps(objects_result, indent=2), objects_result return "Invalid category", {} def process_inputs(image, category, prompt): if image is None: raise gr.Error("Please upload an image.") if not prompt: raise gr.Error("Please provide a prompt.") image.thumbnail((512, 512)) qwen_text, qwen_data = process_qwen(image, category, prompt) qwen_annotated_image = annotate_image(image.copy(), qwen_data) return qwen_annotated_image, qwen_text def on_category_change(category: str): if category == "Query": return gr.Textbox(placeholder="e.g., Count the total number of boats and describe the environment.") elif category == "Caption": return gr.Textbox(placeholder="e.g., short, normal, detailed") elif category == "Point": return gr.Textbox(placeholder="e.g., The gun held by the person.") elif category == "Detect": return gr.Textbox(placeholder="e.g., The headlight of the car.") return gr.Textbox(placeholder="e.g., detect the object.") css = """ #main-title h1 { font-size: 2.3em !important; } #output-title h2 { font-size: 2.1em !important; } """ with gr.Blocks() as demo: with gr.Column(elem_id="col-container"): gr.Markdown("# **Qwen-3VL: Multimodal Understanding**", elem_id="main-title") with gr.Row(): with gr.Column(scale=1): image_input = gr.Image(type="pil", label="Upload Image") category_select = gr.Radio( choices=CATEGORIES, value="Query", label="Select Task Category", interactive=True, ) prompt_input = gr.Textbox( placeholder="e.g., Count the total number of boats and describe the environment.", label="Prompt", lines=1, ) submit_btn = gr.Button("Process Image", variant="primary") with gr.Column(scale=2): qwen_img_output = gr.Image(label="Output Image") qwen_text_output = gr.Textbox( label="Text Output", lines=10, interactive=True) gr.Examples( examples=[ ["examples/5.jpg", "Point", "Detect the children who are out of focus and wearing a white T-shirt."], ["examples/5.jpg", "Detect", "Point out the out-of-focus (all) children."], ["examples/4.jpg", "Detect", "Headlight"], ["examples/3.jpg", "Point", "Gun"], ["examples/1.jpg", "Query", "Count the total number of boats and describe the environment."], ["examples/2.jpg", "Caption", "a brief"], ], inputs=[image_input, category_select, prompt_input], ) category_select.change( fn=on_category_change, inputs=[category_select], outputs=[prompt_input], ) submit_btn.click( fn=process_inputs, inputs=[image_input, category_select, prompt_input], outputs=[qwen_img_output, qwen_text_output], ) if __name__ == "__main__": demo.launch(css=css, theme=steel_blue_theme, mcp_server=True, ssr_mode=False, show_error=True)