Skip to content

How to enable batch training? #260

@yuliangyan0807

Description

@yuliangyan0807

I wrote a script to utilize the ESMc model as follows:

import torch
import torch.nn as nn
from esm.models.esmc import ESMC
from esm.sdk.api import ESMProtein, LogitsConfig
from typing import List

class ESMCProteinEncoder(nn.Module):
    def __init__(self, model_name="esmc_600m"):
        """
        ESMC Protein Encoder that only encodes protein sequences and applies mean pooling
        :param model_name: The pre-trained model to load from HuggingFace (default is "esmc_600m")
        """
        super().__init__()
        # Load the ESMC model, frozen by default
        self.client = ESMC.from_pretrained(model_name)
        
        # Freeze the parameters to avoid modifying during training
        for param in self.client.parameters():
            param.requires_grad = False

    def forward(self, seq: List[str]):
        """
        Encodes the input protein sequence and applies mean pooling on the embeddings
        :param seq: Protein sequence as a string
        :return: Mean-pooled protein embeddings from ESMC model
        """
        # Ensure the input sequence is in the correct format
        protein = ESMProtein(sequence=seq)

        # Get the encoded tensor from the ESMC model
        tensor = self.client.encode(protein)

        # Get the embeddings (logits) from the ESMC model
        logits_output = self.client.logits(tensor, LogitsConfig(sequence=True, return_embeddings=True))
        
        # Extract the embeddings and drop the BOS/EOS tokens
        embeddings = logits_output.embeddings[:, 1:-1, :].to(torch.float32)  # Drop BOS/EOS

        # Apply mean pooling to the embeddings (along the sequence dimension)
        pooled_embeddings = torch.mean(embeddings, dim=1)
        
        return pooled_embeddings



encoder = ESMCProteinEncoder(model_name="esmc_600m")
seq = ["AAAAAA", "GGGGGG", "CCCCCC"]
pooled_embeddings = encoder(seq)
print(pooled_embeddings.shape)

When I try to input a list, e.g., ["AAAAAA", "GGGGGG", "CCCCCC"], into the model, the following error occurred:

Traceback (most recent call last):
  File "/home/yuliangyan/Code/Trust-App-AI-Lab/prot_learn/esm3_test.py", line 60, in <module>
    pooled_embeddings = encoder(seq)
                        ^^^^^^^^^^^^
  File "/home/yuliangyan/anaconda3/envs/yyl/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/yuliangyan/anaconda3/envs/yyl/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/yuliangyan/Code/Trust-App-AI-Lab/prot_learn/esm3_test.py", line 43, in forward
    tensor = self.client.encode(protein)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/yuliangyan/anaconda3/envs/yyl/lib/python3.12/site-packages/esm/models/esmc.py", line 180, in encode
    sequence_tokens = self._tokenize([input.sequence])[0]
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/yuliangyan/anaconda3/envs/yyl/lib/python3.12/site-packages/esm/models/esmc.py", line 105, in _tokenize
    encoding.tokenize_sequence(x, self.tokenizer, add_special_tokens=True)
  File "/home/yuliangyan/anaconda3/envs/yyl/lib/python3.12/site-packages/esm/utils/encoding.py", line 53, in tokenize_sequence
    sequence = sequence.replace(C.MASK_STR_SHORT, sequence_tokenizer.mask_token)
               ^^^^^^^^^^^^^^^^
AttributeError: 'list' object has no attribute 'replace'

How can I fix this bug? Thanks!

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions