| import torch | |
| torch.manual_seed(1024) | |
| import torch.nn as nn | |
| from transformers import PreTrainedModel | |
| from .configuration_hformer import HformerConfig | |
| from .qformer_src import BertConfig, BertLMHeadModel | |
| from transformers import BertTokenizerFast as BertTokenizer | |
| from .configuration_projector import ProjectorConfig | |
| from .modeling_projector import ProjectorModel | |
| import torch.nn.functional as F | |
| from transformers.activations import ACT2FN | |
| class LayerNorm(nn.LayerNorm): | |
| def forward(self, x: torch.Tensor): | |
| ret = super().forward(x) | |
| return ret | |
| class HformerModel(PreTrainedModel): | |
| _auto_class = 'AutoModel' | |
| config_class = HformerConfig | |
| base_model_prefix = 'model' | |
| supports_gradient_checkpointing = False | |
| def __init__(self, config) -> None: | |
| super().__init__(config) | |
| self.gradient_checkpointing = False | |
| vision_width = config.visual_hidden_size | |
| num_query_token = config.num_query_token | |
| bert = config.bert | |
| llm_hidden_size = config.llm_hidden_size | |
| cross_attention_freq = config.cross_attention_freq | |
| qformer_pth = config.qformer_pth | |
| encoder_config = BertConfig.from_pretrained(bert) | |
| encoder_config.encoder_width = vision_width | |
| encoder_config.add_cross_attention = True | |
| encoder_config.cross_attention_freq = cross_attention_freq | |
| encoder_config.query_length = num_query_token | |
| encoder_config.num_hidden_layers = 12 | |
| Qformer = BertLMHeadModel.from_pretrained( | |
| bert, config=encoder_config | |
| ) | |
| remove_text = False | |
| if remove_text: | |
| Qformer.cls = None | |
| Qformer.bert.embeddings.word_embeddings = None | |
| Qformer.bert.embeddings.position_embeddings = None | |
| for layer in Qformer.bert.encoder.layer: | |
| layer.output = None | |
| layer.intermediate = None | |
| query_tokens = nn.Parameter( | |
| torch.zeros(1, num_query_token, encoder_config.hidden_size) | |
| ) | |
| query_tokens.data.normal_(mean=0.0, std=encoder_config.initializer_range) | |
| self.Qformer = Qformer | |
| self.query_tokens = query_tokens | |
| self.llm_proj = nn.Linear(encoder_config.hidden_size, llm_hidden_size, bias=config.bias) | |
| self.ln_vision = LayerNorm(encoder_config.encoder_width) | |
| self.ln_llava = LayerNorm(encoder_config.encoder_width) | |
| tokenizer = BertTokenizer.from_pretrained(bert, truncation_side='right') | |
| tokenizer.add_special_tokens({"bos_token": "[DEC]"}) | |
| self.Qformer.resize_token_embeddings(len(tokenizer)) | |
| if qformer_pth is not None: | |
| pretrained_state_dict = torch.load(qformer_pth, map_location='cpu')['model'] | |
| print(f'Load Qformer from {qformer_pth}') | |
| self.load_state_dict(pretrained_state_dict, strict=False) | |
| print('Done.') | |
| projector_config = ProjectorConfig( | |
| visual_hidden_size = config.visual_hidden_size, | |
| llm_hidden_size = config.llm_hidden_size, | |
| projector_depth = 2) | |
| self.connector = ProjectorModel(projector_config) | |
| modules = [ | |
| nn.Linear(config.llm_hidden_size, config.llm_hidden_size//4, bias=False), | |
| ACT2FN['gelu'], | |
| nn.Linear(config.llm_hidden_size//4, config.llm_hidden_size, bias=False) | |
| ] | |
| self.ffn = nn.Sequential(*modules) | |
| def enable_input_require_grads(self): | |
| def make_inputs_require_grad(module, input, output): | |
| if isinstance(output, tuple): | |
| output[0].requires_grad_(True) | |
| output[1].requires_grad_(True) | |
| else: | |
| output.requires_grad_(True) | |
| self.Qformer.register_forward_hook(make_inputs_require_grad) | |
| self.llm_proj.register_forward_hook(make_inputs_require_grad) | |
| self.ln_vision.register_forward_hook(make_inputs_require_grad) | |
| self.connector.register_forward_hook(make_inputs_require_grad) | |
| self.ffn.register_forward_hook(make_inputs_require_grad) | |
| def _set_gradient_checkpointing(self, module, value=False): | |
| pass | |
| def forward(self, x_): | |
| if self.gradient_checkpointing and self.training: | |
| print('Not support gradient checkpointing') | |
| x = self.ln_vision(x_) | |
| query_tokens = self.query_tokens.expand(x.shape[0], -1, -1) | |
| query_output = self.Qformer.bert( | |
| query_embeds=query_tokens, | |
| encoder_hidden_states=x, | |
| return_dict=True, | |
| ) | |
| q_feat = self.llm_proj(query_output.last_hidden_state) | |
| mlp_outputs = self.connector(x_) | |
| mlp_feat = mlp_outputs | |
| int_feat = mlp_feat + q_feat.mean(dim=1)[:,None] | |
| out = int_feat + self.ffn(int_feat) | |
| return out | |