Skip to content

Commit 0161a4c

Browse files
committed
Evalute native restnet pytroch models from huggingface
1 parent 2378e96 commit 0161a4c

File tree

2 files changed

+57
-20
lines changed

2 files changed

+57
-20
lines changed

tools/accuracy_checker/accuracy_checker/launcher/pytorch_launcher.py

Lines changed: 55 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,18 @@
1616

1717
from contextlib import contextmanager
1818
import sys
19+
import os
1920
import importlib
2021
import urllib
2122
import re
23+
import transformers
2224
from collections import OrderedDict
23-
25+
from transformers import AutoConfig
2426
import numpy as np
2527
from ..config import PathField, StringField, DictField, NumberField, ListField, BoolField
2628
from .launcher import Launcher
2729

30+
CLASS_REGEX = r'(?:\w+)'
2831
MODULE_REGEX = r'(?:\w+)(?:(?:.\w+)*)'
2932
DEVICE_REGEX = r'(?P<device>cpu$|cuda)?'
3033
CHECKPOINT_URL_REGEX = r'^https?://.*\.pth(\?.*)?(#.*)?$'
@@ -67,6 +70,9 @@ def parameters(cls):
6770
'torch_compile_kwargs': DictField(
6871
key_type=str, validate_values=False, optional=True, default={},
6972
description="dictionary of keyword arguments passed to torch.compile"
73+
),
74+
'transformers_class': StringField(
75+
optional=True, regex=CLASS_REGEX, description='Transformers class name to load pre-trained module.'
7076
)
7177
})
7278
return parameters
@@ -84,6 +90,7 @@ def __init__(self, config_entry: dict, *args, **kwargs):
8490
self.validate_config(config_entry)
8591
self.use_torch_compile = config_entry.get('use_torch_compile', False)
8692
self.compile_kwargs = config_entry.get('torch_compile_kwargs', {})
93+
self.tranformers_class = config_entry.get('transformers_class', None)
8794
backend = self.compile_kwargs.get('backend', None)
8895
if self.use_torch_compile and backend == 'openvino':
8996
try:
@@ -96,17 +103,24 @@ def __init__(self, config_entry: dict, *args, **kwargs):
96103
self.device = self.get_value_from_config('device')
97104
self.cuda = 'cuda' in self.device
98105
checkpoint = config_entry.get('checkpoint')
99-
if checkpoint is None:
100-
checkpoint = config_entry.get('checkpoint_url')
101-
self.module = self.load_module(
102-
config_entry['module'],
103-
module_args,
104-
module_kwargs,
105-
checkpoint,
106-
config_entry.get('state_key'),
107-
config_entry.get("python_path"),
108-
config_entry.get("init_method")
109-
)
106+
if self.tranformers_class:
107+
self.module = self.load_tranformers_module(
108+
config_entry['module']
109+
)
110+
else:
111+
if checkpoint is None:
112+
checkpoint = config_entry.get('checkpoint_url')
113+
114+
self.module = self.load_module(
115+
config_entry['module'],
116+
module_args,
117+
module_kwargs,
118+
checkpoint,
119+
config_entry.get('state_key'),
120+
config_entry.get("python_path"),
121+
config_entry.get("init_method")
122+
)
123+
110124

111125
self._batch = self.get_value_from_config('batch')
112126
# torch modules does not have input information
@@ -161,15 +175,27 @@ def load_module(self, model_cls, module_args, module_kwargs, checkpoint=None, st
161175
if all(key.startswith('module.') for key in state):
162176
module = self._torch.nn.DataParallel(module)
163177
module.load_state_dict(state, strict=False)
164-
module.to('cuda' if self.cuda else 'cpu')
165-
module.eval()
166178

167-
if self.use_torch_compile:
168-
if hasattr(model_cls, 'compile'):
169-
module.compile()
170-
module = self._torch.compile(module, **self.compile_kwargs)
179+
return self.prepare_module(module, model_cls)
180+
181+
def load_tranformers_module(self, pretrained_name):
182+
183+
model_class = getattr(transformers, self.tranformers_class)
184+
module = model_class.from_pretrained(pretrained_name)
185+
186+
return self.prepare_module(module, model_class)
187+
188+
189+
def prepare_module(self, module, model_class):
190+
module.to('cuda' if self.cuda else 'cpu')
191+
module.eval()
192+
193+
if self.use_torch_compile:
194+
if hasattr(model_class, 'compile'):
195+
module.compile()
196+
module = self._torch.compile(module, **self.compile_kwargs)
197+
return module
171198

172-
return module
173199

174200
def _convert_to_tensor(self, value, precision):
175201
if isinstance(value, self._torch.Tensor):
@@ -193,6 +219,7 @@ def fit_to_input(self, data, layer_name, layout, precision, template=None):
193219

194220
if layout is not None:
195221
data = np.transpose(data, layout)
222+
196223
tensor = self._torch.from_numpy(data.astype(np.float32 if not precision else precision))
197224
tensor = tensor.to(self.device)
198225
return tensor
@@ -213,7 +240,15 @@ def predict(self, inputs, metadata=None, **kwargs):
213240
if metadata[0].get('input_is_dict_type'):
214241
outputs = self.module(batch_input['input'])
215242
else:
216-
outputs = list(self.module(*batch_input.values()))
243+
output = self.module(*batch_input.values())
244+
245+
if 'logits' in self.output_names:
246+
result_dict = { 'logits': output.logits.detach().cpu().numpy() }
247+
results.append(result_dict)
248+
continue
249+
else:
250+
outputs = list(output)
251+
217252
for meta_ in metadata:
218253
meta_['input_shape'] = {key: list(data.shape) for key, data in batch_input.items()}
219254

tools/accuracy_checker/accuracy_checker/launcher/pytorch_launcher_readme.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ For enabling PyTorch launcher you need to add `framework: pytorch` in launchers
1717
* `batch` - batch size for running model (Optional, default 1).
1818
* `use_torch_compile` - boolean, use torch.compile to optimize the module code (Optional, default `False`)
1919
* `torch_compile_kwargs` - dictionary of keyword arguments to pass to torch.compile (Optional, default `{}`)
20+
* `transformers_class` - transformers class name to load pre-trained model with `module` name. (Optional).
21+
2022

2123
In turn if you model has several inputs you need to specify them in config, using specific parameter: `inputs`.
2224
Each input description should has following info:

0 commit comments

Comments
 (0)