16
16
17
17
from contextlib import contextmanager
18
18
import sys
19
+ import os
19
20
import importlib
20
21
import urllib
21
22
import re
23
+ import transformers
22
24
from collections import OrderedDict
23
-
25
+ from transformers import AutoConfig
24
26
import numpy as np
25
27
from ..config import PathField , StringField , DictField , NumberField , ListField , BoolField
26
28
from .launcher import Launcher
27
29
30
+ CLASS_REGEX = r'(?:\w+)'
28
31
MODULE_REGEX = r'(?:\w+)(?:(?:.\w+)*)'
29
32
DEVICE_REGEX = r'(?P<device>cpu$|cuda)?'
30
33
CHECKPOINT_URL_REGEX = r'^https?://.*\.pth(\?.*)?(#.*)?$'
@@ -67,6 +70,9 @@ def parameters(cls):
67
70
'torch_compile_kwargs' : DictField (
68
71
key_type = str , validate_values = False , optional = True , default = {},
69
72
description = "dictionary of keyword arguments passed to torch.compile"
73
+ ),
74
+ 'transformers_class' : StringField (
75
+ optional = True , regex = CLASS_REGEX , description = 'Transformers class name to load pre-trained module.'
70
76
)
71
77
})
72
78
return parameters
@@ -84,6 +90,7 @@ def __init__(self, config_entry: dict, *args, **kwargs):
84
90
self .validate_config (config_entry )
85
91
self .use_torch_compile = config_entry .get ('use_torch_compile' , False )
86
92
self .compile_kwargs = config_entry .get ('torch_compile_kwargs' , {})
93
+ self .tranformers_class = config_entry .get ('transformers_class' , None )
87
94
backend = self .compile_kwargs .get ('backend' , None )
88
95
if self .use_torch_compile and backend == 'openvino' :
89
96
try :
@@ -96,17 +103,24 @@ def __init__(self, config_entry: dict, *args, **kwargs):
96
103
self .device = self .get_value_from_config ('device' )
97
104
self .cuda = 'cuda' in self .device
98
105
checkpoint = config_entry .get ('checkpoint' )
99
- if checkpoint is None :
100
- checkpoint = config_entry .get ('checkpoint_url' )
101
- self .module = self .load_module (
102
- config_entry ['module' ],
103
- module_args ,
104
- module_kwargs ,
105
- checkpoint ,
106
- config_entry .get ('state_key' ),
107
- config_entry .get ("python_path" ),
108
- config_entry .get ("init_method" )
109
- )
106
+ if self .tranformers_class :
107
+ self .module = self .load_tranformers_module (
108
+ config_entry ['module' ]
109
+ )
110
+ else :
111
+ if checkpoint is None :
112
+ checkpoint = config_entry .get ('checkpoint_url' )
113
+
114
+ self .module = self .load_module (
115
+ config_entry ['module' ],
116
+ module_args ,
117
+ module_kwargs ,
118
+ checkpoint ,
119
+ config_entry .get ('state_key' ),
120
+ config_entry .get ("python_path" ),
121
+ config_entry .get ("init_method" )
122
+ )
123
+
110
124
111
125
self ._batch = self .get_value_from_config ('batch' )
112
126
# torch modules does not have input information
@@ -161,15 +175,27 @@ def load_module(self, model_cls, module_args, module_kwargs, checkpoint=None, st
161
175
if all (key .startswith ('module.' ) for key in state ):
162
176
module = self ._torch .nn .DataParallel (module )
163
177
module .load_state_dict (state , strict = False )
164
- module .to ('cuda' if self .cuda else 'cpu' )
165
- module .eval ()
166
178
167
- if self .use_torch_compile :
168
- if hasattr (model_cls , 'compile' ):
169
- module .compile ()
170
- module = self ._torch .compile (module , ** self .compile_kwargs )
179
+ return self .prepare_module (module , model_cls )
180
+
181
+ def load_tranformers_module (self , pretrained_name ):
182
+
183
+ model_class = getattr (transformers , self .tranformers_class )
184
+ module = model_class .from_pretrained (pretrained_name )
185
+
186
+ return self .prepare_module (module , model_class )
187
+
188
+
189
+ def prepare_module (self , module , model_class ):
190
+ module .to ('cuda' if self .cuda else 'cpu' )
191
+ module .eval ()
192
+
193
+ if self .use_torch_compile :
194
+ if hasattr (model_class , 'compile' ):
195
+ module .compile ()
196
+ module = self ._torch .compile (module , ** self .compile_kwargs )
197
+ return module
171
198
172
- return module
173
199
174
200
def _convert_to_tensor (self , value , precision ):
175
201
if isinstance (value , self ._torch .Tensor ):
@@ -193,6 +219,7 @@ def fit_to_input(self, data, layer_name, layout, precision, template=None):
193
219
194
220
if layout is not None :
195
221
data = np .transpose (data , layout )
222
+
196
223
tensor = self ._torch .from_numpy (data .astype (np .float32 if not precision else precision ))
197
224
tensor = tensor .to (self .device )
198
225
return tensor
@@ -213,7 +240,15 @@ def predict(self, inputs, metadata=None, **kwargs):
213
240
if metadata [0 ].get ('input_is_dict_type' ):
214
241
outputs = self .module (batch_input ['input' ])
215
242
else :
216
- outputs = list (self .module (* batch_input .values ()))
243
+ output = self .module (* batch_input .values ())
244
+
245
+ if 'logits' in self .output_names :
246
+ result_dict = { 'logits' : output .logits .detach ().cpu ().numpy () }
247
+ results .append (result_dict )
248
+ continue
249
+ else :
250
+ outputs = list (output )
251
+
217
252
for meta_ in metadata :
218
253
meta_ ['input_shape' ] = {key : list (data .shape ) for key , data in batch_input .items ()}
219
254
0 commit comments