Skip to content

Commit 3c68320

Browse files
authored
do not use esm from transformers (#343)
* do not use esm from transformers * minor
1 parent 690c2b8 commit 3c68320

File tree

5 files changed

+85
-25
lines changed

5 files changed

+85
-25
lines changed

chai_lab/data/dataset/embeddings/esm.py

Lines changed: 82 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -6,37 +6,43 @@
66
from contextlib import contextmanager
77

88
import torch
9-
from transformers import logging as tr_logging
109

1110
from chai_lab.data.dataset.embeddings.embedding_context import EmbeddingContext
1211
from chai_lab.data.dataset.structure.chain import Chain
1312
from chai_lab.data.parsing.structure.entity_type import EntityType
14-
from chai_lab.utils.paths import downloads_path
15-
from chai_lab.utils.tensor_utils import move_data_to_device
13+
from chai_lab.utils.paths import download_if_not_exists, downloads_path
1614
from chai_lab.utils.typing import typecheck
1715

1816
_esm_model: list = [] # persistent in-process container
1917

2018
os.register_at_fork(after_in_child=lambda: _esm_model.clear())
2119

2220

23-
# unfortunately huggingface complains on pooler layer in ESM being non-initialized.
24-
# Did not find a way to filter specifically that logging message :/
25-
tr_logging.set_verbosity_error()
21+
ESM_URL = "https://chaiassets.com/chai1-inference-depencencies/esm2/traced_sdpa_esm2_t36_3B_UR50D_fp16.pt"
22+
2623

2724
esm_cache_folder = downloads_path.joinpath("esm")
2825

2926

3027
@contextmanager
31-
def esm_model(model_name: str, device):
28+
def esm_model(device):
3229
"""Context transiently keeps ESM model on specified device."""
33-
from transformers import EsmModel
30+
31+
local_esm_path = downloads_path.joinpath(
32+
"esm/traced_sdpa_esm2_t36_3B_UR50D_fp16.pt"
33+
)
34+
download_if_not_exists(ESM_URL, local_esm_path)
3435

3536
if len(_esm_model) == 0:
3637
# lazy loading of the model
37-
_esm_model.append(
38-
EsmModel.from_pretrained(model_name, cache_dir=esm_cache_folder)
39-
)
38+
if device != torch.device("cuda:0"):
39+
# load on cpu first, then move to device
40+
model = torch.jit.load(local_esm_path, map_location="cpu").to(device)
41+
else:
42+
# skip loading on CPU.
43+
model = torch.jit.load(local_esm_path).to(device)
44+
45+
_esm_model.append(model)
4046

4147
[model] = _esm_model
4248
model.to(device)
@@ -45,28 +51,82 @@ def esm_model(model_name: str, device):
4551
model.to("cpu") # move model back to CPU when done
4652

4753

54+
token_map = {
55+
"<cls>": 0,
56+
"<pad>": 1,
57+
"<eos>": 2,
58+
"<unk>": 3,
59+
"L": 4,
60+
"A": 5,
61+
"G": 6,
62+
"V": 7,
63+
"S": 8,
64+
"E": 9,
65+
"R": 10,
66+
"T": 11,
67+
"I": 12,
68+
"D": 13,
69+
"P": 14,
70+
"K": 15,
71+
"Q": 16,
72+
"N": 17,
73+
"F": 18,
74+
"Y": 19,
75+
"M": 20,
76+
"H": 21,
77+
"W": 22,
78+
"C": 23,
79+
"X": 24,
80+
"B": 25,
81+
"U": 26,
82+
"Z": 27,
83+
"O": 28,
84+
".": 29,
85+
"-": 30,
86+
"<null_1>": 31,
87+
"<mask>": 32,
88+
}
89+
90+
91+
class DumbTokenizer:
92+
def __init__(self, token_map: dict[str, int]):
93+
self.token_map = token_map
94+
95+
def tokenize(self, text: str) -> list[int]:
96+
tokens = []
97+
i = 0
98+
while i < len(text):
99+
for token in self.token_map:
100+
if text.startswith(token, i):
101+
tokens.append(self.token_map[token])
102+
i += len(token)
103+
break
104+
else:
105+
raise RuntimeError("Unknown token: " + text[i:])
106+
return tokens
107+
108+
109+
esm_tokenizer = DumbTokenizer(token_map=token_map)
110+
111+
48112
def _get_esm_contexts_for_sequences(
49113
prot_sequences: set[str], device
50114
) -> dict[str, EmbeddingContext]:
51115
if len(prot_sequences) == 0:
52116
return {} # skip loading ESM
53117

54-
# local import, requires huggingface transformers
55-
from transformers import EsmTokenizer
56-
57-
model_name = "facebook/esm2_t36_3B_UR50D"
58-
tokenizer = EsmTokenizer.from_pretrained(model_name, cache_dir=esm_cache_folder)
59-
60118
seq2embedding_context = {}
61119

62120
with torch.no_grad():
63-
with esm_model(model_name=model_name, device=device) as model:
121+
with esm_model(device=device) as model:
64122
for seq in prot_sequences:
65-
inputs = tokenizer(seq, return_tensors="pt")
66-
inputs = move_data_to_device(dict(**inputs), device=device)
67-
outputs = model(**inputs)
123+
# add bos/eos, tokenize
124+
token_ids = torch.asarray(esm_tokenizer.tokenize(f"<cls>{seq}<eos>"))
125+
token_ids = token_ids[None, :].to(device)
126+
127+
last_hidden_state = model(tokens=token_ids)
68128
# remove BOS/EOS, back to CPU
69-
esm_embeddings = outputs.last_hidden_state[0, 1:-1].to("cpu")
129+
esm_embeddings = last_hidden_state[0, 1:-1].float().to("cpu")
70130
seq_len, _emb_dim = esm_embeddings.shape
71131
assert seq_len == len(seq)
72132

chai_lab/data/parsing/templates/m8.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from chai_lab.data.parsing.templates.template_hit import TemplateHit
1717
from chai_lab.tools.kalign import kalign_query_to_reference
1818

19-
logger = logging.getLogger(name=__name__)
19+
logger = logging.getLogger(__name__)
2020

2121

2222
def parse_m8_file(fname: Path) -> pd.DataFrame:

examples/predict_structure.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66

77
from chai_lab.chai1 import run_inference
88

9+
logging.basicConfig(level=logging.INFO) # control verbosity
10+
911
# We use fasta-like format for inputs.
1012
# - each entity encodes protein, ligand, RNA or DNA
1113
# - each entity is labeled with unique name;

pyproject.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,6 @@ module = [
4545
"biotite.*",
4646
"DockQ.*",
4747
"boto3.*",
48-
"transformers.*",
4948
"modelcif.*",
5049
"ihm.*",
5150
]

requirements.in

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,4 +29,3 @@ einops~=0.8
2929
jaxtyping>=0.2.25 # versions <0.2.25 do not easily support runtime typechecking
3030
beartype>=0.18 # compatible typechecker to use with jaxtyping
3131
torch>=2.3.1,<2.7 # 2.2 is broken, 2.3.1 is confirmed to work correctly
32-
transformers~=4.44 # for esm inference

0 commit comments

Comments
 (0)