Skip to content

Commit 05bac8d

Browse files
vloncarjmduarte
andauthored
Support keepdims in GlobalPooling layers (#716)
* Support keepdims in GlobalPooling layers * Update ci-template.yml * parametrize fixture for keepdims * pre-commit --------- Co-authored-by: Javier Duarte <jduarte@ucsd.edu>
1 parent f0bcd4f commit 05bac8d

File tree

3 files changed

+104
-59
lines changed

3 files changed

+104
-59
lines changed

hls4ml/converters/keras/pooling.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,14 +67,27 @@ def parse_global_pooling_layer(keras_layer, input_names, input_shapes, data_read
6767
assert 'Pooling' in keras_layer['class_name']
6868

6969
layer = parse_default_keras_layer(keras_layer, input_names)
70+
layer['keepdims'] = keras_layer['config']['keepdims']
7071

7172
if int(layer['class_name'][-2]) == 1:
7273
(layer['n_in'], layer['n_filt']) = parse_data_format(input_shapes[0], layer['data_format'])
7374

74-
output_shape = [input_shapes[0][0], layer['n_filt']]
75+
if layer['keepdims']:
76+
if layer['data_format'] == 'channels_last':
77+
output_shape = [input_shapes[0][0], 1, layer['n_filt']]
78+
elif layer['data_format'] == 'channels_first':
79+
output_shape = [input_shapes[0][0], layer['n_filt'], 1]
80+
else:
81+
output_shape = [input_shapes[0][0], layer['n_filt']]
7582
elif int(layer['class_name'][-2]) == 2:
7683
(layer['in_height'], layer['in_width'], layer['n_filt']) = parse_data_format(input_shapes[0], layer['data_format'])
7784

78-
output_shape = [input_shapes[0][0], layer['n_filt']]
85+
if layer['keepdims']:
86+
if layer['data_format'] == 'channels_last':
87+
output_shape = [input_shapes[0][0], 1, 1, layer['n_filt']]
88+
elif layer['data_format'] == 'channels_first':
89+
output_shape = [input_shapes[0][0], layer['n_filt'], 1, 1]
90+
else:
91+
output_shape = [input_shapes[0][0], layer['n_filt']]
7992

8093
return layer, output_shape

test/pytest/ci-template.yml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
.pytest:
22
stage: test
3-
image: gitlab-registry.cern.ch/fastmachinelearning/hls4ml-testing:0.2.base
4-
tags:
3+
image: gitlab-registry.cern.ch/fastmachinelearning/hls4ml-testing:0.4.base
4+
tags:
55
- docker
66
before_script:
77
- source ~/.bashrc
@@ -14,10 +14,10 @@
1414
artifacts:
1515
when: always
1616
reports:
17-
junit:
17+
junit:
1818
- test/pytest/report.xml
1919
coverage_report:
2020
coverage_format: cobertura
2121
path: test/pytest/coverage.xml
2222
paths:
23-
- test/pytest/hls4mlprj*.tar.gz
23+
- test/pytest/hls4mlprj*.tar.gz

test/pytest/test_globalpooling.py

Lines changed: 85 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -1,94 +1,126 @@
1+
from pathlib import Path
2+
3+
import numpy as np
14
import pytest
5+
from tensorflow.keras.layers import GlobalAveragePooling1D, GlobalAveragePooling2D, GlobalMaxPooling1D, GlobalMaxPooling2D
26
from tensorflow.keras.models import Sequential
3-
from tensorflow.keras.layers import GlobalAveragePooling1D, GlobalMaxPooling1D, GlobalAveragePooling2D, GlobalMaxPooling2D
4-
import numpy as np
7+
58
import hls4ml
6-
from pathlib import Path
79

810
test_root_path = Path(__file__).parent
911

1012
in_shape = 18
1113
in_filt = 6
1214
atol = 5e-3
1315

16+
1417
@pytest.fixture(scope='module')
1518
def data_1d():
1619
return np.random.rand(100, in_shape, in_filt)
1720

18-
@pytest.fixture(scope='module')
19-
def keras_model_max_1d():
20-
model = Sequential()
21-
model.add(GlobalMaxPooling1D(input_shape=(in_shape, in_filt)))
22-
model.compile()
23-
return model
2421

2522
@pytest.fixture(scope='module')
26-
def keras_model_avg_1d():
23+
def keras_model_1d(request):
24+
model_type = request.param['model_type']
25+
keepdims = request.param['keepdims']
2726
model = Sequential()
28-
model.add(GlobalAveragePooling1D(input_shape=(in_shape, in_filt)))
27+
if model_type == 'avg':
28+
model.add(GlobalAveragePooling1D(input_shape=(in_shape, in_filt), keepdims=keepdims))
29+
elif model_type == 'max':
30+
model.add(GlobalMaxPooling1D(input_shape=(in_shape, in_filt), keepdims=keepdims))
2931
model.compile()
30-
return model
31-
32+
return model, model_type, keepdims
33+
3234

3335
@pytest.mark.parametrize('backend', ['Quartus', 'Vivado'])
34-
@pytest.mark.parametrize('model_type', ['max', 'avg'])
36+
@pytest.mark.parametrize(
37+
'keras_model_1d',
38+
[
39+
{'model_type': 'max', 'keepdims': True},
40+
{'model_type': 'max', 'keepdims': False},
41+
{'model_type': 'avg', 'keepdims': True},
42+
{'model_type': 'avg', 'keepdims': False},
43+
],
44+
ids=[
45+
'model_type-max-keepdims-True',
46+
'model_type-max-keepdims-False',
47+
'model_type-avg-keepdims-True',
48+
'model_type-avg-keepdims-False',
49+
],
50+
indirect=True,
51+
)
3552
@pytest.mark.parametrize('io_type', ['io_parallel', 'io_stream'])
36-
def test_global_pool1d(backend, keras_model_max_1d, keras_model_avg_1d, data_1d, model_type, io_type):
37-
if model_type == 'avg':
38-
model = keras_model_avg_1d
39-
else:
40-
model = keras_model_max_1d
41-
53+
def test_global_pool1d(backend, keras_model_1d, data_1d, io_type):
54+
55+
model, model_type, keepdims = keras_model_1d
56+
4257
config = hls4ml.utils.config_from_keras_model(model, default_precision='ap_fixed<32,9>', granularity='name')
4358

44-
hls_model = hls4ml.converters.convert_from_keras_model(model,
45-
hls_config=config,
46-
io_type=io_type,
47-
output_dir=str(test_root_path / f'hls4mlprj_globalplool1d_{backend}_{io_type}_{model_type}'),
48-
backend=backend)
59+
hls_model = hls4ml.converters.convert_from_keras_model(
60+
model,
61+
hls_config=config,
62+
io_type=io_type,
63+
output_dir=str(test_root_path / f'hls4mlprj_globalplool1d_{backend}_{io_type}_{model_type}_keepdims{keepdims}'),
64+
backend=backend,
65+
)
4966
hls_model.compile()
50-
51-
y_keras = np.squeeze(model.predict(data_1d))
52-
y_hls = hls_model.predict(data_1d)
67+
68+
y_keras = model.predict(data_1d)
69+
y_hls = hls_model.predict(data_1d).reshape(y_keras.shape)
5370
np.testing.assert_allclose(y_keras, y_hls, rtol=0, atol=atol, verbose=True)
5471

72+
5573
@pytest.fixture(scope='module')
5674
def data_2d():
5775
return np.random.rand(100, in_shape, in_shape, in_filt)
5876

59-
@pytest.fixture(scope='module')
60-
def keras_model_max_2d():
61-
model = Sequential()
62-
model.add(GlobalMaxPooling2D(input_shape=(in_shape, in_shape, in_filt)))
63-
model.compile()
64-
return model
6577

6678
@pytest.fixture(scope='module')
67-
def keras_model_avg_2d():
79+
def keras_model_2d(request):
80+
model_type = request.param['model_type']
81+
keepdims = request.param['keepdims']
6882
model = Sequential()
69-
model.add(GlobalAveragePooling2D(input_shape=(in_shape, in_shape, in_filt)))
83+
if model_type == 'avg':
84+
model.add(GlobalAveragePooling2D(input_shape=(in_shape, in_shape, in_filt), keepdims=keepdims))
85+
elif model_type == 'max':
86+
model.add(GlobalMaxPooling2D(input_shape=(in_shape, in_shape, in_filt), keepdims=keepdims))
7087
model.compile()
71-
return model
88+
return model, model_type, keepdims
89+
7290

7391
@pytest.mark.parametrize('backend', ['Quartus', 'Vivado'])
74-
@pytest.mark.parametrize('model_type', ['max', 'avg'])
92+
@pytest.mark.parametrize(
93+
'keras_model_2d',
94+
[
95+
{'model_type': 'max', 'keepdims': True},
96+
{'model_type': 'max', 'keepdims': False},
97+
{'model_type': 'avg', 'keepdims': True},
98+
{'model_type': 'avg', 'keepdims': False},
99+
],
100+
ids=[
101+
'model_type-max-keepdims-True',
102+
'model_type-max-keepdims-False',
103+
'model_type-avg-keepdims-True',
104+
'model_type-avg-keepdims-False',
105+
],
106+
indirect=True,
107+
)
75108
@pytest.mark.parametrize('io_type', ['io_parallel', 'io_stream'])
76-
def test_global_pool2d(backend, keras_model_max_2d, keras_model_avg_2d, data_2d, model_type, io_type):
77-
78-
if model_type == 'avg':
79-
model = keras_model_avg_2d
80-
else:
81-
model = keras_model_max_2d
82-
109+
def test_global_pool2d(backend, keras_model_2d, data_2d, io_type):
110+
111+
model, model_type, keepdims = keras_model_2d
112+
83113
config = hls4ml.utils.config_from_keras_model(model, default_precision='ap_fixed<32,9>', granularity='name')
84114

85-
hls_model = hls4ml.converters.convert_from_keras_model(model,
86-
hls_config=config,
87-
io_type=io_type,
88-
output_dir=str(test_root_path / f'hls4mlprj_globalplool2d_{backend}_{io_type}_{model_type}'),
89-
backend=backend)
115+
hls_model = hls4ml.converters.convert_from_keras_model(
116+
model,
117+
hls_config=config,
118+
io_type=io_type,
119+
output_dir=str(test_root_path / f'hls4mlprj_globalplool2d_{backend}_{io_type}_{model_type}_keepdims{keepdims}'),
120+
backend=backend,
121+
)
90122
hls_model.compile()
91-
92-
y_keras = np.squeeze(model.predict(data_2d))
93-
y_hls = hls_model.predict(data_2d)
123+
124+
y_keras = model.predict(data_2d)
125+
y_hls = hls_model.predict(data_2d).reshape(y_keras.shape)
94126
np.testing.assert_allclose(y_keras, y_hls, rtol=0, atol=atol, verbose=True)

0 commit comments

Comments
 (0)