File size: 8,764 Bytes
eda40d5
 
 
 
 
 
 
 
 
 
d120439
eda40d5
 
 
d120439
eda40d5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d120439
 
eda40d5
 
 
 
 
d120439
eda40d5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d120439
 
eda40d5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d120439
 
eda40d5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
# ------------------------------------------------------------------------------
# Talk2DINO
# ------------------------------------------------------------------------------
import copy
from collections import OrderedDict
import numpy as np
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
from .us import normalize
from einops import rearrange, repeat

# from models.dinotext.gumbel import gumbel_sigmoid
from .modules import FeatureEncoder
from omegaconf import OmegaConf


def build_model(config):
    model = OmegaConf.to_container(config, resolve=True)
    return model

class Sim2Mask(nn.Module):
    def __init__(self, init_w=1.0, init_b=0.0, gumbel_tau=1.0, learnable=True):
        super().__init__()
        self.init_w = init_w
        self.init_b = init_b
        self.gumbel_tau = gumbel_tau
        self.learnable = learnable

        assert not ((init_w is None) ^ (init_b is None))
        if learnable:
            self.w = nn.Parameter(torch.full([], float(init_w)))
            self.b = nn.Parameter(torch.full([], float(init_b)))
        else:
            self.w = init_w
            self.b = init_b

    def forward(self, x, deterministic=False):
        logits = x * self.w + self.b

        soft_mask = torch.sigmoid(logits)
        if deterministic:
            hard_mask = soft_mask.gt(0.5).type(logits.dtype)
        else:
            hard_mask = gumbel_sigmoid(logits, hard=True, tau=self.gumbel_tau)

        return hard_mask, soft_mask

    def extra_repr(self):
        return f'init_w={self.init_w}, init_b={self.init_b}, learnable={self.learnable}, gumbel_tau={self.gumbel_tau}'


class MaskerBackbone(nn.Module):
    """Masker image encoder backbone.
    """
    def __init__(self, clip_visual, freeze_idx):
        super().__init__()
        self.transformer = copy.deepcopy(clip_visual.transformer)
        self.transformer.resblocks = self.transformer.resblocks[freeze_idx:]

        for block in self.transformer.resblocks:
            if hasattr(block, "hook_handler"):
                block.hook_handler.remove()

        self.ln_post = copy.deepcopy(clip_visual.ln_post)
        self.proj = copy.deepcopy(clip_visual.proj)

        self.layers = len(self.transformer.resblocks)
        self.patch_size = clip_visual.patch_size

        self.output_dim = clip_visual.output_dim if self.proj is not None else clip_visual.width

    def forward(self, x, spatial=True, ignore_last_attn=True):
        if self.layers:
            x = self.transformer(x, ignore_last_attn=ignore_last_attn)

        x = x.permute(1, 0, 2)  # LND -> NLD

        if spatial:
            x = self.ln_post(x)
        else:
            x = self.ln_post(x[:, 0, :])

        if self.proj is not None:
            x = x @ self.proj

        return x

class MaskerImageFeatureEncoder(FeatureEncoder):
    def __init__(self, backbone: nn.Module, decoder: nn.Module, ignore_last_attn: bool = True):
        super().__init__()
        self.ignore_last_attn = ignore_last_attn
        self.patch_size = backbone.patch_size
        self.backbone = backbone
        self.decoder = decoder

        for resblock in self.backbone.transformer.resblocks:
            resblock.hook_handler = resblock.register_forward_hook(self.hook)

    def _encode(self, image, image_feat):
        H, W = image.shape[-2:]
        h = H // self.patch_size
        w = W // self.patch_size

        x = self.backbone(image_feat, spatial=True, ignore_last_attn=self.ignore_last_attn)  # BLC
        x = rearrange(x[:, 1:], "B (H W) C -> B C H W", H=h, W=w)
        x = self.decoder(x)

        return x

