Talk2DINO-ViTB / modeling_talk2dino.py
lorebianchi98's picture
Fixed error
d120439
raw
history blame
1.82 kB
from .configuration_talk2dino import Talk2DINOConfig
from .dinotext import DINOText
from transformers import PreTrainedModel
import clip
import torch
class Talk2DINO(DINOText, PreTrainedModel):
config_class = Talk2DINOConfig
def __init__(self, config: Talk2DINOConfig):
# Store the config
self.config = config
# Convert config to a dict (works for PretrainedConfig subclasses)
cfg_dict = config.to_dict()
# Initialize parent (DINOText) with unpacked kwargs
super().__init__(**cfg_dict)
def encode_text(self, texts):
""" texts: string or list of strings
returns: text embeddings (N, D) where N is the number of texts, D is the embedding dimension
"""
text_tokens = clip.tokenize(texts).to(self.parameters().__next__().device)
txt_embed = self.clip_model.encode_text(text_tokens)
txt_embed = self.proj.project_clip_txt(txt_embed)
return txt_embed
def encode_image(self, images):
""" images: PIL image or list of PIL images
returns: image embeddings (N, L, D) where N is the number of images, L is the number of patches, D is the embedding dimension
"""
if type(images) is not list:
images = [images]
img_preprocessed = [self.image_transforms(img).to(next(self.parameters()).device) for img in images]
img_preprocessed = torch.stack(img_preprocessed)
if 'dinov2' in self.model_name or 'dinov3' in self.model_name:
img_embed = self.model.forward_features(img_preprocessed)['x_norm_patchtokens']
elif 'mae' in self.model_name or 'clip' in self.model_name or 'dino' in self.model_name:
img_embed = self.model.forward_features(img_preprocessed)[:, 1:, :]
return img_embed