Skip to content

ESM models fail on CPU when flash_attn is installed #242

@karinazad

Description

@karinazad

When running ESM models on CPU with flash_attn installed, inference fails with CUDA-related errors despite explicitly setting the device to CPU.

Steps to reproduce:
On a CPU machine with flash-attention installed

from esm.models.esmc import ESMC
from esm.sdk.api import ESMProtein, LogitsConfig

protein = ESMProtein(sequence="AAAAA")
client = ESMC.from_pretrained("esmc_300m").to("cpu")
protein_tensor = client.encode(protein)
logits_output = client.logits(protein_tensor, LogitsConfig(sequence=True))

Error
RuntimeError: invalid argument to exchangeDevice

I tried setting client._use_flash_attn = False manually after loading but it fails with a different error related to tensor dimensions, suggesting Flash Attention dependencies remain active.

Suggested Fix
Add an automatic check to disable Flash Attention when running on CPU, or provide use_flash_attn parameter in from_pretrained()

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions