Skip to content

Commit 9cbf0f1

Browse files
committed
add more onnx tests, optimize the handling of some attributes, update example model version
1 parent 583a8c2 commit 9cbf0f1

File tree

5 files changed

+122
-11
lines changed

5 files changed

+122
-11
lines changed

example-models

hls4ml/model/optimizer/passes/batchnorm_opt.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def transform(self, model, node):
2828
if not (len(node.inputs) == 5 and all(node.inputs)):
2929
raise ValueError('All 5 BatchNormOnnnx inputs need to be defined')
3030

31-
attributes = {k: node.attributes.get(k, None) for k in _base_attributes}
31+
attributes = {k: node.attributes[k] for k in _base_attributes if k in node.attributes}
3232

3333
gamma_node = node.get_input_node(node.inputs[1])
3434
if not isinstance(gamma_node, Constant):

hls4ml/model/optimizer/passes/conv_to_convxd.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def transform(self, model, node):
5454
bias_node = node.get_input_node(node.inputs[2])
5555

5656
# creating the attributes
57-
attributes = {k: node.attributes.get(k, None) for k in _base_attributes}
57+
attributes = {k: node.attributes[k] for k in _base_attributes if k in node.attributes}
5858

5959
# The ConvxD nodes expect the weight data to be in a different format, not (M, k1.., C)
6060
if node.attributes['n_dim'] == 1:

hls4ml/model/optimizer/passes/conv_to_depthwiseconvxd.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def transform(self, model, node):
5555
bias_node = node.get_input_node(node.inputs[2])
5656

5757
# creating the attributes
58-
attributes = {k: node.attributes.get(k, None) for k in _base_attributes}
58+
attributes = {k: node.attributes[k] for k in _base_attributes if k in node.attributes}
5959

6060
# The ConvxD nodes expect the weight data to be in a different format, not (M, k1.., C)
6161
if node.attributes['n_dim'] == 1:

test/pytest/test_qonnx.py

Lines changed: 118 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010

1111
# To conveniently run QONNX inference
1212
from qonnx.core.modelwrapper import ModelWrapper
13+
from qonnx.transformation.channels_last import ConvertToChannelsLastAndClean
14+
from qonnx.transformation.gemm_to_matmul import GemmToMatMul
1315

1416
import hls4ml
1517

@@ -99,10 +101,23 @@ def sep_conv_model():
99101
return model
100102

101103

104+
@pytest.fixture(scope='module')
105+
def two_layer_keras_model():
106+
"""
107+
Load a simple, two-layer, originally keras, unquantized model
108+
"""
109+
dl_file = str(example_model_path / "onnx/two_layer_keras.onnx")
110+
assert os.path.isfile(dl_file)
111+
112+
model = ModelWrapper(dl_file)
113+
model = qonnx.util.cleanup.cleanup_model(model)
114+
return model
115+
116+
102117
@pytest.fixture(scope='module')
103118
def three_layer_keras_model():
104119
"""
105-
Load a simple, originally keras unquantized model
120+
Load a simple, three-layer, originally keras, unquantized model
106121
"""
107122
dl_file = str(example_model_path / "onnx/three_layer_keras.onnx")
108123
assert os.path.isfile(dl_file)
@@ -112,6 +127,84 @@ def three_layer_keras_model():
112127
return model
113128

114129

