Talk2DINO-ViTB / masker.py
lorebianchi98's picture
Fixed error
d120439
# ------------------------------------------------------------------------------
# Talk2DINO
# ------------------------------------------------------------------------------
import copy
from collections import OrderedDict
import numpy as np
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
from .us import normalize
from einops import rearrange, repeat
# from models.dinotext.gumbel import gumbel_sigmoid
from .modules import FeatureEncoder
from omegaconf import OmegaConf
def build_model(config):
model = OmegaConf.to_container(config, resolve=True)
return model
class Sim2Mask(nn.Module):
def __init__(self, init_w=1.0, init_b=0.0, gumbel_tau=1.0, learnable=True):
super().__init__()
self.init_w = init_w
self.init_b = init_b
self.gumbel_tau = gumbel_tau
self.learnable = learnable
assert not ((init_w is None) ^ (init_b is None))
if learnable:
self.w = nn.Parameter(torch.full([], float(init_w)))
self.b = nn.Parameter(torch.full([], float(init_b)))
else:
self.w = init_w
self.b = init_b
def forward(self, x, deterministic=False):
logits = x * self.w + self.b
soft_mask = torch.sigmoid(logits)
if deterministic:
hard_mask = soft_mask.gt(0.5).type(logits.dtype)
else:
hard_mask = gumbel_sigmoid(logits, hard=True, tau=self.gumbel_tau)
return hard_mask, soft_mask
def extra_repr(self):
return f'init_w={self.init_w}, init_b={self.init_b}, learnable={self.learnable}, gumbel_tau={self.gumbel_tau}'
class MaskerBackbone(nn.Module):
"""Masker image encoder backbone.
"""
def __init__(self, clip_visual, freeze_idx):
super().__init__()
self.transformer = copy.deepcopy(clip_visual.transformer)
self.transformer.resblocks = self.transformer.resblocks[freeze_idx:]
for block in self.transformer.resblocks:
if hasattr(block, "hook_handler"):
block.hook_handler.remove()
self.ln_post = copy.deepcopy(clip_visual.ln_post)
self.proj = copy.deepcopy(clip_visual.proj)
self.layers = len(self.transformer.resblocks)
self.patch_size = clip_visual.patch_size
self.output_dim = clip_visual.output_dim if self.proj is not None else clip_visual.width
def forward(self, x, spatial=True, ignore_last_attn=True):
if self.layers:
x = self.transformer(x, ignore_last_attn=ignore_last_attn)
x = x.permute(1, 0, 2) # LND -> NLD
if spatial:
x = self.ln_post(x)
else:
x = self.ln_post(x[:, 0, :])
if self.proj is not None:
x = x @ self.proj
return x
class MaskerImageFeatureEncoder(FeatureEncoder):
def __init__(self, backbone: nn.Module, decoder: nn.Module, ignore_last_attn: bool = True):
super().__init__()
self.ignore_last_attn = ignore_last_attn
self.patch_size = backbone.patch_size
self.backbone = backbone
self.decoder = decoder
for resblock in self.backbone.transformer.resblocks:
resblock.hook_handler = resblock.register_forward_hook(self.hook)
def _encode(self, image, image_feat):
H, W = image.shape[-2:]
h = H // self.patch_size
w = W // self.patch_size
x = self.backbone(image_feat, spatial=True, ignore_last_attn=self.ignore_last_attn) # BLC
x = rearrange(x[:, 1:], "B (H W) C -> B C H W", H=h, W=w)
x = self.decoder(x)
return x
class Masker(nn.Module):
def __init__(self, backbone, decoder, image_proj, sim2mask, ignore_last_attn, **kwargs):
super().__init__()
self.ignore_last_attn = ignore_last_attn
decoder["C"] = backbone.output_dim
decoder = MODELS.build(decoder)
decoder = nn.Sequential(OrderedDict([
("decoder", decoder),
("image_proj", image_proj)
]))
self.image_encoder = MaskerImageFeatureEncoder(backbone, decoder, ignore_last_attn=ignore_last_attn)
self.sim2mask = Sim2Mask(**sim2mask)
def forward(self, image, image_feat, text_emb, deterministic=False):
B = image.size(0)
image_emb, feats = self.image_encoder(image, image_feat, ret_feats=True) # [BCHW]
image_emb_norm = normalize(image_emb, dim=1)
text_emb_norm = normalize(text_emb, dim=-1)
H, W = image_emb.shape[2:]
D = dist.get_world_size()
# simmap [B, B*D, H, W] where D is #devices
all_text_emb_norm = gather_cat(text_emb_norm, grad=True, contiguous_grad=True)
simmap = torch.einsum("bchw,nc->bnhw", image_emb_norm, all_text_emb_norm)
mask, soft_mask = self.sim2mask(simmap, deterministic=deterministic)
# mask [B, B*D, H, W] where D is #devices
# positive global label
pos_indices = torch.arange(B, dtype=torch.long, device=image_emb.device) + B * dist.get_rank()
pos_mask = mask[torch.arange(B), pos_indices].unsqueeze(1) # [B, 1, H, W]
offdiag = torch.ones(B, B*D, dtype=torch.bool, device=mask.device)
offdiag[torch.arange(B), pos_indices] = False
soft_pos_mask = soft_mask[torch.arange(B), pos_indices].unsqueeze(1)
soft_neg_mask = soft_mask.masked_select(offdiag[..., None, None]).view(B, B*D-1, H, W)
masks = {
"pos": pos_mask, # [B, 1, H, W]
"soft_pos": soft_pos_mask,
"soft_neg": soft_neg_mask,
"soft_all": soft_mask, # [B, N, H, W]
}
return masks, image_emb, text_emb, feats
@torch.no_grad()
def forward_seg(self, image, image_feat, text_emb, deterministic=True, hard=False):
"""Make mask by 1:N matching
Args:
image [B, 3, H, W]
image_feat [L, B, C]: CLIP features
text_emb [N, C]
deterministic (bool): deterministic inference flag for gumbel noise
hard (bool): decide hard or soft returning segmentation mask.
Note that soft mask is required for proper evaluation
Return:
mask [B, N, H', W'] (H' and W' are downsampled H/W)
"""
image_emb = self.image_encoder(image, image_feat) # [BCHW]
image_emb = normalize(image_emb, dim=1) # BCHW
text_emb = normalize(text_emb, dim=-1) # NC
simmap = torch.einsum("b c h w, n c -> b n h w", image_emb, text_emb)
hard_mask, soft_mask = self.sim2mask(simmap, deterministic=deterministic)
mask = hard_mask if hard else soft_mask
return mask, simmap
class DINOTextMasker(nn.Module):
def __init__(self, similarity_type="cosine"):
super().__init__()
self.sim2mask = DINOTextSim2Mask()
self.sim2mask = self.sim2mask.eval()
self.similarity_type = similarity_type
def forward(self, image, image_feat, text_emb, deterministic=False):
pass
@torch.no_grad()
def forward_seg(self, image_feat, text_emb, deterministic=True, hard=False):
"""Make mask by 1:N matching
Args:
image [B, 3, H, W]
image_feat [L, B, C]: CLIP features
text_emb [N, K, C]
deterministic (bool): deterministic inference flag for gumbel noise
hard (bool): decide hard or soft returning segmentation mask.
Note that soft mask is required for proper evaluation
use_k_nn (bool): use kNN to segment
k_nn (int): number of nearest neighbors for kNN segmentation
Return:
mask [B, N, H', W'] (H' and W' are downsampled H/W)
"""
b, c, h, w = image_feat.shape
n, c = text_emb.shape
if self.similarity_type == "cosine":
image_feat = normalize(image_feat, dim=1) # BCHW
# text_emb = normalize(text_emb, dim=-1) # NKC
simmap = torch.einsum("b c h w, n c -> b n h w", image_feat, text_emb)
else:
raise NotImplementedError("similarity type {} not implemented".format(self.similarity_type))
hard_mask, soft_mask = self.sim2mask(simmap, deterministic=deterministic)
mask = hard_mask if hard else soft_mask
return mask, simmap
class DINOTextSim2Mask(nn.Module):
def __init__(self, gumbel_tau=1.0):
super().__init__()
self.gumbel_tau = gumbel_tau
def forward(self, x, deterministic=False):
soft_mask = torch.sigmoid(x)
if deterministic:
hard_mask = soft_mask.gt(0.5).type(x.dtype)
else:
hard_mask = gumbel_sigmoid(x, hard=True, tau=self.gumbel_tau)
return hard_mask, soft_mask