STAMP: Simultaneous Textual All-Mask Prediction
STAMP is a Multimodal Large Language Model (MLLM) capable of performing simultaneous dialogue and segmentation. It resolves the conflict between text generation and mask prediction, achieving high performance and fast inference speed.
π Quick Start
Note: This model relies on the codebase and custom architecture defined in the GitHub repository. You must clone the repository to run inference.
1. Installation
Clone the repository and install the required dependencies:
git clone https://github.com/HKUST-LongGroup/STAMP.git
cd STAMP
# Create environment (Recommended)
conda create -n STAMP python=3.10
conda activate STAMP
# Install dependencies
pip install -r requirements.txt
pip install flash-attn --no-build-isolation
# download SAM-H to YOUR_SAM_PATH
wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth
2. Run Inference
- Option A: Run via Command Line.
# Make sure you are in the STAMP directory
python run_seg_ref.py \
--model-path "JiaZL/STAMP-2B-uni" \
--image-file "images/horses.png" \
--sam_path "HCMUE-Research/SAM-vit-h/sam_vit_h_4b8939.pth" \
--query "Please segment the white horse in the image."
- Option B: Run Code Directly.
import torch
import torch.nn.functional as F
import numpy as np
import cv2
import os
from PIL import Image
# Import local modules
from segment_predictor_cache import GenerativeSegmenter
from model.segment_anything import sam_model_registry, SamPredictor
# [New] Import utility functions for SAM prompt generation
from eval.utils import compute_logits_from_mask, masks_sample_points
# --- Configuration ---
# Model paths
MODEL_PATH = "JiaZL/STAMP-2B-uni"
SAM_PATH = "HCMUE-Research/SAM-vit-h/sam_vit_h_4b8939.pth"
IMAGE_PATH = "images/horses.png"
QUERY = "Please segment the white horse in the image."
USE_SAM = True # Enable SAM refinement (Recommended: True)
# --- Load Models ---
print(f"Loading STAMP model from {MODEL_PATH}...")
segmenter = GenerativeSegmenter(
MODEL_PATH,
device_map="cuda",
min_pixels=1024 * 28 * 28,
max_pixels=1280 * 28 * 28
)
print(f"Loading SAM model from {SAM_PATH}...")
sam = sam_model_registry["vit_h"](checkpoint=SAM_PATH)
sam = sam.to(dtype=torch.float32, device='cuda')
predictor = SamPredictor(sam)
# --- Inference ---
image = Image.open(IMAGE_PATH).convert("RGB")
w_ori, h_ori = image.size
with torch.inference_mode():
# 1. Set SAM image embedding (Compute once for efficiency)
if USE_SAM:
predictor.set_image(np.array(image))
# 2. Generate Coarse Mask using STAMP
print("Generating coarse mask with STAMP...")
segmentation_masks, response_text = segmenter.generate_with_segmentation(
image, QUERY
)
print(f"Model Response: {response_text}")
if not segmentation_masks or len(segmentation_masks) == 0:
print("No mask generated.")
exit()
# Extract the first mask
mask = segmentation_masks[0]
# Resize coarse mask to original image size [H, W]
mask_pred = F.interpolate(
mask.unsqueeze(0).unsqueeze(0).double(),
size=(h_ori, w_ori),
mode='nearest'
).squeeze(0).squeeze(0)
# --- SAM Refinement ---
final_mask = np.zeros((h_ori, w_ori), dtype=np.float32)
if USE_SAM:
print("Refining mask with SAM...")
# Get all unique class IDs (excluding background 0)
unique_classes = torch.unique(mask_pred)
for class_id in unique_classes:
if class_id == 0: continue
# Get binary mask for the current class
binary_mask = (mask_pred == class_id).double().cpu()
try:
# Generate Prompts (Logits and Points) from the coarse mask
logits = compute_logits_from_mask(binary_mask)
point_coords, point_labels = masks_sample_points(binary_mask)
# First pass prediction
sam_mask, _, logit = predictor.predict(
point_coords=point_coords,
point_labels=point_labels,
mask_input=logits,
multimask_output=False
)
# Iterative refinement (Standard Cascade: 2 times)
for _ in range(2):
sam_mask, _, logit = predictor.predict(
point_coords=point_coords,
point_labels=point_labels,
mask_input=logit,
multimask_output=False
)
# Merge results into the final mask
current_refined_mask = sam_mask[0].astype(np.float32)
final_mask = np.maximum(final_mask, current_refined_mask)
except Exception as e:
print(f"SAM Error for class {class_id}: {e}")
# Fallback to coarse mask if SAM fails
final_mask = np.maximum(final_mask, binary_mask.numpy())
else:
# Use coarse mask directly if SAM is disabled
final_mask = mask_pred.cpu().numpy()
# --- Save Result ---
# Convert to 0-255 uint8 format for saving
mask_uint8 = (final_mask > 0).astype(np.uint8) * 255
base_name = os.path.basename(IMAGE_PATH).split(".")[0]
save_name = f"{base_name}_mask_refined.png"
cv2.imwrite(save_name, mask_uint8)
print(f"Saved refined mask to {save_name}")
π Citation
If you find this work useful, please cite our paper:
@misc{liu2025betterstrongerfastertackling,
title={Better, Stronger, Faster: Tackling the Trilemma in MLLM-based Segmentation with Simultaneous Textual Mask Prediction},
author={Jiazhen Liu and Mingkuan Feng and Long Chen},
year={2025},
eprint={2512.00395},
archivePrefix={arXiv},
primaryClass={cs.CV},
url={https://arxiv.org/abs/2512.00395},
}
- Downloads last month
- 70
Inference Providers
NEW
This model isn't deployed by any Inference Provider.
π
Ask for provider support