prithivMLmods's picture
Update app.py
3a7015b verified
raw
history blame
11.6 kB
import gradio as gr
import torch
import numpy as np
import supervision as sv
from typing import Iterable
from transformers import (
Qwen3VLForConditionalGeneration,
Qwen3VLProcessor,
)
from gradio.themes import Soft
from gradio.themes.utils import colors, fonts, sizes
import json
import ast
import re
from PIL import Image
from spaces import GPU
colors.steel_blue = colors.Color(
name="steel_blue",
c50="#EBF3F8",
c100="#D3E5F0",
c200="#A8CCE1",
c300="#7DB3D2",
c400="#529AC3",
c500="#4682B4",
c600="#3E72A0",
c700="#36638C",
c800="#2E5378",
c900="#264364",
c950="#1E3450",
)
class SteelBlueTheme(Soft):
def __init__(
self,
*,
primary_hue: colors.Color | str = colors.gray,
secondary_hue: colors.Color | str = colors.steel_blue,
neutral_hue: colors.Color | str = colors.slate,
text_size: sizes.Size | str = sizes.text_lg,
font: fonts.Font | str | Iterable[fonts.Font | str] = (
fonts.GoogleFont("Outfit"), "Arial", "sans-serif",
),
font_mono: fonts.Font | str | Iterable[fonts.Font | str] = (
fonts.GoogleFont("IBM Plex Mono"), "ui-monospace", "monospace",
),
):
super().__init__(
primary_hue=primary_hue,
secondary_hue=secondary_hue,
neutral_hue=neutral_hue,
text_size=text_size,
font=font,
font_mono=font_mono,
)
super().set(
background_fill_primary="*primary_50",
background_fill_primary_dark="*primary_900",
body_background_fill="linear-gradient(135deg, *primary_200, *primary_100)",
body_background_fill_dark="linear-gradient(135deg, *primary_900, *primary_800)",
button_primary_text_color="white",
button_primary_text_color_hover="white",
button_primary_background_fill="linear-gradient(90deg, *secondary_500, *secondary_600)",
button_primary_background_fill_hover="linear-gradient(90deg, *secondary_600, *secondary_700)",
button_primary_background_fill_dark="linear-gradient(90deg, *secondary_600, *secondary_800)",
button_primary_background_fill_hover_dark="linear-gradient(90deg, *secondary_500, *secondary_500)",
button_secondary_text_color="black",
button_secondary_text_color_hover="white",
button_secondary_background_fill="linear-gradient(90deg, *primary_300, *primary_300)",
button_secondary_background_fill_hover="linear-gradient(90deg, *primary_400, *primary_400)",
button_secondary_background_fill_dark="linear-gradient(90deg, *primary_500, *primary_600)",
button_secondary_background_fill_hover_dark="linear-gradient(90deg, *primary_500, *primary_500)",
slider_color="*secondary_500",
slider_color_dark="*secondary_600",
block_title_text_weight="600",
block_border_width="3px",
block_shadow="*shadow_drop_lg",
button_primary_shadow="*shadow_drop_lg",
button_large_padding="11px",
color_accent_soft="*primary_100",
block_label_background_fill="*primary_200",
)
steel_blue_theme = SteelBlueTheme()
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DTYPE = "auto"
CATEGORIES = ["Query", "Caption", "Point", "Detect"]
qwen_model = Qwen3VLForConditionalGeneration.from_pretrained(
"Qwen/Qwen3-VL-4B-Instruct",
dtype=DTYPE,
device_map=DEVICE,
).eval()
qwen_processor = Qwen3VLProcessor.from_pretrained(
"Qwen/Qwen3-VL-4B-Instruct",
)
def safe_parse_json(text: str):
text = text.strip()
text = re.sub(r"^```(json)?", "", text)
text = re.sub(r"```$", "", text)
text = text.strip()
try:
return json.loads(text)
except json.JSONDecodeError:
pass
try:
return ast.literal_eval(text)
except Exception:
return {}
def annotate_image(image: Image.Image, result: dict):
if not isinstance(image, Image.Image) or not isinstance(result, dict):
return image
# Ensure image is mutable
image = image.convert("RGB")
original_width, original_height = image.size
if "points" in result and result["points"]:
points_list = [
[int(p["x"] * original_width), int(p["y"] * original_height)]
for p in result.get("points", [])
]
if not points_list:
return image
points_array = np.array(points_list).reshape(1, -1, 2)
key_points = sv.KeyPoints(xy=points_array)
vertex_annotator = sv.VertexAnnotator(radius=4, color=sv.Color.RED)
annotated_image = vertex_annotator.annotate(scene=np.array(image.copy()), key_points=key_points)
return Image.fromarray(annotated_image)
if "objects" in result and result["objects"]:
boxes = []
for obj in result["objects"]:
x_min = obj.get("x_min", 0.0) * original_width
y_min = obj.get("y_min", 0.0) * original_height
x_max = obj.get("x_max", 0.0) * original_width
y_max = obj.get("y_max", 0.0) * original_height
boxes.append([x_min, y_min, x_max, y_max])
if not boxes:
return image
detections = sv.Detections(xyxy=np.array(boxes))
if len(detections) == 0:
return image
box_annotator = sv.BoxAnnotator(color_lookup=sv.ColorLookup.INDEX, thickness=2)
annotated_image = box_annotator.annotate(scene=np.array(image.copy()), detections=detections)
return Image.fromarray(annotated_image)
return image
def run_qwen_inference(image: Image.Image, prompt: str):
messages = [
{
"role": "user",
"content": [
{"type": "image", "image": image},
{"type": "text", "text": prompt},
],
}
]
inputs = qwen_processor.apply_chat_template(
messages,
tokenize=True,
add_generation_prompt=True,
return_dict=True,
return_tensors="pt",
).to(DEVICE)
with torch.inference_mode():
generated_ids = qwen_model.generate(
**inputs,
max_new_tokens=512,
)
generated_ids_trimmed = [
out_ids[len(in_ids) :]
for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
return qwen_processor.batch_decode(
generated_ids_trimmed,
skip_special_tokens=True,
clean_up_tokenization_spaces=False,
)[0]
@GPU
def process_qwen(image: Image.Image, category: str, prompt: str):
if category == "Query":
return run_qwen_inference(image, prompt), {}
elif category == "Caption":
full_prompt = f"Provide a {prompt} length caption for the image."
return run_qwen_inference(image, full_prompt), {}
elif category == "Point":
full_prompt = (
f"Provide 2d point coordinates for {prompt}. Report in JSON format."
)
output_text = run_qwen_inference(image, full_prompt)
parsed_json = safe_parse_json(output_text)
points_result = {"points": []}
if isinstance(parsed_json, list):
for item in parsed_json:
if "point_2d" in item and len(item["point_2d"]) == 2:
x, y = item["point_2d"]
points_result["points"].append({"x": x / 1000.0, "y": y / 1000.0})
return json.dumps(points_result, indent=2), points_result
elif category == "Detect":
full_prompt = (
f"Provide bounding box coordinates for {prompt}. Report in JSON format."
)
output_text = run_qwen_inference(image, full_prompt)
parsed_json = safe_parse_json(output_text)
objects_result = {"objects": []}
if isinstance(parsed_json, list):
for item in parsed_json:
if "bbox_2d" in item and len(item["bbox_2d"]) == 4:
xmin, ymin, xmax, ymax = item["bbox_2d"]
objects_result["objects"].append(
{
"x_min": xmin / 1000.0,
"y_min": ymin / 1000.0,
"x_max": xmax / 1000.0,
"y_max": ymax / 1000.0,
}
)
return json.dumps(objects_result, indent=2), objects_result
return "Invalid category", {}
def process_inputs(image, category, prompt):
if image is None:
raise gr.Error("Please upload an image.")
if not prompt:
raise gr.Error("Please provide a prompt.")
image.thumbnail((512, 512))
qwen_text, qwen_data = process_qwen(image, category, prompt)
qwen_annotated_image = annotate_image(image.copy(), qwen_data)
return qwen_annotated_image, qwen_text
def on_category_change(category: str):
if category == "Query":
return gr.Textbox(placeholder="e.g., Count the total number of boats and describe the environment.")
elif category == "Caption":
return gr.Textbox(placeholder="e.g., short, normal, detailed")
elif category == "Point":
return gr.Textbox(placeholder="e.g., The gun held by the person.")
elif category == "Detect":
return gr.Textbox(placeholder="e.g., The headlight of the car.")
return gr.Textbox(placeholder="e.g., detect the object.")
css = """
#main-title h1 {
font-size: 2.3em !important;
}
#output-title h2 {
font-size: 2.1em !important;
}
"""
with gr.Blocks() as demo:
with gr.Column(elem_id="col-container"):
gr.Markdown("# **Qwen-3VL: Multimodal Understanding**", elem_id="main-title")
with gr.Row():
with gr.Column(scale=1):
image_input = gr.Image(type="pil", label="Upload Image")
category_select = gr.Radio(
choices=CATEGORIES,
value="Query",
label="Select Task Category",
interactive=True,
)
prompt_input = gr.Textbox(
placeholder="e.g., Count the total number of boats and describe the environment.",
label="Prompt",
lines=1,
)
submit_btn = gr.Button("Process Image", variant="primary")
with gr.Column(scale=2):
qwen_img_output = gr.Image(label="Output Image")
qwen_text_output = gr.Textbox(
label="Text Output", lines=10, interactive=True)
gr.Examples(
examples=[
["examples/5.jpg", "Point", "Detect the children who are out of focus and wearing a white T-shirt."],
["examples/5.jpg", "Detect", "Point out the out-of-focus (all) children."],
["examples/4.jpg", "Detect", "Headlight"],
["examples/3.jpg", "Point", "Gun"],
["examples/1.jpg", "Query", "Count the total number of boats and describe the environment."],
["examples/2.jpg", "Caption", "a brief"],
],
inputs=[image_input, category_select, prompt_input],
)
category_select.change(
fn=on_category_change,
inputs=[category_select],
outputs=[prompt_input],
)
submit_btn.click(
fn=process_inputs,
inputs=[image_input, category_select, prompt_input],
outputs=[qwen_img_output, qwen_text_output],
)
if __name__ == "__main__":
demo.launch(css=css, theme=steel_blue_theme, mcp_server=True, ssr_mode=False, show_error=True)