Commit
Β·
d120439
1
Parent(s):
eda40d5
Fixed error
Browse files- src/dinotext.py β dinotext.py +7 -7
- hf_demo.ipynb +0 -0
- src/hooks.py β hooks.py +0 -0
- src/masker.py β masker.py +9 -9
- src/model.py β model.py +1 -1
- modeling_talk2dino.py +2 -2
- src/modules.py β modules.py +0 -0
- src/pamr.py β pamr.py +0 -0
- src/__init__.py +0 -0
- src/templates.py β templates.py +0 -0
- src/us.py β us.py +0 -0
src/dinotext.py β dinotext.py
RENAMED
|
@@ -16,14 +16,14 @@ from transformers import BertModel, AutoTokenizer
|
|
| 16 |
import torchvision.transforms as T
|
| 17 |
import clip
|
| 18 |
import importlib
|
| 19 |
-
|
| 20 |
|
| 21 |
-
from
|
| 22 |
-
from
|
| 23 |
-
from
|
| 24 |
|
| 25 |
-
from
|
| 26 |
-
from
|
| 27 |
|
| 28 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 29 |
|
|
@@ -280,7 +280,7 @@ class DINOText(nn.Module):
|
|
| 280 |
text_embs = text_embs.mean(dim=1).float()
|
| 281 |
if type(self.proj) == ProjectionLayer or type(self.proj) == DoubleMLP:
|
| 282 |
text_embs = self.proj.project_clip_txt(text_embs)
|
| 283 |
-
text_embs =
|
| 284 |
|
| 285 |
return text_embs
|
| 286 |
|
|
|
|
| 16 |
import torchvision.transforms as T
|
| 17 |
import clip
|
| 18 |
import importlib
|
| 19 |
+
from .us import normalize
|
| 20 |
|
| 21 |
+
from .pamr import PAMR
|
| 22 |
+
from .masker import DINOTextMasker
|
| 23 |
+
from .templates import get_template
|
| 24 |
|
| 25 |
+
from .model import ProjectionLayer, VisualProjectionLayer, CLIPLastLayer, DoubleMLP
|
| 26 |
+
from .hooks import average_text_tokens, get_vit_out, feats
|
| 27 |
|
| 28 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 29 |
|
|
|
|
| 280 |
text_embs = text_embs.mean(dim=1).float()
|
| 281 |
if type(self.proj) == ProjectionLayer or type(self.proj) == DoubleMLP:
|
| 282 |
text_embs = self.proj.project_clip_txt(text_embs)
|
| 283 |
+
text_embs = normalize(text_embs, dim=-1)
|
| 284 |
|
| 285 |
return text_embs
|
| 286 |
|
hf_demo.ipynb
CHANGED
|
The diff for this file is too large to render.
See raw diff
|
|
|
src/hooks.py β hooks.py
RENAMED
|
File without changes
|
src/masker.py β masker.py
RENAMED
|
@@ -8,11 +8,11 @@ import torch
|
|
| 8 |
import torch.distributed as dist
|
| 9 |
import torch.nn as nn
|
| 10 |
import torch.nn.functional as F
|
| 11 |
-
|
| 12 |
from einops import rearrange, repeat
|
| 13 |
|
| 14 |
# from models.dinotext.gumbel import gumbel_sigmoid
|
| 15 |
-
from
|
| 16 |
from omegaconf import OmegaConf
|
| 17 |
|
| 18 |
|
|
@@ -129,14 +129,14 @@ class Masker(nn.Module):
|
|
| 129 |
B = image.size(0)
|
| 130 |
image_emb, feats = self.image_encoder(image, image_feat, ret_feats=True) # [BCHW]
|
| 131 |
|
| 132 |
-
image_emb_norm =
|
| 133 |
-
text_emb_norm =
|
| 134 |
|
| 135 |
H, W = image_emb.shape[2:]
|
| 136 |
D = dist.get_world_size()
|
| 137 |
|
| 138 |
# simmap [B, B*D, H, W] where D is #devices
|
| 139 |
-
all_text_emb_norm =
|
| 140 |
simmap = torch.einsum("bchw,nc->bnhw", image_emb_norm, all_text_emb_norm)
|
| 141 |
mask, soft_mask = self.sim2mask(simmap, deterministic=deterministic)
|
| 142 |
|
|
@@ -178,8 +178,8 @@ class Masker(nn.Module):
|
|
| 178 |
"""
|
| 179 |
image_emb = self.image_encoder(image, image_feat) # [BCHW]
|
| 180 |
|
| 181 |
-
image_emb =
|
| 182 |
-
text_emb =
|
| 183 |
|
| 184 |
simmap = torch.einsum("b c h w, n c -> b n h w", image_emb, text_emb)
|
| 185 |
|
|
@@ -219,8 +219,8 @@ class DINOTextMasker(nn.Module):
|
|
| 219 |
n, c = text_emb.shape
|
| 220 |
|
| 221 |
if self.similarity_type == "cosine":
|
| 222 |
-
image_feat =
|
| 223 |
-
# text_emb =
|
| 224 |
simmap = torch.einsum("b c h w, n c -> b n h w", image_feat, text_emb)
|
| 225 |
else:
|
| 226 |
raise NotImplementedError("similarity type {} not implemented".format(self.similarity_type))
|
|
|
|
| 8 |
import torch.distributed as dist
|
| 9 |
import torch.nn as nn
|
| 10 |
import torch.nn.functional as F
|
| 11 |
+
from .us import normalize
|
| 12 |
from einops import rearrange, repeat
|
| 13 |
|
| 14 |
# from models.dinotext.gumbel import gumbel_sigmoid
|
| 15 |
+
from .modules import FeatureEncoder
|
| 16 |
from omegaconf import OmegaConf
|
| 17 |
|
| 18 |
|
|
|
|
| 129 |
B = image.size(0)
|
| 130 |
image_emb, feats = self.image_encoder(image, image_feat, ret_feats=True) # [BCHW]
|
| 131 |
|
| 132 |
+
image_emb_norm = normalize(image_emb, dim=1)
|
| 133 |
+
text_emb_norm = normalize(text_emb, dim=-1)
|
| 134 |
|
| 135 |
H, W = image_emb.shape[2:]
|
| 136 |
D = dist.get_world_size()
|
| 137 |
|
| 138 |
# simmap [B, B*D, H, W] where D is #devices
|
| 139 |
+
all_text_emb_norm = gather_cat(text_emb_norm, grad=True, contiguous_grad=True)
|
| 140 |
simmap = torch.einsum("bchw,nc->bnhw", image_emb_norm, all_text_emb_norm)
|
| 141 |
mask, soft_mask = self.sim2mask(simmap, deterministic=deterministic)
|
| 142 |
|
|
|
|
| 178 |
"""
|
| 179 |
image_emb = self.image_encoder(image, image_feat) # [BCHW]
|
| 180 |
|
| 181 |
+
image_emb = normalize(image_emb, dim=1) # BCHW
|
| 182 |
+
text_emb = normalize(text_emb, dim=-1) # NC
|
| 183 |
|
| 184 |
simmap = torch.einsum("b c h w, n c -> b n h w", image_emb, text_emb)
|
| 185 |
|
|
|
|
| 219 |
n, c = text_emb.shape
|
| 220 |
|
| 221 |
if self.similarity_type == "cosine":
|
| 222 |
+
image_feat = normalize(image_feat, dim=1) # BCHW
|
| 223 |
+
# text_emb = normalize(text_emb, dim=-1) # NKC
|
| 224 |
simmap = torch.einsum("b c h w, n c -> b n h w", image_feat, text_emb)
|
| 225 |
else:
|
| 226 |
raise NotImplementedError("similarity type {} not implemented".format(self.similarity_type))
|
src/model.py β model.py
RENAMED
|
@@ -4,7 +4,7 @@ import torch
|
|
| 4 |
import torch.nn as nn
|
| 5 |
import torch.nn.functional as F
|
| 6 |
|
| 7 |
-
from
|
| 8 |
|
| 9 |
class VisualProjectionLayer(nn.Module):
|
| 10 |
"""
|
|
|
|
| 4 |
import torch.nn as nn
|
| 5 |
import torch.nn.functional as F
|
| 6 |
|
| 7 |
+
from .hooks import get_self_attention, process_self_attention, feats
|
| 8 |
|
| 9 |
class VisualProjectionLayer(nn.Module):
|
| 10 |
"""
|
modeling_talk2dino.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
| 1 |
-
from
|
|
|
|
| 2 |
from transformers import PreTrainedModel
|
| 3 |
-
from configuration_talk2dino import Talk2DINOConfig
|
| 4 |
import clip
|
| 5 |
import torch
|
| 6 |
|
|
|
|
| 1 |
+
from .configuration_talk2dino import Talk2DINOConfig
|
| 2 |
+
from .dinotext import DINOText
|
| 3 |
from transformers import PreTrainedModel
|
|
|
|
| 4 |
import clip
|
| 5 |
import torch
|
| 6 |
|
src/modules.py β modules.py
RENAMED
|
File without changes
|
src/pamr.py β pamr.py
RENAMED
|
File without changes
|
src/__init__.py
DELETED
|
File without changes
|
src/templates.py β templates.py
RENAMED
|
File without changes
|
src/us.py β us.py
RENAMED
|
File without changes
|