Skip to content

Commit ac70217

Browse files
committed
clean up mult-dimensional dense
1 parent f93aab9 commit ac70217

File tree

6 files changed

+27
-61
lines changed

6 files changed

+27
-61
lines changed

hls4ml/backends/catapult/passes/pointwise.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
from copy import copy
22

3-
import numpy as np
4-
53
from hls4ml.backends.catapult.passes.convolution_templates import (
64
Conv1DConfigTemplate,
75
Conv1DFunctionTemplate,
@@ -78,9 +76,6 @@ def match(self, node):
7876
def transform(self, model, node):
7977
dim = node.__class__.__name__[-2:] # '1D' or '2D'
8078
pw_node = model.make_node('PointwiseConv' + dim, node.name, copy(node.attributes), node.inputs.copy())
81-
if len(node.weights['weight'].data.shape) == 2: # This can happen if we assign weights of Dense layer to 1x1 Conv2D
82-
expand_axis = tuple(range(int(dim[0])))
83-
pw_node.weights['weight'].data = np.expand_dims(node.weights['weight'].data, axis=expand_axis)
8479
pw_node.weights['bias'].data = node.weights['bias'].data
8580
# Set strategy to ensure lowercase string is passed to the template
8681
if model.config.is_resource_strategy(pw_node):

hls4ml/backends/quartus/passes/pointwise.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
from copy import copy
22

3-
import numpy as np
4-
53
from hls4ml.backends.fpga.fpga_layers import PointwiseConv1D, PointwiseConv2D
64
from hls4ml.backends.quartus.passes.convolution_templates import (
75
Conv1DConfigTemplate,
@@ -86,9 +84,6 @@ def transform(self, model, node):
8684
pw_node = model.make_node(
8785
'PointwiseConv' + dim, node.name, copy(node.attributes), node.inputs.copy(), outputs=node.outputs.copy()
8886
)
89-
if len(node.weights['weight'].data.shape) == 2: # This can happen if we assign weights of Dense layer to 1x1 Conv2D
90-
expand_axis = tuple(range(int(dim[0])))
91-
pw_node.weights['weight'].data = np.expand_dims(node.weights['weight'].data, axis=expand_axis)
9287
pw_node.weights['bias'].data = node.weights['bias'].data
9388
model.replace_node(node, pw_node)
9489

hls4ml/backends/vivado/passes/pointwise.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
from copy import copy
22

3-
import numpy as np
4-
53
from hls4ml.backends.fpga.fpga_layers import PointwiseConv1D, PointwiseConv2D
64
from hls4ml.backends.vivado.passes.convolution_templates import (
75
Conv1DConfigTemplate,
@@ -78,9 +76,6 @@ def match(self, node):
7876
def transform(self, model, node):
7977
dim = node.__class__.__name__[-2:] # '1D' or '2D'
8078
pw_node = model.make_node('PointwiseConv' + dim, node.name, copy(node.attributes), node.inputs.copy())
81-
if len(node.weights['weight'].data.shape) == 2: # This can happen if we assign weights of Dense layer to 1x1 Conv2D
82-
expand_axis = tuple(range(int(dim[0])))
83-
pw_node.weights['weight'].data = np.expand_dims(node.weights['weight'].data, axis=expand_axis)
8479
pw_node.weights['bias'].data = node.weights['bias'].data
8580
# Set strategy to ensure lowercase string is passed to the template
8681
if model.config.is_resource_strategy(pw_node):

hls4ml/model/optimizer/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
'qkeras_factorize_alpha',
4545
'extract_ternary_threshold',
4646
'fuse_consecutive_batch_normalization',
47+
'replace_multidimensional_dense_with_conv',
4748
],
4849
) # TODO Maybe not all QKeras optmizers belong here?
4950

@@ -53,7 +54,6 @@
5354
'eliminate_linear_activation',
5455
'fuse_consecutive_batch_normalization',
5556
'fuse_batch_normalization',
56-
'replace_multidimensional_dense_with_conv',
5757
'infer_precision_types',
5858
'set_precision_concat',
5959
],

hls4ml/model/optimizer/passes/multi_dense.py

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,14 @@
55

66

77
class ReplaceMultidimensionalDenseWithConv(OptimizerPass):
8+
"""
9+
This matches all multidimensional Dense layers and changes them to a convolution.
10+
Note: the convolution may subsequently be changed to a pointwise convolution for
11+
bakends that implement special pointwise convolutions.
12+
"""
13+
814
def match(self, node):
9-
return (
10-
isinstance(node, Dense)
11-
and len(node.get_input_variable().shape) - sum(d == 1 for d in node.get_input_variable().shape) > 1
12-
)
13-
# The above sum checks for the number of dimensions in the Dense with size 1
14-
# The subtraction allows the check to only count the number of dimensions with non-1 size
15-
# For example, this prevents matching for a Dense layer with shape (1,N)
15+
return isinstance(node, Dense) and len(node.get_input_variable().shape) > 1
1616