130+
@pytest.fixture(scope='module')
131+
def two_layer_pytorch_model():
132+
"""
133+
Load a simple, two-layer, originally pytorch, unquantized model
134+
"""
135+
dl_file = str(example_model_path / "onnx/two_layer_keras.onnx")
136+
assert os.path.isfile(dl_file)
137+
138+
model = ModelWrapper(dl_file)
139+
model = qonnx.util.cleanup.cleanup_model(model)
140+
model = model.transform(GemmToMatMul())
141+
model = qonnx.util.cleanup.cleanup_model(model)
142+
return model
143+
144+
145+
@pytest.fixture(scope='module')
146+
def three_layer_pytorch_model():
147+
"""
148+
Load a simple, three-layer, originally pytorch, unquantized model
149+
"""
150+
dl_file = str(example_model_path / "onnx/three_layer_pytorch.onnx")
151+
assert os.path.isfile(dl_file)
152+
153+
model = ModelWrapper(dl_file)
154+
model = qonnx.util.cleanup.cleanup_model(model)
155+
model = model.transform(GemmToMatMul())
156+
model = qonnx.util.cleanup.cleanup_model(model)
157+
return model
158+
159+
160+
@pytest.fixture(scope='module')
161+
def conv1d_small_keras_model():
162+
"""
163+
Load a simple conv1d, originally keras, unquantized model
164+
"""
165+
dl_file = str(example_model_path / "onnx/conv1d_small_keras.onnx")
166+
assert os.path.isfile(dl_file)
167+
168+
model = ModelWrapper(dl_file)
169+
model = qonnx.util.cleanup.cleanup_model(model)
170+
model = model.transform(ConvertToChannelsLastAndClean())
171+
model = model.transform(GemmToMatMul())
172+
model = qonnx.util.cleanup.cleanup_model(model)
173+
return model
174+
175+
176+
@pytest.fixture(scope='module')
177+
def conv2d_small_keras_model():
178+
"""
179+
Load a simple conv2d, originally keras, unquantized model
180+
"""
181+
dl_file = str(example_model_path / "onnx/conv2d_small_keras.onnx")
182+
assert os.path.isfile(dl_file)
183+
184+
model = ModelWrapper(dl_file)
185+
model = qonnx.util.cleanup.cleanup_model(model)
186+
model = model.transform(ConvertToChannelsLastAndClean())
187+
model = model.transform(GemmToMatMul())
188+
model = qonnx.util.cleanup.cleanup_model(model)
189+
return model
190+
191+
192+
@pytest.fixture(scope='module')
193+
def conv2d_small_mp_keras_model():
194+
"""
195+
Load a conv2d model with max pooling, originally keras, unquantized model
196+
"""
197+
dl_file = str(example_model_path / "onnx/conv2d_small_mp_keras.onnx")
198+
assert os.path.isfile(dl_file)
199+
200+
model = ModelWrapper(dl_file)
201+
model = qonnx.util.cleanup.cleanup_model(model)
202+
model = model.transform(ConvertToChannelsLastAndClean())
203+
model = model.transform(GemmToMatMul())
204+
model = qonnx.util.cleanup.cleanup_model(model)
205+
return model
206+
207+
115208
# The actual tests
116209

117210

@@ -216,25 +309,43 @@ def test_sep_conv(sep_conv_model, backend):
216309
np.testing.assert_allclose(y_qonnx.ravel(), y_hls4ml.ravel(), atol=1e-2, rtol=1)
217310

218311

312+
@pytest.mark.parametrize(
313+
'model_name',
314+
[
315+
'two_layer_keras_model',
316+
'three_layer_keras_model',
317+
'two_layer_pytorch_model',
318+
'three_layer_pytorch_model',
319+
'conv1d_small_keras_model',
320+
'conv2d_small_keras_model',
321+
'conv2d_small_mp_keras_model',
322+
],
323+
)
219324
@pytest.mark.parametrize('backend', ['Vitis'])
220325
@pytest.mark.parametrize('io_type', ['io_parallel', 'io_stream'])
221-
def test_three_layer_keras(three_layer_keras_model, io_type, backend):
222-
model = three_layer_keras_model
326+
def test_simple_model(model_name, io_type, backend, request):
327+
if model_name == 'conv2d_small_mp_keras_model' and io_type == 'io_stream':
328+
# Not yet supported due to an issue with channels last conversion
329+
# There is a qonnx PR.
330+
pytest.skip()
331+
model = request.getfixturevalue(model_name)
223332
ishape = tuple(model.get_tensor_shape(model.graph.input[0].name))
224333
X = np.random.uniform(low=0, high=1, size=np.prod(ishape)).reshape(ishape)
225-
X = (np.round(X * 2**16) * 2**-16).astype(np.float32)
334+
X = (np.round(X * 2**10) * 2**-10).astype(np.float32)
226335
idict = {model.graph.input[0].name: X}
227336
y_qonnx = oxe.execute_onnx(model, idict)[model.graph.output[0].name]
228337

229338
config = hls4ml.utils.config.config_from_onnx_model(
230-
model, granularity='name', backend=backend, default_precision='fixed<32,16>'
339+
model, granularity='name', backend=backend, default_precision='fixed<16,6>'
231340
)
232341

233-
config['LayerName']['Softmax_0']['Implementation'] = 'legacy'
342+
for layer in config['LayerName']:
343+
if layer.startswith('Softmax'):
344+
config['LayerName'][layer]['Implementation'] = 'legacy'
234345

235346
hls_model = hls4ml.converters.convert_from_onnx_model(
236347
model,
237-
output_dir=str(test_root_path / f'hls4mlprj_onnx_three_layer_keras_{io_type}_{backend}'),
348+
output_dir=str(test_root_path / f'hls4mlprj_onnx_{model_name}_{io_type}_{backend}'),
238349
io_type=io_type,
239350
backend=backend,
240351
hls_config=config,

0 commit comments

Comments
 (0)