macwiatrak commited on
Commit
d5bdc18
·
verified ·
1 Parent(s): 7f332d8

feat: move protein embeddings input to the right dtype

Browse files
Files changed (1) hide show
  1. modeling_bacformer.py +1 -1
modeling_bacformer.py CHANGED
@@ -328,7 +328,7 @@ class BacformerEmbeddings(nn.Module):
328
  bs, seq_length, dim = protein_embeddings.shape
329
 
330
  # pass the pooled ESM protein embeddings through a linear layer
331
- protein_embeddings = self.linear(protein_embeddings)
332
  protein_embeddings = torch.where(
333
  special_tokens_mask.unsqueeze(-1).repeat(1, 1, dim) == self.prot_emb_token_id,
334
  protein_embeddings,
 
328
  bs, seq_length, dim = protein_embeddings.shape
329
 
330
  # pass the pooled ESM protein embeddings through a linear layer
331
+ protein_embeddings = self.linear(protein_embeddings.type_as(self.linear.weight))
332
  protein_embeddings = torch.where(
333
  special_tokens_mask.unsqueeze(-1).repeat(1, 1, dim) == self.prot_emb_token_id,
334
  protein_embeddings,