class Masker(nn.Module):
    def __init__(self, backbone, decoder, image_proj, sim2mask, ignore_last_attn, **kwargs):
        super().__init__()
        self.ignore_last_attn = ignore_last_attn

        decoder["C"] = backbone.output_dim
        decoder = MODELS.build(decoder)
        decoder = nn.Sequential(OrderedDict([
            ("decoder", decoder),
            ("image_proj", image_proj)
        ]))

        self.image_encoder = MaskerImageFeatureEncoder(backbone, decoder, ignore_last_attn=ignore_last_attn)

        self.sim2mask = Sim2Mask(**sim2mask)

    def forward(self, image, image_feat, text_emb, deterministic=False):
        B = image.size(0)
        image_emb, feats = self.image_encoder(image, image_feat, ret_feats=True)  # [BCHW]

        image_emb_norm = normalize(image_emb, dim=1)
        text_emb_norm = normalize(text_emb, dim=-1)

        H, W = image_emb.shape[2:]
        D = dist.get_world_size()

        # simmap [B, B*D, H, W] where D is #devices
        all_text_emb_norm = gather_cat(text_emb_norm, grad=True, contiguous_grad=True)
        simmap = torch.einsum("bchw,nc->bnhw", image_emb_norm, all_text_emb_norm)
        mask, soft_mask = self.sim2mask(simmap, deterministic=deterministic)

        # mask [B, B*D, H, W] where D is #devices
        # positive global label
        pos_indices = torch.arange(B, dtype=torch.long, device=image_emb.device) + B * dist.get_rank()
        pos_mask = mask[torch.arange(B), pos_indices].unsqueeze(1)  # [B, 1, H, W]

        offdiag = torch.ones(B, B*D, dtype=torch.bool, device=mask.device)
        offdiag[torch.arange(B), pos_indices] = False

        soft_pos_mask = soft_mask[torch.arange(B), pos_indices].unsqueeze(1)
        soft_neg_mask = soft_mask.masked_select(offdiag[..., None, None]).view(B, B*D-1, H, W)

        masks = {
            "pos": pos_mask,  # [B, 1, H, W]

            "soft_pos": soft_pos_mask,
            "soft_neg": soft_neg_mask,
            "soft_all": soft_mask,  # [B, N, H, W]
        }

        return masks, image_emb, text_emb, feats

    @torch.no_grad()
    def forward_seg(self, image, image_feat, text_emb, deterministic=True, hard=False):
        """Make mask by 1:N matching

        Args:
            image [B, 3, H, W]
            image_feat [L, B, C]: CLIP features
            text_emb [N, C]
            deterministic (bool): deterministic inference flag for gumbel noise
            hard (bool): decide hard or soft returning segmentation mask.
                Note that soft mask is required for proper evaluation

        Return:
            mask [B, N, H', W'] (H' and W' are downsampled H/W)
        """
        image_emb = self.image_encoder(image, image_feat)  # [BCHW]

        image_emb = normalize(image_emb, dim=1)  # BCHW
        text_emb = normalize(text_emb, dim=-1)  # NC

        simmap = torch.einsum("b c h w, n c -> b n h w", image_emb, text_emb)

        hard_mask, soft_mask = self.sim2mask(simmap, deterministic=deterministic)
        mask = hard_mask if hard else soft_mask

        return mask, simmap

class DINOTextMasker(nn.Module):
    def __init__(self, similarity_type="cosine"):
        super().__init__()
        self.sim2mask = DINOTextSim2Mask()
        self.sim2mask = self.sim2mask.eval()
        self.similarity_type = similarity_type

    def forward(self, image, image_feat, text_emb, deterministic=False):
        pass

    @torch.no_grad()
    def forward_seg(self, image_feat, text_emb, deterministic=True, hard=False):
        """Make mask by 1:N matching

        Args:
            image [B, 3, H, W]
            image_feat [L, B, C]: CLIP features
            text_emb [N, K, C]
            deterministic (bool): deterministic inference flag for gumbel noise
            hard (bool): decide hard or soft returning segmentation mask.
                Note that soft mask is required for proper evaluation
            use_k_nn (bool): use kNN to segment
            k_nn (int): number of nearest neighbors for kNN segmentation

        Return:
            mask [B, N, H', W'] (H' and W' are downsampled H/W)
        """
        b, c, h, w = image_feat.shape
        n, c = text_emb.shape

        if self.similarity_type == "cosine":
            image_feat = normalize(image_feat, dim=1)  # BCHW
            # text_emb = normalize(text_emb, dim=-1)  # NKC
            simmap = torch.einsum("b c h w, n c -> b n h w", image_feat, text_emb)
        else:
            raise NotImplementedError("similarity type {} not implemented".format(self.similarity_type))

        hard_mask, soft_mask = self.sim2mask(simmap, deterministic=deterministic)
        mask = hard_mask if hard else soft_mask

        return mask, simmap


class DINOTextSim2Mask(nn.Module):
    def __init__(self, gumbel_tau=1.0):
        super().__init__()
        self.gumbel_tau = gumbel_tau

    def forward(self, x, deterministic=False):
        soft_mask = torch.sigmoid(x)
        if deterministic:
            hard_mask = soft_mask.gt(0.5).type(x.dtype)
        else:
            hard_mask = gumbel_sigmoid(x, hard=True, tau=self.gumbel_tau)

        return hard_mask, soft_mask