-
Notifications
You must be signed in to change notification settings - Fork 238
Open
Description
When running ESM models on CPU with flash_attn installed, inference fails with CUDA-related errors despite explicitly setting the device to CPU.
Steps to reproduce:
On a CPU machine with flash-attention installed
from esm.models.esmc import ESMC
from esm.sdk.api import ESMProtein, LogitsConfig
protein = ESMProtein(sequence="AAAAA")
client = ESMC.from_pretrained("esmc_300m").to("cpu")
protein_tensor = client.encode(protein)
logits_output = client.logits(protein_tensor, LogitsConfig(sequence=True))
Error
RuntimeError: invalid argument to exchangeDevice
I tried setting client._use_flash_attn = False
manually after loading but it fails with a different error related to tensor dimensions, suggesting Flash Attention dependencies remain active.
Suggested Fix
Add an automatic check to disable Flash Attention when running on CPU, or provide use_flash_attn
parameter in from_pretrained()
Metadata
Metadata
Assignees
Labels
No labels