Spaces:
Running
on
Zero
Running
on
Zero
| import gradio as gr | |
| from gradio.themes.ocean import Ocean | |
| import torch | |
| import numpy as np | |
| import supervision as sv | |
| from transformers import ( | |
| Qwen3VLForConditionalGeneration, | |
| Qwen3VLProcessor, | |
| ) | |
| import json | |
| import ast | |
| import re | |
| from PIL import Image | |
| from spaces import GPU | |
| # --- Constants and Configuration --- | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| DTYPE = "auto" | |
| CATEGORIES = ["Query", "Caption", "Point", "Detect"] | |
| PLACEHOLDERS = { | |
| "Query": "What's in this image?", | |
| "Caption": "Enter caption length: short, normal, or long", | |
| "Point": "Select an object from suggestions or enter manually", | |
| "Detect": "Select an object from suggestions or enter manually", | |
| } | |
| # --- Model Loading --- | |
| # Load Qwen3-VL | |
| qwen_model = Qwen3VLForConditionalGeneration.from_pretrained( | |
| "Qwen/Qwen3-VL-4B-Instruct", | |
| dtype=DTYPE, | |
| device_map=DEVICE, | |
| ).eval() | |
| qwen_processor = Qwen3VLProcessor.from_pretrained( | |
| "Qwen/Qwen3-VL-4B-Instruct", | |
| ) | |
| # --- Utility Functions --- | |
| def safe_parse_json(text: str): | |
| text = text.strip() | |
| text = re.sub(r"^```(json)?", "", text) | |
| text = re.sub(r"```$", "", text) | |
| text = text.strip() | |
| try: | |
| return json.loads(text) | |
| except json.JSONDecodeError: | |
| pass | |
| try: | |
| return ast.literal_eval(text) | |
| except Exception: | |
| return {} | |
| def get_suggested_objects(image: Image.Image): | |
| """Get suggested objects in the image using Qwen""" | |
| if image is None: | |
| return [] | |
| try: | |
| prompt = "List the objects in the image in python list format." | |
| messages = [ | |
| { | |
| "role": "user", | |
| "content": [ | |
| {"type": "image", "image": image}, | |
| {"type": "text", "text": prompt}, | |
| ], | |
| } | |
| ] | |
| inputs = qwen_processor.apply_chat_template( | |
| messages, | |
| tokenize=True, | |
| add_generation_prompt=True, | |
| return_dict=True, | |
| return_tensors="pt", | |
| ).to(DEVICE) | |
| with torch.inference_mode(): | |
| generated_ids = qwen_model.generate( | |
| **inputs, | |
| max_new_tokens=128, | |
| ) | |
| generated_ids_trimmed = [ | |
| out_ids[len(in_ids) :] | |
| for in_ids, out_ids in zip(inputs.input_ids, generated_ids) | |
| ] | |
| output_text = qwen_processor.batch_decode( | |
| generated_ids_trimmed, | |
| skip_special_tokens=True, | |
| clean_up_tokenization_spaces=False, | |
| )[0] | |
| suggested_objects = ast.literal_eval(output_text) | |
| if isinstance(suggested_objects, list): | |
| return suggested_objects[:3] if len(suggested_objects) > 3 else suggested_objects | |
| return [] | |
| except Exception as e: | |
| print(f"Error getting suggestions: {e}") | |
| return [] | |
| def annotate_image(image: Image.Image, result: dict): | |
| if not isinstance(image, Image.Image) or not isinstance(result, dict): | |
| return image | |
| original_width, original_height = image.size | |
| # Handle Point annotations | |
| if "points" in result and result["points"]: | |
| points_list = [ | |
| [int(p["x"] * original_width), int(p["y"] * original_height)] | |
| for p in result.get("points", []) | |
| ] | |
| if not points_list: | |
| return image | |
| points_array = np.array(points_list).reshape(1, -1, 2) | |
| key_points = sv.KeyPoints(xy=points_array) | |
| vertex_annotator = sv.VertexAnnotator(radius=8, color=sv.Color.RED) | |
| return vertex_annotator.annotate(scene=image.copy(), key_points=key_points) | |
| # Handle Detection annotations | |
| if "objects" in result and result["objects"]: | |
| # Manually create detections from the Qwen output format | |
| boxes = [] | |
| for obj in result["objects"]: | |
| x_min = obj.get("x_min", 0.0) * original_width | |
| y_min = obj.get("y_min", 0.0) * original_height | |
| x_max = obj.get("x_max", 0.0) * original_width | |
| y_max = obj.get("y_max", 0.0) * original_height | |
| boxes.append([x_min, y_min, x_max, y_max]) | |
| if not boxes: | |
| return image | |
| detections = sv.Detections(xyxy=np.array(boxes)) | |
| if len(detections) == 0: | |
| return image | |
| box_annotator = sv.BoxAnnotator(color_lookup=sv.ColorLookup.INDEX, thickness=5) | |
| return box_annotator.annotate(scene=image.copy(), detections=detections) | |
| return image | |
| # --- Inference Functions --- | |
| def run_qwen_inference(image: Image.Image, prompt: str): | |
| messages = [ | |
| { | |
| "role": "user", | |
| "content": [ | |
| {"type": "image", "image": image}, | |
| {"type": "text", "text": prompt}, | |
| ], | |
| } | |
| ] | |
| inputs = qwen_processor.apply_chat_template( | |
| messages, | |
| tokenize=True, | |
| add_generation_prompt=True, | |
| return_dict=True, | |
| return_tensors="pt", | |
| ).to(DEVICE) | |
| with torch.inference_mode(): | |
| generated_ids = qwen_model.generate( | |
| **inputs, | |
| max_new_tokens=512, | |
| ) | |
| generated_ids_trimmed = [ | |
| out_ids[len(in_ids) :] | |
| for in_ids, out_ids in zip(inputs.input_ids, generated_ids) | |
| ] | |
| return qwen_processor.batch_decode( | |
| generated_ids_trimmed, | |
| skip_special_tokens=True, | |
| clean_up_tokenization_spaces=False, | |
| )[0] | |
| def process_qwen(image: Image.Image, category: str, prompt: str): | |
| if category == "Query": | |
| return run_qwen_inference(image, prompt), {} | |
| elif category == "Caption": | |
| full_prompt = f"Provide a {prompt} length caption for the image." | |
| return run_qwen_inference(image, full_prompt), {} | |
| elif category == "Point": | |
| full_prompt = ( | |
| f"Provide 2d point coordinates for {prompt}. Report in JSON format." | |
| ) | |
| output_text = run_qwen_inference(image, full_prompt) | |
| parsed_json = safe_parse_json(output_text) | |
| points_result = {"points": []} | |
| if isinstance(parsed_json, list): | |
| for item in parsed_json: | |
| if "point_2d" in item and len(item["point_2d"]) == 2: | |
| x, y = item["point_2d"] | |
| points_result["points"].append({"x": x / 1000.0, "y": y / 1000.0}) | |
| return json.dumps(points_result, indent=2), points_result | |
| elif category == "Detect": | |
| full_prompt = ( | |
| f"Provide bounding box coordinates for {prompt}. Report in JSON format." | |
| ) | |
| output_text = run_qwen_inference(image, full_prompt) | |
| parsed_json = safe_parse_json(output_text) | |
| objects_result = {"objects": []} | |
| if isinstance(parsed_json, list): | |
| for item in parsed_json: | |
| if "bbox_2d" in item and len(item["bbox_2d"]) == 4: | |
| xmin, ymin, xmax, ymax = item["bbox_2d"] | |
| objects_result["objects"].append( | |
| { | |
| "x_min": xmin / 1000.0, | |
| "y_min": ymin / 1000.0, | |
| "x_max": xmax / 1000.0, | |
| "y_max": ymax / 1000.0, | |
| } | |
| ) | |
| return json.dumps(objects_result, indent=2), objects_result | |
| return "Invalid category", {} | |
| # --- Gradio Interface Logic --- | |
| def on_category_and_image_change(image, category): | |
| """Generate suggestions when category changes""" | |
| text_box = gr.Textbox(value="", placeholder=PLACEHOLDERS.get(category, ""), interactive=True) | |
| if category == "Caption": | |
| return gr.Radio(choices=["short", "normal", "long"], visible=True, label="Caption Length"), text_box | |
| if image is None or category not in ["Point", "Detect"]: | |
| return gr.Radio(choices=[], visible=False), text_box | |
| suggestions = get_suggested_objects(image) | |
| if suggestions: | |
| return gr.Radio(choices=suggestions, visible=True, interactive=True, label="Suggestions"), text_box | |
| else: | |
| return gr.Radio(choices=[], visible=False), text_box | |
| def update_prompt_from_radio(selected_object): | |
| """Update prompt textbox when a radio option is selected""" | |
| return gr.Textbox(value=selected_object) if selected_object else gr.Textbox(value="") | |
| def process_inputs(image, category, prompt): | |
| if image is None: | |
| raise gr.Error("Please upload an image.") | |
| if not prompt: | |
| raise gr.Error("Please provide a prompt.") | |
| image.thumbnail((512, 512)) | |
| qwen_text, qwen_data = process_qwen(image, category, prompt) | |
| qwen_annotated_image = annotate_image(image.copy(), qwen_data) | |
| return qwen_annotated_image, qwen_text | |
| css_hide_share = """ | |
| button#gradio-share-link-button-0 { | |
| display: none !important; | |
| } | |
| """ | |
| # --- Gradio UI Layout --- | |
| with gr.Blocks(theme=Ocean(), css=css_hide_share) as demo: | |
| gr.Markdown("# 👓 Object Understanding with Qwen3-VL") | |
| gr.Markdown( | |
| "### Explore object detection, visual grounding, and keypoint detection through natural language prompts." | |
| ) | |
| gr.Markdown( | |
| "*Powered by [Qwen/Qwen3-VL-4B-Instruct](https://huggingface.co/Qwen/Qwen3-VL-4B-Instruct).*" | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| image_input = gr.Image(type="pil", label="Input Image") | |
| category_select = gr.Radio( | |
| choices=CATEGORIES, | |
| value=CATEGORIES[0], | |
| label="Select Task Category", | |
| interactive=True, | |
| ) | |
| suggestions_radio = gr.Radio( | |
| choices=[], | |
| label="Suggestions", | |
| visible=False, | |
| interactive=True, | |
| ) | |
| prompt_input = gr.Textbox( | |
| placeholder=PLACEHOLDERS[CATEGORIES[0]], | |
| label="Prompt", | |
| lines=2, | |
| ) | |
| submit_btn = gr.Button("Process Image", variant="primary") | |
| with gr.Column(scale=2): | |
| qwen_img_output = gr.Image(label="Annotated Image") | |
| qwen_text_output = gr.Textbox( | |
| label="Text Output", lines=10, interactive=False | |
| ) | |
| gr.Examples( | |
| examples=[ | |
| ["examples/example_1.jpg", "Query", "How many cars are in the image?"], | |
| ["examples/example_1.jpg", "Caption", "short"], | |
| ["examples/example_2.JPG", "Point", "the person's face"], | |
| ["examples/example_2.JPG", "Detect", "the person"], | |
| ], | |
| inputs=[image_input, category_select, prompt_input], | |
| ) | |
| # --- Event Listeners --- | |
| category_select.change( | |
| fn=on_category_and_image_change, | |
| inputs=[image_input, category_select], | |
| outputs=[suggestions_radio, prompt_input], | |
| ) | |
| image_input.change( | |
| fn=on_category_and_image_change, | |
| inputs=[image_input, category_select], | |
| outputs=[suggestions_radio, prompt_input], | |
| ) | |
| suggestions_radio.change( | |
| fn=update_prompt_from_radio, | |
| inputs=[suggestions_radio], | |
| outputs=[prompt_input], | |
| ) | |
| submit_btn.click( | |
| fn=process_inputs, | |
| inputs=[image_input, category_select, prompt_input], | |
| outputs=[qwen_img_output, qwen_text_output], | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |