|
|
import itertools |
|
|
import os |
|
|
import pickle |
|
|
from math import sqrt |
|
|
import re |
|
|
import yaml |
|
|
|
|
|
import numpy as np |
|
|
import timm |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
import torchvision |
|
|
from einops import rearrange |
|
|
from transformers import BertModel, AutoTokenizer |
|
|
import torchvision.transforms as T |
|
|
import clip |
|
|
import importlib |
|
|
from .us import normalize |
|
|
|
|
|
from .pamr import PAMR |
|
|
from .masker import DINOTextMasker |
|
|
from .templates import get_template |
|
|
|
|
|
from .model import ProjectionLayer, VisualProjectionLayer, CLIPLastLayer, DoubleMLP |
|
|
from .hooks import average_text_tokens, get_vit_out, feats |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
|
|
|
|
|
|
class DINOText(nn.Module): |
|
|
|
|
|
def get_self_attention(self, module, input, output): |
|
|
self.feats['self_attn'] = output |
|
|
|
|
|
def get_clip_second_last_dense_out(self, model: torch.nn.Module, input: torch.Tensor, output: torch.Tensor): |
|
|
self.feats['clip_second_last_out'] = output |
|
|
self.feats['clip_second_last_out'].to(dtype=torch.float32) |
|
|
|
|
|
def get_all_out_tokens(self, model: torch.nn.Module, input: torch.Tensor, output: torch.Tensor): |
|
|
self.feats['clip_txt_out_tokens'] = output |
|
|
|
|
|
def __init__( |
|
|
self, model_name, resize_dim, clip_model_name, proj_class, proj_name, proj_model, avg_self_attn_token=False, disentangled_self_attn_token=True, loss=None, pre_trained=True, |
|
|
unfreeze_last_text_layer=False, unfreeze_last_image_layer=False, is_eval=True, use_avg_text_token=False, keep_cls=False, keep_end_seq=False, with_bg_clean=False, **kwargs |
|
|
): |
|
|
nn.Module.__init__(self) |
|
|
|
|
|
self.feats = {} |
|
|
self.model_name = model_name |
|
|
|
|
|
|
|
|
if 'dinov2' in model_name: |
|
|
self.model_family = 'facebookresearch/dinov2' if 'dinov2' in model_name else 'facebookresearch/dino:main' |
|
|
self.model = torch.hub.load(self.model_family, model_name) |
|
|
elif 'dinov3' in model_name: |
|
|
def extract_dinov3_name(path, n_parts=2): |
|
|
filename = os.path.basename(path) |
|
|
parts = filename.split("_") |
|
|
return "_".join(parts[:n_parts]) |
|
|
self.model = torch.hub.load('src/dinov3', extract_dinov3_name(model_name), source='local', weights=model_name) |
|
|
|
|
|
|
|
|
elif 'mae' in model_name or 'sam' in model_name or 'clip' in model_name or 'dino' in model_name: |
|
|
self.model = timm.create_model( |
|
|
model_name, |
|
|
pretrained=True, |
|
|
num_classes=0, |
|
|
img_size=resize_dim |
|
|
) |
|
|
|
|
|
if 'sam' in model_name: |
|
|
self.model.blocks[-1].register_forward_hook(get_vit_out) |
|
|
else: |
|
|
raise Exception("Unknown ViT model") |
|
|
|
|
|
mean = (0.485, 0.456, 0.406) if not 'clip' in model_name else (0.4815, 0.4578, 0.4082) |
|
|
std = (0.229, 0.224, 0.225) if not 'clip' in model_name else (0.2686, 0.2613, 0.2758) |
|
|
self.image_transforms = T.Compose([ |
|
|
T.Resize((resize_dim, resize_dim)), |
|
|
lambda x: T.ToTensor()(x) if not isinstance(x, torch.Tensor) else x / 255.0, |
|
|
T.Normalize(mean, std), |
|
|
]) |
|
|
|
|
|
self.model |
|
|
self.model.requires_grad_(False) |
|
|
|
|
|
self.clip_model_name = clip_model_name |
|
|
if 'bert' in self.clip_model_name: |
|
|
self.clip_model = BertModel.from_pretrained(self.clip_model_name, output_hidden_states = False) |
|
|
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained(self.clip_model_name) |
|
|
else: |
|
|
self.clip_model, _ = clip.load(clip_model_name, device='meta') |
|
|
self.clip_model.eval() |
|
|
self.clip_model.requires_grad_(False) |
|
|
if unfreeze_last_text_layer: |
|
|
for param in self.clip_model.transformer.resblocks[-1].parameters(): |
|
|
param.requires_grad = True |
|
|
for param in self.clip_model.ln_final.parameters(): |
|
|
param.requires_grad = True |
|
|
self.clip_model.text_projection.requires_grad = True |
|
|
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) |
|
|
|
|
|
|
|
|
|
|
|
if 'vitb_mlp_infonce' in proj_class: |
|
|
config = { |
|
|
'act': 'tanh', |
|
|
'hidden_layer': True, |
|
|
'dino_embed_dim': 768 |
|
|
} |
|
|
elif 'vitl_mlp_infonce' in proj_class: |
|
|
config = { |
|
|
'act': 'tanh', |
|
|
'hidden_layer': True, |
|
|
'dino_embed_dim': 1024 |
|
|
} |
|
|
|
|
|
self.proj = ProjectionLayer.from_config(config) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.proj |
|
|
|
|
|
self.masker = DINOTextMasker(similarity_type="cosine") |
|
|
self.masker = self.masker.eval() |
|
|
|
|
|
self.pamr = None |
|
|
|
|
|
self.avg_self_attn_token = avg_self_attn_token |
|
|
self.disentangled_self_attn_token = disentangled_self_attn_token |
|
|
|
|
|
if self.avg_self_attn_token or self.disentangled_self_attn_token or is_eval: |
|
|
self.model.blocks[-1].attn.qkv.register_forward_hook(self.get_self_attention) |
|
|
self.num_global_tokens = 5 if 'reg' in model_name or 'dinov3' in model_name else 1 |
|
|
if 'sam' in self.model_name: |
|
|
self.num_global_tokens = 0 |
|
|
self.num_attn_heads = self.model.num_heads |
|
|
self.scale = 0.125 |
|
|
|
|
|
self.use_avg_text_token = use_avg_text_token |
|
|
if self.use_avg_text_token: |
|
|
self.feats = {} |
|
|
|
|
|
self.clip_model.ln_final.register_forward_hook(self.get_all_out_tokens) |
|
|
self.keep_cls = keep_cls |
|
|
self.keep_end_seq = keep_end_seq |
|
|
|
|
|
self.with_bg_clean = with_bg_clean |
|
|
|
|
|
|
|
|
def process_self_attention(self, output, batch_size, num_tokens, num_attn_heads, embed_dim, scale, num_global_tokens, ret_self_attn_maps=False): |
|
|
qkv = output.reshape(batch_size, num_tokens, 3, num_attn_heads, embed_dim // num_attn_heads).permute(2, 0, 3, 1, 4) |
|
|
q, k, v = qkv[0] * scale, qkv[1], qkv[2] |
|
|
attn = q @ k.transpose(-2, -1) |
|
|
self_attn_maps = attn[:, : , 0, num_global_tokens:] |
|
|
self_attn = self_attn_maps.mean(dim=1) |
|
|
self_attn = self_attn.softmax(dim=-1) |
|
|
if ret_self_attn_maps: |
|
|
return self_attn, self_attn_maps |
|
|
else: |
|
|
return self_attn |
|
|
|
|
|
def encode_text(self, tokenized_texts): |
|
|
x = self.clip_model.encode_text(tokenized_texts) |
|
|
return x |
|
|
|
|
|
def encode_image(self, images): |
|
|
batch_size, _, _, _ = images.shape |
|
|
self_attn_maps = None |
|
|
x = self.model(images, is_training=(self.avg_self_attn_token or self.disentangled_self_attn_token)) |
|
|
batch_size, num_tokens, embed_dim = x['x_norm_patchtokens'].shape |
|
|
num_tokens = num_tokens + self.num_global_tokens |
|
|
if self.avg_self_attn_token or self.disentangled_self_attn_token: |
|
|
self_attn, self_attn_maps = self.process_self_attention(self.feats['self_attn'], batch_size, num_tokens, self.num_attn_heads, embed_dim, self.scale, self.num_global_tokens, ret_self_attn_maps=True) |
|
|
if self.avg_self_attn_token: |
|
|
x = (self_attn.unsqueeze(-1) * x['x_norm_patchtokens']).mean(dim=1) |
|
|
elif self.disentangled_self_attn_token: |
|
|
self_attn_maps = self_attn_maps.softmax(dim=-1) |
|
|
x = (x['x_norm_patchtokens'].unsqueeze(1) * self_attn_maps.unsqueeze(-1)).mean(dim=2) |
|
|
|
|
|
return x, self_attn_maps |
|
|
|
|
|
def forward(self, image, text, return_logit_scale=False): |
|
|
with torch.no_grad(): |
|
|
txt_embed = self.encode_text(text) |
|
|
|
|
|
img_embed, self_attn_maps = self.encode_image(image) |
|
|
|
|
|
if type(self.proj) == CLIPLastLayer: |
|
|
img_embed, txt_embed = self.proj(img_embed, txt_embed, ret_embeds=True, self_attn_maps=self_attn_maps, text_argmax=text.argmax(dim=-1)) |
|
|
else: |
|
|
img_embed, txt_embed = self.proj(img_embed, txt_embed, ret_embeds=True, self_attn_maps=self_attn_maps) |
|
|
|
|
|
if return_logit_scale: |
|
|
return txt_embed, img_embed, self.logit_scale |
|
|
|
|
|
return txt_embed, img_embed |
|
|
|
|
|
def compute_loss(self, image, text, cosine=True, ret_similarity_matrix=True): |
|
|
ret = {} |
|
|
if cosine: |
|
|
img_embed = F.normalize(img_embed, p=2, dim=1) |
|
|
txt_embed = F.normalize(txt_embed, p=2, dim=1) |
|
|
sim = img_embed @ txt_embed.transpose(1, 0) |
|
|
if not ret_similarity_matrix: |
|
|
sim = sim[torch.eye(len(sim)) > 0.5] |
|
|
|
|
|
ret['contrastive_loss'] = self.contrastive_loss.compute_contrastive_loss(sim) |
|
|
|
|
|
return ret |
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def build_dataset_class_tokens(self, template_set, classnames): |
|
|
tokens = [] |
|
|
templates = get_template(template_set) |
|
|
for classname in classnames: |
|
|
if 'bert' not in self.clip_model_name: |
|
|
tokens.append( |
|
|
clip.tokenize([template.format(classname) for template in templates]) |
|
|
) |
|
|
else: |
|
|
tokens.append(self.tokenizer([template.format(classname) for template in templates], return_tensors='pt', padding='max_length')['input_ids']) |
|
|
|
|
|
tokens = torch.stack(tokens) |
|
|
|
|
|
return tokens |
|
|
|
|
|
@torch.no_grad() |
|
|
def build_text_embedding(self, text): |
|
|
""" |
|
|
Args: |
|
|
text (torch.Tensor): [NUM_CLASSES, NUM_TEMPLATES, CONTEXT_LENGTH] text tokens |
|
|
|
|
|
Returns: |
|
|
text_embs |
|
|
""" |
|
|
text = text.to(next(self.parameters()).device) |
|
|
num_classes, num_templates = text.shape[:2] |
|
|
text_argmax = text.argmax(dim=-1) |
|
|
text_argmax = rearrange(text_argmax, 'n t -> (n t)', n=num_classes, t=num_templates) |
|
|
text = rearrange(text, 'n t l -> (n t) l', n=num_classes, t=num_templates) |
|
|
|
|
|
chunk_size = 32 |
|
|
N = text.size(0) |
|
|
if type(self.proj) == CLIPLastLayer: |
|
|
text_embs = torch.cat([ |
|
|
self.proj.project_clip_txt(self.encode_text(text[i:i + chunk_size]).permute(1, 0, 2), text_argmax=text_argmax[i:i + chunk_size]) |
|
|
for i in range(0, N, chunk_size) |
|
|
]) |
|
|
else: |
|
|
if not self.use_avg_text_token: |
|
|
|
|
|
if 'bert' not in self.clip_model_name: |
|
|
text_embs = torch.cat([ |
|
|
self.clip_model.encode_text(text[i:i + chunk_size]) |
|
|
for i in range(0, N, chunk_size) |
|
|
]) |
|
|
else: |
|
|
|
|
|
text_embs = [] |
|
|
for i in range(0, N, chunk_size): |
|
|
outputs = self.clip_model(text[i:i + chunk_size]) |
|
|
text_embs.append(outputs['pooler_output']) |
|
|
text_embs = torch.cat(text_embs) |
|
|
else: |
|
|
|
|
|
text_embs = [] |
|
|
for i in range(0, N, chunk_size): |
|
|
self.clip_model.encode_text(text[i:i + chunk_size]) |
|
|
text_embs.append(average_text_tokens(self.feats['clip_txt_out_tokens'] @ self.clip_model.text_projection, text[i:i + chunk_size] > 0, self.keep_cls, self.keep_end_seq)) |
|
|
text_embs = torch.cat(text_embs) |
|
|
|
|
|
text_embs = rearrange(text_embs, '(n t) c -> n t c', n=num_classes, t=num_templates) |
|
|
|
|
|
text_embs = text_embs.mean(dim=1).float() |
|
|
if type(self.proj) == ProjectionLayer or type(self.proj) == DoubleMLP: |
|
|
text_embs = self.proj.project_clip_txt(text_embs) |
|
|
text_embs = normalize(text_embs, dim=-1) |
|
|
|
|
|
return text_embs |
|
|
|
|
|
def apply_pamr(self, image, mask): |
|
|
image = F.interpolate(image, mask.shape[-2:], mode="bilinear", align_corners=True) |
|
|
if self.pamr is None: |
|
|
pamr_iter = 10 |
|
|
pamr_kernel = [1, 2, 4, 8, 12, 24] |
|
|
self.pamr = PAMR(pamr_iter, pamr_kernel) |
|
|
self.pamr.eval() |
|
|
self.pamr.to(next(self.parameters()).device) |
|
|
|
|
|
mask = self.pamr(image, mask) |
|
|
return mask |
|
|
|
|
|
def compute_padsize(self, H: int, W: int, patch_size: int): |
|
|
l, r, t, b = 0, 0, 0, 0 |
|
|
if W % patch_size: |
|
|
lr = patch_size - (W % patch_size) |
|
|
l = lr // 2 |
|
|
r = lr - l |
|
|
|
|
|
if H % patch_size: |
|
|
tb = patch_size - (H % patch_size) |
|
|
t = tb // 2 |
|
|
b = tb - t |
|
|
|
|
|
return l, r, t, b |
|
|
|
|
|
@torch.no_grad() |
|
|
def generate_masks( |
|
|
self, image, img_metas, text_emb, classnames, text_is_token=False, apply_pamr=False, background_func="weighted_average_sigmoid", lambda_bg=0.2, |
|
|
|
|
|
): |
|
|
"""Generate masks for each text embeddings |
|
|
|
|
|
Args: |
|
|
image [B, 3, H, W] |
|
|
|
|
|
Returns: |
|
|
softmask [B, N, H, W]: softmasks for each text embeddings |
|
|
""" |
|
|
|
|
|
H, W = image.shape[2:] |
|
|
|
|
|
|
|
|
pH, pW = image.shape[2:] |
|
|
num_classes = text_emb.shape[0] |
|
|
batch_size = image.shape[0] |
|
|
|
|
|
image = image[:, [2, 1, 0], :, :] |
|
|
ori_image = image.clone() |
|
|
|
|
|
img_preprocessed = self.image_transforms(image).to(next(self.parameters()).device) |
|
|
if 'dinov2' in self.model_name or 'dinov3' in self.model_name: |
|
|
image_feat = self.model.forward_features(img_preprocessed)['x_norm_patchtokens'] |
|
|
elif 'mae' in self.model_name or 'clip' in self.model_name or 'dino' in self.model_name: |
|
|
image_feat = self.model.forward_features(img_preprocessed)[:, 1:, :] |
|
|
elif 'sam' in self.model_name: |
|
|
self.model.forward_features(img_preprocessed) |
|
|
image_feat = feats['vit_out'].reshape(feats['vit_out'].shape[0], feats['vit_out'].shape[1]**2, feats['vit_out'].shape[-1]) |
|
|
|
|
|
batch_size, num_tokens, embed_dim = image_feat.shape |
|
|
if type(self.proj) == VisualProjectionLayer: |
|
|
image_feat = self.proj.project_dino(image_feat.float()) |
|
|
if type(self.proj) == DoubleMLP: |
|
|
image_feat = self.proj.project_visual(image_feat.float()) |
|
|
b, np, c = image_feat.shape |
|
|
np_h = np_w = int(sqrt(np)) |
|
|
image_feat = image_feat.reshape(b, np_h, np_w, c).permute(0, 3, 1, 2) |
|
|
|
|
|
self_attn, self_attn_maps = self.process_self_attention(self.feats['self_attn'], batch_size, num_tokens + self.num_global_tokens, self.num_attn_heads, embed_dim, self.scale, self.num_global_tokens, ret_self_attn_maps=True) |
|
|
mask, simmap = self.masker.forward_seg(image_feat, text_emb, hard=False) |
|
|
|
|
|
if self.with_bg_clean: |
|
|
mask = self.similarity_assignment_weighted(mask, image_feat, self_attn_maps, text_emb, lambda_bg) |
|
|
|
|
|
|
|
|
mask = F.interpolate(mask, (pH, pW), mode='bilinear', align_corners=True) |
|
|
|
|
|
if apply_pamr: |
|
|
for c in range(0, mask.shape[1], 30): |
|
|
mask[:, c:c + 30] = self.apply_pamr(ori_image, mask[:, c:c + 30]) |
|
|
|
|
|
assert mask.shape[2] == H and mask.shape[3] == W, f"shape mismatch: ({H}, {W}) / {mask.shape}" |
|
|
|
|
|
return mask, simmap |
|
|
|
|
|
def similarity_assignment_weighted(self, mask, image_feat, self_attn_maps, text_emb, lambda_bg=0.2): |
|
|
bs, c, h, w = image_feat.shape |
|
|
bs, num_classes, h, w = mask.shape |
|
|
bs, num_heads, hw = self_attn_maps.shape |
|
|
image_feat = image_feat.reshape(bs, c, hw) |
|
|
num_classes, c = text_emb.shape |
|
|
avg_head_embed = (self_attn_maps.unsqueeze(2) * image_feat.unsqueeze(1)).mean(dim=-1) |
|
|
avg_head_embed = avg_head_embed / avg_head_embed.norm(dim=-1, keepdim=True) |
|
|
avg_head_embed = avg_head_embed.permute(0, 2, 1) |
|
|
head_text_sim = text_emb.unsqueeze(0) @ avg_head_embed |
|
|
head_text_sim = (head_text_sim).softmax(dim=-1) |
|
|
head_text_sim_sum = head_text_sim.sum(dim=-1) |
|
|
|
|
|
self_attn_maps_repeat = self_attn_maps.unsqueeze(1).repeat(1, num_classes, 1, 1) |
|
|
head_text_sim_repeat = head_text_sim.unsqueeze(-1).repeat(1, 1, 1, hw) |
|
|
avg_self_attn_per_class = (self_attn_maps_repeat * head_text_sim_repeat).sum(dim=2) / head_text_sim_sum.unsqueeze(-1).repeat(1, 1, hw) |
|
|
avg_self_attn_per_class = avg_self_attn_per_class.softmax(dim=-1) |
|
|
|
|
|
min_self_attn = avg_self_attn_per_class.min().item() |
|
|
max_self_attn = avg_self_attn_per_class.max().item() |
|
|
max_self_attn = max(max_self_attn, max_self_attn - min_self_attn) |
|
|
avg_self_attn_per_class = avg_self_attn_per_class - min_self_attn |
|
|
avg_self_attn_per_class = avg_self_attn_per_class / max_self_attn |
|
|
avg_self_attn_per_class = avg_self_attn_per_class * (mask.max() - mask.min()) + mask.min() |
|
|
mask = mask.reshape(num_classes, hw) |
|
|
mask_output = (mask + lambda_bg * avg_self_attn_per_class).reshape(bs, num_classes, h, w) / (1 + lambda_bg) |
|
|
return mask_output |
|
|
|
|
|
|
|
|
|
|
|
|