Skip to content

Commit 2f04812

Browse files
authored
Merge branch 'main' into namespaces_and_other_emulator_goodies
2 parents 425e77a + ba08ca1 commit 2f04812

File tree

19 files changed

+526
-116
lines changed

19 files changed

+526
-116
lines changed

hls4ml/backends/catapult/passes/pointwise.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
from copy import copy
22

3-
import numpy as np
4-
53
from hls4ml.backends.catapult.passes.convolution_templates import (
64
Conv1DConfigTemplate,
75
Conv1DFunctionTemplate,
@@ -78,9 +76,6 @@ def match(self, node):
7876
def transform(self, model, node):
7977
dim = node.__class__.__name__[-2:] # '1D' or '2D'
8078
pw_node = model.make_node('PointwiseConv' + dim, node.name, copy(node.attributes), node.inputs.copy())
81-
if len(node.weights['weight'].data.shape) == 2: # This can happen if we assign weights of Dense layer to 1x1 Conv2D
82-
expand_axis = tuple(range(int(dim[0])))
83-
pw_node.weights['weight'].data = np.expand_dims(node.weights['weight'].data, axis=expand_axis)
8479
pw_node.weights['bias'].data = node.weights['bias'].data
8580
# Set strategy to ensure lowercase string is passed to the template
8681
if model.config.is_resource_strategy(pw_node):

hls4ml/backends/quartus/passes/pointwise.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
from copy import copy
22

3-
import numpy as np
4-
53
from hls4ml.backends.fpga.fpga_layers import PointwiseConv1D, PointwiseConv2D
64
from hls4ml.backends.quartus.passes.convolution_templates import (
75
Conv1DConfigTemplate,
@@ -86,9 +84,6 @@ def transform(self, model, node):
8684
pw_node = model.make_node(
8785
'PointwiseConv' + dim, node.name, copy(node.attributes), node.inputs.copy(), outputs=node.outputs.copy()
8886
)
89-
if len(node.weights['weight'].data.shape) == 2: # This can happen if we assign weights of Dense layer to 1x1 Conv2D
90-
expand_axis = tuple(range(int(dim[0])))
91-
pw_node.weights['weight'].data = np.expand_dims(node.weights['weight'].data, axis=expand_axis)
9287
pw_node.weights['bias'].data = node.weights['bias'].data
9388
model.replace_node(node, pw_node)
9489

hls4ml/backends/quartus/passes/recurrent_templates.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@
6666
using activation_recr = nnet::activation::{recurrent_activation}<x_T, y_T, config_T>;
6767
6868
static const unsigned reuse_factor = {reuse};
69+
static const unsigned pytorch_order = {pytorch};
6970
static const bool store_weights_in_bram = false;
7071
}};\n'''
7172

@@ -92,6 +93,7 @@ def format(self, node):
9293
params['config_mult_h'] = f'config{node.index}_h_mult'
9394
params['act_t'] = '{}_config{}'.format(node.get_attr('activation'), str(node.index) + '_act')
9495
params['act_recurrent_t'] = '{}_config{}'.format(node.get_attr('recurrent_activation'), str(node.index) + '_rec_act')
96+
params['pytorch'] = 'true' if node.get_attr('pytorch', False) else 'false'
9597
gru_config = self.gru_template.format(**params)
9698

9799
# Activation is on candidate hidden state, dimensionality (1, n_units)
@@ -256,6 +258,9 @@ def format(self, node):
256258
}};\n"""
257259

258260
simple_rnn_function_template = 'nnet::simple_rnn<{input_t}, {output_t}, {config}>({input}, {output}, {weights});'
261+
simple_rnn_pytorch_function_template = (
262+
'nnet::simple_rnn_pytorch<{input_t}, {output_t}, {config}>({input}, {output}, {weights});'
263+
)
259264

260265

261266
class SimpleRNNConfigTemplate(LayerConfigTemplate):
@@ -301,5 +306,9 @@ def __init__(self):
301306

302307
def format(self, node):
303308
params = self._default_function_params(node)
304-
params['weights'] = 'w{0}, wr{0}, b{0}'.format(str(node.index))
309+
if node.get_attr('pytorch', False):
310+
self.template = simple_rnn_pytorch_function_template
311+
params['weights'] = 'w{0}, wr{0}, b{0}, br{0}'.format(str(node.index))
312+
else:
313+
params['weights'] = 'w{0}, wr{0}, b{0}'.format(str(node.index))
305314
return self.template.format(**params)

hls4ml/backends/vivado/passes/pointwise.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
from copy import copy
22

3-
import numpy as np
4-
53
from hls4ml.backends.fpga.fpga_layers import PointwiseConv1D, PointwiseConv2D
64
from hls4ml.backends.vivado.passes.convolution_templates import (
75
Conv1DConfigTemplate,
@@ -78,9 +76,6 @@ def match(self, node):
7876
def transform(self, model, node):
7977
dim = node.__class__.__name__[-2:] # '1D' or '2D'
8078
pw_node = model.make_node('PointwiseConv' + dim, node.name, copy(node.attributes), node.inputs.copy())
81-
if len(node.weights['weight'].data.shape) == 2: # This can happen if we assign weights of Dense layer to 1x1 Conv2D
82-
expand_axis = tuple(range(int(dim[0])))
83-
pw_node.weights['weight'].data = np.expand_dims(node.weights['weight'].data, axis=expand_axis)
8479
pw_node.weights['bias'].data = node.weights['bias'].data
8580
# Set strategy to ensure lowercase string is passed to the template
8681
if model.config.is_resource_strategy(pw_node):

hls4ml/backends/vivado/passes/recurrent_templates.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@
6262
static const unsigned reuse_factor = {reuse};
6363
static const bool store_weights_in_bram = false;
6464
static const bool use_static = {static};
65+
static const bool pytorch_order = {pytorch};
6566
}};\n"""
6667

6768
recr_function_template = 'nnet::{recr_type}_stack<{input_t}, {output_t}, {config}>({input}, {output}, {w}, {wr}, {b}, {br});'
@@ -97,6 +98,7 @@ def format(self, node):
9798
params['act_t'] = '{}_config{}'.format(node.get_attr('activation'), node.index)
9899
params['strategy'] = node.get_attr('strategy')
99100
params['static'] = 'true' if node.attributes['static'] else 'false'
101+
params['pytorch'] = 'true' if node.get_attr('pytorch', False) else 'false'
100102
params['recr_type'] = node.class_name.lower()
101103
params['RECR_TYPE'] = node.class_name
102104

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
import warnings
2+
3+
import numpy as np
4+
5+
from hls4ml.converters.pytorch_to_hls import pytorch_handler
6+
7+
rnn_layers = ['RNN', 'LSTM', 'GRU']
8+
9+
10+
@pytorch_handler(*rnn_layers)
11+
def parse_rnn_layer(operation, layer_name, input_names, input_shapes, node, class_object, data_reader, config):
12+
assert operation in rnn_layers
13+
14+
layer = {}
15+
16+
layer["name"] = layer_name
17+
18+
layer['inputs'] = [input_names[0]]
19+
if len(input_names) > 1:
20+
warnings.warn(
21+
'hls4ml disregards the initial value of the hidden state passed to the model, assuming that it is all zeros',
22+
stacklevel=2,
23+
)
24+
layer['class_name'] = operation
25+
if operation == "RNN":
26+
layer['class_name'] = 'SimpleRNN'
27+
28+
layer['return_sequences'] = False # parameter does not exist in pytorch
29+
layer['return_state'] = False # parameter does not exist in pytorch
30+
31+
if layer['class_name'] == 'SimpleRNN':
32+
layer['activation'] = class_object.nonlinearity # Default is tanh, can also be ReLU in pytorch
33+
else:
34+
layer['activation'] = "tanh" # GRU and LSTM are hard-coded to use tanh in pytorch
35+
36+
if layer['class_name'] == 'GRU' or layer['class_name'] == 'LSTM':
37+
layer['recurrent_activation'] = 'sigmoid' # GRU and LSTM are hard-coded to use sigmoid in pytorch
38+
39+
layer['time_major'] = not class_object.batch_first
40+
# TODO Should we handle time_major?
41+
if layer['time_major']:
42+
raise Exception('hls4ml only supports "batch-first == True"')
43+
44+
layer['n_timesteps'] = input_shapes[0][1]
45+
layer['n_in'] = input_shapes[0][2]
46+
47+
layer['n_out'] = class_object.hidden_size
48+
49+
if class_object.num_layers > 1:
50+
raise Exception('hls4ml does not support num_layers > 1')
51+
52+
if class_object.bidirectional:
53+
raise Exception('hls4ml does not support birectional RNNs')
54+
55+
if class_object.dropout > 0:
56+
raise Exception('hls4ml does not support RNNs with dropout')
57+
58+
layer['weight_data'] = class_object.weight_ih_l0.data.numpy()
59+
layer['recurrent_weight_data'] = class_object.weight_hh_l0.data.numpy()
60+
layer['bias_data'] = class_object.bias_ih_l0.data.numpy()
61+
layer['recurrent_bias_data'] = class_object.bias_hh_l0.data.numpy()
62+
63+
if class_object.bias is False:
64+
layer['bias_data'] = np.zeros(layer['weight_data'].shape[0])
65+
layer['recurrent_bias_data'] = np.zeros(layer['recurrent_weight_data'].shape[0])
66+
67+
if layer['class_name'] == 'GRU':
68+
layer['apply_reset_gate'] = 'after' # Might be true for pytorch? It's not a free parameter
69+
70+
output_shape = [input_shapes[0][0], layer['n_out']]
71+
72+
layer['pytorch'] = True # need to switch some behaviors to match pytorch implementations
73+
74+
return layer, output_shape

hls4ml/converters/pytorch_to_hls.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@ def decorator(function):
9595
'avg_pool1d': 'AvgPool1d',
9696
'avg_pool2d': 'AvgPool2d',
9797
'flatten': 'Flatten',
98+
'view': 'View',
9899
}
99100

100101

@@ -198,8 +199,21 @@ def pytorch_to_hls(config):
198199

199200
# parse info from class object
200201
input_names = [inputs_map.get(str(i), str(i)) for i in node.args]
201-
input_shapes = [output_shapes[str(i)] for i in node.args]
202-
202+
if pytorch_class in ["RNN", "GRU", "LSTM"]:
203+
# we currently don't support the passing of the initial value of the hidden state to RNN models
204+
input_names = [inputs_map.get(str(node.args[0]), str(node.args[0]))]
205+
input_shapes = [output_shapes[str(node.args[0])]]
206+
# if a 'getitem' is the input to a node, step back in the graph to find the real source of the input
207+
elif "getitem" in node.args[0].name:
208+
for tmp_node in traced_model.graph.nodes:
209+
if tmp_node.name == node.args[0].name:
210+
if "getitem" in tmp_node.args[0].name:
211+
raise Exception('Nested getitem calles not resolved at the moment.')
212+
input_names = [inputs_map.get(str(tmp_node.args[0]), str(tmp_node.args[0]))]
213+
input_shapes = [output_shapes[str(tmp_node.args[0])]]
214+
node.args = [tmp_node.args[0]]
215+
else:
216+
input_shapes = [output_shapes[str(i)] for i in node.args]
203217
# for Conv layers
204218
if 'Conv' in pytorch_class:
205219
if not class_object.padding_mode == 'zeros':
@@ -253,6 +267,8 @@ def pytorch_to_hls(config):
253267
operation = layer_name_map[operation]
254268

255269
# only a limited number of functions are supported
270+
if operation == "getitem":
271+
continue
256272
if operation not in supported_layers:
257273
raise Exception(f'Unsupported function {operation}')
258274
if operation == 'PReLU' or operation == 'batch_norm' or operation == 'conv1d' or operation == 'conv2d':

hls4ml/model/layers.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1042,6 +1042,8 @@ def initialize(self):
10421042

10431043
# biases
10441044
self.add_weights_variable(name='bias', var_name='b{index}')
1045+
if "pytorch" in self.attributes.keys():
1046+
self.add_weights_variable(name='recurrent_bias', var_name='br{index}')
10451047

10461048

10471049
class LSTM(Layer):
@@ -1093,8 +1095,11 @@ def initialize(self):
10931095
# biases
10941096
self.add_weights_variable(name='bias', var_name='b{index}')
10951097

1096-
recurrent_bias = np.zeros(recurrent_weight.shape[1])
1097-
self.add_weights_variable(name='recurrent_bias', var_name='br{index}', data=recurrent_bias)
1098+
if "pytorch" in self.attributes.keys():
1099+
self.add_weights_variable(name='recurrent_bias', var_name='br{index}')
1100+
else:
1101+
recurrent_bias = np.zeros(recurrent_weight.shape[1])
1102+
self.add_weights_variable(name='recurrent_bias', var_name='br{index}', data=recurrent_bias)
10981103

10991104

11001105
class GRU(Layer):

hls4ml/model/optimizer/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
'qkeras_factorize_alpha',
4545
'extract_ternary_threshold',
4646
'fuse_consecutive_batch_normalization',
47+
'replace_multidimensional_dense_with_conv',
4748
],
4849
) # TODO Maybe not all QKeras optmizers belong here?
4950

@@ -53,7 +54,6 @@
5354
'eliminate_linear_activation',
5455
'fuse_consecutive_batch_normalization',
5556
'fuse_batch_normalization',
56-
'replace_multidimensional_dense_with_conv',
5757
'infer_precision_types',
5858
'set_precision_concat',
5959
],

hls4ml/model/optimizer/passes/convert_to_channels_last.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,14 @@ def match(self, node):
1717

1818
def transform(self, model, node):
1919
# If this parameter has not been set, this model does not need to be converted
20-
if 'InputsChannelLast' not in model.config.config['HLSConfig']['Model']:
20+
if 'ChannelsLastConversion' not in model.config.config['HLSConfig']['Model']:
2121
node.channels_last_converted = True
2222
return False
2323
outshape = node.get_output_variable().shape
2424

2525
if isinstance(node, Input):
2626
# if inputs are not yet transposed into channels_last, add transpose layer
27-
if not model.config.config['HLSConfig']['Model']['InputsChannelLast'] and len(outshape) > 1:
27+
if model.config.config['HLSConfig']['Model']['ChannelsLastConversion'] == "full" and len(outshape) > 1:
2828
# Add transpose for input layer
2929
input = node.name
3030
if len(outshape) == 2:
@@ -39,7 +39,7 @@ def transform(self, model, node):
3939
transpose_node.channels_last_converted = True
4040

4141
model.insert_node(transpose_node)
42-
else:
42+
elif model.config.config['HLSConfig']['Model']['ChannelsLastConversion'] == "internal" and len(outshape) > 1:
4343
input_shape = node.get_output_variable().shape
4444
input_shape.append(input_shape.pop(0))
4545
node.get_output_variable().shape = input_shape

0 commit comments

Comments
 (0)