|
|
from .configuration_talk2dino import Talk2DINOConfig |
|
|
from .dinotext import DINOText |
|
|
from transformers import PreTrainedModel |
|
|
import clip |
|
|
import torch |
|
|
|
|
|
class Talk2DINO(DINOText, PreTrainedModel): |
|
|
config_class = Talk2DINOConfig |
|
|
|
|
|
def __init__(self, config: Talk2DINOConfig): |
|
|
|
|
|
self.config = config |
|
|
|
|
|
|
|
|
cfg_dict = config.to_dict() |
|
|
|
|
|
|
|
|
super().__init__(**cfg_dict) |
|
|
|
|
|
def encode_text(self, texts): |
|
|
""" texts: string or list of strings |
|
|
returns: text embeddings (N, D) where N is the number of texts, D is the embedding dimension |
|
|
""" |
|
|
text_tokens = clip.tokenize(texts).to(self.parameters().__next__().device) |
|
|
txt_embed = self.clip_model.encode_text(text_tokens) |
|
|
txt_embed = self.proj.project_clip_txt(txt_embed) |
|
|
return txt_embed |
|
|
|
|
|
def encode_image(self, images): |
|
|
""" images: PIL image or list of PIL images |
|
|
returns: image embeddings (N, L, D) where N is the number of images, L is the number of patches, D is the embedding dimension |
|
|
""" |
|
|
if type(images) is not list: |
|
|
images = [images] |
|
|
img_preprocessed = [self.image_transforms(img).to(next(self.parameters()).device) for img in images] |
|
|
img_preprocessed = torch.stack(img_preprocessed) |
|
|
if 'dinov2' in self.model_name or 'dinov3' in self.model_name: |
|
|
img_embed = self.model.forward_features(img_preprocessed)['x_norm_patchtokens'] |
|
|
elif 'mae' in self.model_name or 'clip' in self.model_name or 'dino' in self.model_name: |
|
|
img_embed = self.model.forward_features(img_preprocessed)[:, 1:, :] |
|
|
|
|
|
return img_embed |