6
6
from contextlib import contextmanager
7
7
8
8
import torch
9
- from transformers import logging as tr_logging
10
9
11
10
from chai_lab .data .dataset .embeddings .embedding_context import EmbeddingContext
12
11
from chai_lab .data .dataset .structure .chain import Chain
13
12
from chai_lab .data .parsing .structure .entity_type import EntityType
14
- from chai_lab .utils .paths import downloads_path
15
- from chai_lab .utils .tensor_utils import move_data_to_device
13
+ from chai_lab .utils .paths import download_if_not_exists , downloads_path
16
14
from chai_lab .utils .typing import typecheck
17
15
18
16
_esm_model : list = [] # persistent in-process container
19
17
20
18
os .register_at_fork (after_in_child = lambda : _esm_model .clear ())
21
19
22
20
23
- # unfortunately huggingface complains on pooler layer in ESM being non-initialized.
24
- # Did not find a way to filter specifically that logging message :/
25
- tr_logging .set_verbosity_error ()
21
+ ESM_URL = "https://chaiassets.com/chai1-inference-depencencies/esm2/traced_sdpa_esm2_t36_3B_UR50D_fp16.pt"
22
+
26
23
27
24
esm_cache_folder = downloads_path .joinpath ("esm" )
28
25
29
26
30
27
@contextmanager
31
- def esm_model (model_name : str , device ):
28
+ def esm_model (device ):
32
29
"""Context transiently keeps ESM model on specified device."""
33
- from transformers import EsmModel
30
+
31
+ local_esm_path = downloads_path .joinpath (
32
+ "esm/traced_sdpa_esm2_t36_3B_UR50D_fp16.pt"
33
+ )
34
+ download_if_not_exists (ESM_URL , local_esm_path )
34
35
35
36
if len (_esm_model ) == 0 :
36
37
# lazy loading of the model
37
- _esm_model .append (
38
- EsmModel .from_pretrained (model_name , cache_dir = esm_cache_folder )
39
- )
38
+ if device != torch .device ("cuda:0" ):
39
+ # load on cpu first, then move to device
40
+ model = torch .jit .load (local_esm_path , map_location = "cpu" ).to (device )
41
+ else :
42
+ # skip loading on CPU.
43
+ model = torch .jit .load (local_esm_path ).to (device )
44
+
45
+ _esm_model .append (model )
40
46
41
47
[model ] = _esm_model
42
48
model .to (device )
@@ -45,28 +51,82 @@ def esm_model(model_name: str, device):
45
51
model .to ("cpu" ) # move model back to CPU when done
46
52
47
53
54
+ token_map = {
55
+ "<cls>" : 0 ,
56
+ "<pad>" : 1 ,
57
+ "<eos>" : 2 ,
58
+ "<unk>" : 3 ,
59
+ "L" : 4 ,
60
+ "A" : 5 ,
61
+ "G" : 6 ,
62
+ "V" : 7 ,
63
+ "S" : 8 ,
64
+ "E" : 9 ,
65
+ "R" : 10 ,
66
+ "T" : 11 ,
67
+ "I" : 12 ,
68
+ "D" : 13 ,
69
+ "P" : 14 ,
70
+ "K" : 15 ,
71
+ "Q" : 16 ,
72
+ "N" : 17 ,
73
+ "F" : 18 ,
74
+ "Y" : 19 ,
75
+ "M" : 20 ,
76
+ "H" : 21 ,
77
+ "W" : 22 ,
78
+ "C" : 23 ,
79
+ "X" : 24 ,
80
+ "B" : 25 ,
81
+ "U" : 26 ,
82
+ "Z" : 27 ,
83
+ "O" : 28 ,
84
+ "." : 29 ,
85
+ "-" : 30 ,
86
+ "<null_1>" : 31 ,
87
+ "<mask>" : 32 ,
88
+ }
89
+
90
+
91
+ class DumbTokenizer :
92
+ def __init__ (self , token_map : dict [str , int ]):
93
+ self .token_map = token_map
94
+
95
+ def tokenize (self , text : str ) -> list [int ]:
96
+ tokens = []
97
+ i = 0
98
+ while i < len (text ):
99
+ for token in self .token_map :
100
+ if text .startswith (token , i ):
101
+ tokens .append (self .token_map [token ])
102
+ i += len (token )
103
+ break
104
+ else :
105
+ raise RuntimeError ("Unknown token: " + text [i :])
106
+ return tokens
107
+
108
+
109
+ esm_tokenizer = DumbTokenizer (token_map = token_map )
110
+
111
+
48
112
def _get_esm_contexts_for_sequences (
49
113
prot_sequences : set [str ], device
50
114
) -> dict [str , EmbeddingContext ]:
51
115
if len (prot_sequences ) == 0 :
52
116
return {} # skip loading ESM
53
117
54
- # local import, requires huggingface transformers
55
- from transformers import EsmTokenizer
56
-
57
- model_name = "facebook/esm2_t36_3B_UR50D"
58
- tokenizer = EsmTokenizer .from_pretrained (model_name , cache_dir = esm_cache_folder )
59
-
60
118
seq2embedding_context = {}
61
119
62
120
with torch .no_grad ():
63
- with esm_model (model_name = model_name , device = device ) as model :
121
+ with esm_model (device = device ) as model :
64
122
for seq in prot_sequences :
65
- inputs = tokenizer (seq , return_tensors = "pt" )
66
- inputs = move_data_to_device (dict (** inputs ), device = device )
67
- outputs = model (** inputs )
123
+ # add bos/eos, tokenize
124
+ token_ids = torch .asarray (esm_tokenizer .tokenize (f"<cls>{ seq } <eos>" ))
125
+ token_ids = token_ids [None , :].to (device )
126
+
127
+ last_hidden_state = model (tokens = token_ids )
68
128
# remove BOS/EOS, back to CPU
69
- esm_embeddings = outputs . last_hidden_state [0 , 1 :- 1 ].to ("cpu" )
129
+ esm_embeddings = last_hidden_state [0 , 1 :- 1 ]. float () .to ("cpu" )
70
130
seq_len , _emb_dim = esm_embeddings .shape
71
131
assert seq_len == len (seq )
72
132
0 commit comments