lorebianchi98 commited on
Commit
d120439
Β·
1 Parent(s): eda40d5

Fixed error

Browse files
src/dinotext.py β†’ dinotext.py RENAMED
@@ -16,14 +16,14 @@ from transformers import BertModel, AutoTokenizer
16
  import torchvision.transforms as T
17
  import clip
18
  import importlib
19
- import src.us as us
20
 
21
- from src.pamr import PAMR
22
- from src.masker import DINOTextMasker
23
- from src.templates import get_template
24
 
25
- from src.model import ProjectionLayer, VisualProjectionLayer, CLIPLastLayer, DoubleMLP
26
- from src.hooks import average_text_tokens, get_vit_out, feats
27
 
28
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
29
 
@@ -280,7 +280,7 @@ class DINOText(nn.Module):
280
  text_embs = text_embs.mean(dim=1).float()
281
  if type(self.proj) == ProjectionLayer or type(self.proj) == DoubleMLP:
282
  text_embs = self.proj.project_clip_txt(text_embs)
283
- text_embs = us.normalize(text_embs, dim=-1)
284
 
285
  return text_embs
286
 
 
16
  import torchvision.transforms as T
17
  import clip
18
  import importlib
19
+ from .us import normalize
20
 
21
+ from .pamr import PAMR
22
+ from .masker import DINOTextMasker
23
+ from .templates import get_template
24
 
25
+ from .model import ProjectionLayer, VisualProjectionLayer, CLIPLastLayer, DoubleMLP
26
+ from .hooks import average_text_tokens, get_vit_out, feats
27
 
28
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
29
 
 
280
  text_embs = text_embs.mean(dim=1).float()
281
  if type(self.proj) == ProjectionLayer or type(self.proj) == DoubleMLP:
282
  text_embs = self.proj.project_clip_txt(text_embs)
283
+ text_embs = normalize(text_embs, dim=-1)
284
 
285
  return text_embs
286
 
hf_demo.ipynb CHANGED
The diff for this file is too large to render. See raw diff
 
src/hooks.py β†’ hooks.py RENAMED
File without changes
src/masker.py β†’ masker.py RENAMED
@@ -8,11 +8,11 @@ import torch
8
  import torch.distributed as dist
9
  import torch.nn as nn
10
  import torch.nn.functional as F
11
- import src.us as us
12
  from einops import rearrange, repeat
13
 
14
  # from models.dinotext.gumbel import gumbel_sigmoid
15
- from src.modules import FeatureEncoder
16
  from omegaconf import OmegaConf
17
 
18
 
@@ -129,14 +129,14 @@ class Masker(nn.Module):
129
  B = image.size(0)
130
  image_emb, feats = self.image_encoder(image, image_feat, ret_feats=True) # [BCHW]
131
 
132
- image_emb_norm = us.normalize(image_emb, dim=1)
133
- text_emb_norm = us.normalize(text_emb, dim=-1)
134
 
135
  H, W = image_emb.shape[2:]
136
  D = dist.get_world_size()
137
 
138
  # simmap [B, B*D, H, W] where D is #devices
139
- all_text_emb_norm = us.gather_cat(text_emb_norm, grad=True, contiguous_grad=True)
140
  simmap = torch.einsum("bchw,nc->bnhw", image_emb_norm, all_text_emb_norm)
141
  mask, soft_mask = self.sim2mask(simmap, deterministic=deterministic)
142
 
@@ -178,8 +178,8 @@ class Masker(nn.Module):
178
  """
179
  image_emb = self.image_encoder(image, image_feat) # [BCHW]
180
 
181
- image_emb = us.normalize(image_emb, dim=1) # BCHW
182
- text_emb = us.normalize(text_emb, dim=-1) # NC
183
 
184
  simmap = torch.einsum("b c h w, n c -> b n h w", image_emb, text_emb)
185
 
@@ -219,8 +219,8 @@ class DINOTextMasker(nn.Module):
219
  n, c = text_emb.shape
220
 
221
  if self.similarity_type == "cosine":
222
- image_feat = us.normalize(image_feat, dim=1) # BCHW
223
- # text_emb = us.normalize(text_emb, dim=-1) # NKC
224
  simmap = torch.einsum("b c h w, n c -> b n h w", image_feat, text_emb)
225
  else:
226
  raise NotImplementedError("similarity type {} not implemented".format(self.similarity_type))
 
8
  import torch.distributed as dist
9
  import torch.nn as nn
10
  import torch.nn.functional as F
11
+ from .us import normalize
12
  from einops import rearrange, repeat
13
 
14
  # from models.dinotext.gumbel import gumbel_sigmoid
15
+ from .modules import FeatureEncoder
16
  from omegaconf import OmegaConf
17
 
18
 
 
129
  B = image.size(0)
130
  image_emb, feats = self.image_encoder(image, image_feat, ret_feats=True) # [BCHW]
131
 
132
+ image_emb_norm = normalize(image_emb, dim=1)
133
+ text_emb_norm = normalize(text_emb, dim=-1)
134
 
135
  H, W = image_emb.shape[2:]
136
  D = dist.get_world_size()
137
 
138
  # simmap [B, B*D, H, W] where D is #devices
139
+ all_text_emb_norm = gather_cat(text_emb_norm, grad=True, contiguous_grad=True)
140
  simmap = torch.einsum("bchw,nc->bnhw", image_emb_norm, all_text_emb_norm)
141
  mask, soft_mask = self.sim2mask(simmap, deterministic=deterministic)
142
 
 
178
  """
179
  image_emb = self.image_encoder(image, image_feat) # [BCHW]
180
 
181
+ image_emb = normalize(image_emb, dim=1) # BCHW
182
+ text_emb = normalize(text_emb, dim=-1) # NC
183
 
184
  simmap = torch.einsum("b c h w, n c -> b n h w", image_emb, text_emb)
185
 
 
219
  n, c = text_emb.shape
220
 
221
  if self.similarity_type == "cosine":
222
+ image_feat = normalize(image_feat, dim=1) # BCHW
223
+ # text_emb = normalize(text_emb, dim=-1) # NKC
224
  simmap = torch.einsum("b c h w, n c -> b n h w", image_feat, text_emb)
225
  else:
226
  raise NotImplementedError("similarity type {} not implemented".format(self.similarity_type))
src/model.py β†’ model.py RENAMED
@@ -4,7 +4,7 @@ import torch
4
  import torch.nn as nn
5
  import torch.nn.functional as F
6
 
7
- from src.hooks import get_self_attention, process_self_attention, feats
8
 
9
  class VisualProjectionLayer(nn.Module):
10
  """
 
4
  import torch.nn as nn
5
  import torch.nn.functional as F
6
 
7
+ from .hooks import get_self_attention, process_self_attention, feats
8
 
9
  class VisualProjectionLayer(nn.Module):
10
  """
modeling_talk2dino.py CHANGED
@@ -1,6 +1,6 @@
1
- from src.dinotext import DINOText
 
2
  from transformers import PreTrainedModel
3
- from configuration_talk2dino import Talk2DINOConfig
4
  import clip
5
  import torch
6
 
 
1
+ from .configuration_talk2dino import Talk2DINOConfig
2
+ from .dinotext import DINOText
3
  from transformers import PreTrainedModel
 
4
  import clip
5
  import torch
6
 
src/modules.py β†’ modules.py RENAMED
File without changes
src/pamr.py β†’ pamr.py RENAMED
File without changes
src/__init__.py DELETED
File without changes
src/templates.py β†’ templates.py RENAMED
File without changes
src/us.py β†’ us.py RENAMED
File without changes