Skip to content

Commit 54d7a34

Browse files
committed
addressing Jovan's comments
1 parent 58b7913 commit 54d7a34

File tree

3 files changed

+4
-4
lines changed

3 files changed

+4
-4
lines changed

hls4ml/backends/quartus/passes/recurrent_templates.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ def format(self, node):
9393
params['config_mult_h'] = f'config{node.index}_h_mult'
9494
params['act_t'] = '{}_config{}'.format(node.get_attr('activation'), str(node.index) + '_act')
9595
params['act_recurrent_t'] = '{}_config{}'.format(node.get_attr('recurrent_activation'), str(node.index) + '_rec_act')
96-
params['pytorch'] = 'true' if "pytorch" in node.attributes.keys() else 'false'
96+
params['pytorch'] = 'true' if node.get_attr('pytorch', False) else 'false'
9797
gru_config = self.gru_template.format(**params)
9898

9999
# Activation is on candidate hidden state, dimensionality (1, n_units)
@@ -306,7 +306,7 @@ def __init__(self):
306306

307307
def format(self, node):
308308
params = self._default_function_params(node)
309-
if "pytorch" in node.attributes.keys():
309+
if node.get_attr('pytorch', False):
310310
self.template = simple_rnn_pytorch_function_template
311311
params['weights'] = 'w{0}, wr{0}, b{0}, br{0}'.format(str(node.index))
312312
else:

hls4ml/backends/vivado/passes/recurrent_templates.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ def format(self, node):
9898
params['act_t'] = '{}_config{}'.format(node.get_attr('activation'), node.index)
9999
params['strategy'] = node.get_attr('strategy')
100100
params['static'] = 'true' if node.attributes['static'] else 'false'
101-
params['pytorch'] = 'true' if "pytorch" in node.attributes.keys() else 'false'
101+
params['pytorch'] = 'true' if node.get_attr('pytorch', False) else 'false'
102102
params['recr_type'] = node.class_name.lower()
103103
params['RECR_TYPE'] = node.class_name
104104

test/pytest/test_pytorch_api.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -846,7 +846,7 @@ def forward(self, x):
846846

847847
# X_input is channels last
848848
X_input = np.ascontiguousarray(X_input.transpose(0, 2, 1))
849-
config = config_from_pytorch_model(model, inputs_channel_last=True, transpose_outputs=False)
849+
config = config_from_pytorch_model(model, channels_last_conversion="internal", transpose_outputs=False)
850850

851851
output_dir = str(test_root_path / f'hls4mlprj_pytorch_view_{backend}_{io_type}')
852852
hls_model = convert_from_pytorch_model(

0 commit comments

Comments
 (0)