1717
def transform(self, model, node):
1818
dim = len(node.get_input_variable().shape) - 1
@@ -23,7 +23,7 @@ def transform(self, model, node):
2323
'padding': 'valid',
2424
'n_chan': input_shape[-1],
2525
'n_filt': node.get_attr('n_out'),
26-
'weight_data': node.get_attr('weight_data'),
26+
'weight_data': np.expand_dims(node.get_attr('weight_data'), axis=tuple(range(dim))),
2727
'bias_data': node.get_attr('bias_data'),
2828
}
2929

@@ -58,11 +58,8 @@ def transform(self, model, node):
5858
else:
5959
raise Exception('Cannot replace Dense over {dim}D tensor with Conv{dim}D.'.format(dim=dim))
6060

61-
class_name = 'PointwiseConv' + str(dim) + 'D'
61+
class_name = 'Conv' + str(dim) + 'D'
6262
pw_node = model.make_node(class_name, node.name, pointwise_attrs, node.inputs.copy())
63-
if len(node.weights['weight'].data.shape) == 2: # This can happen if we assign weights of Dense layer to 1x1 Conv2D
64-
pw_node.weights['weight'].data = np.expand_dims(node.weights['weight'].data, axis=tuple(range(dim)))
65-
pw_node.weights['bias'].data = node.weights['bias'].data
6663
model.replace_node(node, pw_node)
6764

6865
return True

test/pytest/test_multi_dense.py

Lines changed: 17 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -11,46 +11,32 @@
1111

1212

1313
@pytest.mark.parametrize(
14-
'backend, io_type',
14+
'backend, strategy',
1515
[
16-
('Quartus', 'io_parallel'),
17-
('Vivado', 'io_parallel'),
18-
('Vitis', 'io_parallel'),
19-
('Vivado', 'io_stream'),
20-
('Vivado', 'io_stream'),
21-
('Vitis', 'io_stream'),
16+
('Vitis', 'Latency'),
17+
('Vitis', 'Resource'),
18+
('Quartus', 'Resource'),
19+
('Catapult', 'Latency'),
20+
('Catapult', 'Resource'),
2221
],
2322
)
24-
def test_multi_dense(backend, io_type):
23+
@pytest.mark.parametrize('io_type', ['io_parallel', 'io_stream'])
24+
@pytest.mark.parametrize('shape', [(4, 3), (4, 1), (2, 3, 2), (1, 3, 1)])
25+
def test_multi_dense(backend, strategy, io_type, shape):
2526
model = tf.keras.models.Sequential()
26-
model.add(
27-
Dense(
28-
4,
29-
input_shape=(
30-
8,
31-
8,
32-
),
33-
name='Dense',
34-
use_bias=True,
35-
kernel_initializer=tf.keras.initializers.RandomUniform(minval=1, maxval=10),
36-
bias_initializer='zeros',
37-
kernel_regularizer=None,
38-
bias_regularizer=None,
39-
activity_regularizer=None,
40-
kernel_constraint=None,
41-
bias_constraint=None,
42-
activation='relu',
43-
)
44-
)
27+
model.add(Dense(7, input_shape=shape, activation='relu'))
28+
model.add(Dense(2, activation='relu'))
4529
model.compile(optimizer='adam', loss='mse')
4630

47-
X_input = np.random.rand(100, 8, 8)
31+
X_input = np.random.rand(100, *shape)
32+
X_input = np.round(X_input * 2**10) * 2**-10 # make it an exact ap_fixed<16,6>
4833

4934
keras_prediction = model.predict(X_input)
5035

51-
default_precision = 'ap_fixed<32, 16>' if backend in ['Vivado', 'Vitis'] else 'ac_fixed<32, 16, true>'
52-
config = hls4ml.utils.config_from_keras_model(model, default_precision=default_precision)
53-
output_dir = str(test_root_path / f'hls4mlprj_multi_dense_{backend}_{io_type}')
36+
config = hls4ml.utils.config_from_keras_model(model, granularity='name', backend=backend)
37+
config['Model']['Strategy'] = strategy
38+
shapestr = '_'.join(str(x) for x in shape)
39+
output_dir = str(test_root_path / f'hls4mlprj_multi_dense_{backend}_{strategy}_{io_type}_{shapestr}')
5440

5541
hls_model = hls4ml.converters.convert_from_keras_model(
5642
model, hls_config=config, output_dir=output_dir, backend=backend, io_type=io_type
@@ -61,5 +47,3 @@ def test_multi_dense(backend, io_type):
6147
hls_prediction = hls_model.predict(X_input).reshape(keras_prediction.shape)
6248

6349
np.testing.assert_allclose(hls_prediction, keras_prediction, rtol=1e-2, atol=0.01)
64-
65-
assert list(hls_model.get_layers())[1].class_name == 'PointwiseConv1D'

0 commit comments

Comments
 (0)