feat: move protein embeddings input to the right dtype
Browse files- modeling_bacformer.py +1 -1
modeling_bacformer.py
CHANGED
|
@@ -328,7 +328,7 @@ class BacformerEmbeddings(nn.Module):
|
|
| 328 |
bs, seq_length, dim = protein_embeddings.shape
|
| 329 |
|
| 330 |
# pass the pooled ESM protein embeddings through a linear layer
|
| 331 |
-
protein_embeddings = self.linear(protein_embeddings)
|
| 332 |
protein_embeddings = torch.where(
|
| 333 |
special_tokens_mask.unsqueeze(-1).repeat(1, 1, dim) == self.prot_emb_token_id,
|
| 334 |
protein_embeddings,
|
|
|
|
| 328 |
bs, seq_length, dim = protein_embeddings.shape
|
| 329 |
|
| 330 |
# pass the pooled ESM protein embeddings through a linear layer
|
| 331 |
+
protein_embeddings = self.linear(protein_embeddings.type_as(self.linear.weight))
|
| 332 |
protein_embeddings = torch.where(
|
| 333 |
special_tokens_mask.unsqueeze(-1).repeat(1, 1, dim) == self.prot_emb_token_id,
|
| 334 |
protein_embeddings,
|