-
Notifications
You must be signed in to change notification settings - Fork 238
Open
Description
I wrote a script to utilize the ESMc model as follows:
import torch
import torch.nn as nn
from esm.models.esmc import ESMC
from esm.sdk.api import ESMProtein, LogitsConfig
from typing import List
class ESMCProteinEncoder(nn.Module):
def __init__(self, model_name="esmc_600m"):
"""
ESMC Protein Encoder that only encodes protein sequences and applies mean pooling
:param model_name: The pre-trained model to load from HuggingFace (default is "esmc_600m")
"""
super().__init__()
# Load the ESMC model, frozen by default
self.client = ESMC.from_pretrained(model_name)
# Freeze the parameters to avoid modifying during training
for param in self.client.parameters():
param.requires_grad = False
def forward(self, seq: List[str]):
"""
Encodes the input protein sequence and applies mean pooling on the embeddings
:param seq: Protein sequence as a string
:return: Mean-pooled protein embeddings from ESMC model
"""
# Ensure the input sequence is in the correct format
protein = ESMProtein(sequence=seq)
# Get the encoded tensor from the ESMC model
tensor = self.client.encode(protein)
# Get the embeddings (logits) from the ESMC model
logits_output = self.client.logits(tensor, LogitsConfig(sequence=True, return_embeddings=True))
# Extract the embeddings and drop the BOS/EOS tokens
embeddings = logits_output.embeddings[:, 1:-1, :].to(torch.float32) # Drop BOS/EOS
# Apply mean pooling to the embeddings (along the sequence dimension)
pooled_embeddings = torch.mean(embeddings, dim=1)
return pooled_embeddings
encoder = ESMCProteinEncoder(model_name="esmc_600m")
seq = ["AAAAAA", "GGGGGG", "CCCCCC"]
pooled_embeddings = encoder(seq)
print(pooled_embeddings.shape)
When I try to input a list, e.g., ["AAAAAA", "GGGGGG", "CCCCCC"], into the model, the following error occurred:
Traceback (most recent call last):
File "/home/yuliangyan/Code/Trust-App-AI-Lab/prot_learn/esm3_test.py", line 60, in <module>
pooled_embeddings = encoder(seq)
^^^^^^^^^^^^
File "/home/yuliangyan/anaconda3/envs/yyl/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/yuliangyan/anaconda3/envs/yyl/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/yuliangyan/Code/Trust-App-AI-Lab/prot_learn/esm3_test.py", line 43, in forward
tensor = self.client.encode(protein)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/yuliangyan/anaconda3/envs/yyl/lib/python3.12/site-packages/esm/models/esmc.py", line 180, in encode
sequence_tokens = self._tokenize([input.sequence])[0]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/yuliangyan/anaconda3/envs/yyl/lib/python3.12/site-packages/esm/models/esmc.py", line 105, in _tokenize
encoding.tokenize_sequence(x, self.tokenizer, add_special_tokens=True)
File "/home/yuliangyan/anaconda3/envs/yyl/lib/python3.12/site-packages/esm/utils/encoding.py", line 53, in tokenize_sequence
sequence = sequence.replace(C.MASK_STR_SHORT, sequence_tokenizer.mask_token)
^^^^^^^^^^^^^^^^
AttributeError: 'list' object has no attribute 'replace'
How can I fix this bug? Thanks!
Metadata
Metadata
Assignees
Labels
No labels