Skip to content

Commit 1c16616

Browse files
author
Enrico Lupi
committed
FIX order
1 parent a1500e4 commit 1c16616

File tree

1 file changed

+31
-30
lines changed

1 file changed

+31
-30
lines changed

hls4ml/backends/vivado/passes/recurrent_templates.py

Lines changed: 31 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -428,6 +428,37 @@ def format(self, node):
428428
return template.format(**params)
429429

430430

431+
class BidirectionalFunctionTemplate(FunctionCallTemplate):
432+
def __init__(self):
433+
super().__init__((Bidirectional), include_header=recr_include_list)
434+
435+
def format(self, node):
436+
params = self._default_function_params(node)
437+
438+
# TO DO: Add initial states functions
439+
'''
440+
if params['pass_initial_states'] == 'true':
441+
params['input2_t'] = node.get_input_variable(node.inputs[1]).type.name
442+
params['input2'] = node.get_input_variable(node.inputs[1]).name
443+
if node.class_name == 'BLSTM':
444+
params['input3'] = node.get_input_variable(node.inputs[2]).name
445+
params['input3_t'] = node.get_input_variable(node.inputs[2]).type.name
446+
'''
447+
448+
params['w'] = node.get_weights('forward_weight').name
449+
params['b'] = node.get_weights('forward_bias').name
450+
params['wr'] = node.get_weights('forward_recurrent_weight').name
451+
params['br'] = node.get_weights('forward_recurrent_bias').name
452+
params['w_b'] = node.get_weights('backward_weight').name
453+
params['b_b'] = node.get_weights('backward_bias').name
454+
params['wr_b'] = node.get_weights('backward_recurrent_weight').name
455+
params['br_b'] = node.get_weights('backward_recurrent_bias').name
456+
457+
template = bidirectional_function_template
458+
459+
return template.format(**params)
460+
461+
431462
time_distributed_config_template = """struct config{index} : nnet::time_distributed_config {{
432463
static const unsigned dim = {dim};
433464
@@ -492,33 +523,3 @@ def format(self, node):
492523
return self.template_start.format(**params)
493524
else:
494525
return self.template_end.format(**params)
495-
496-
class BidirectionalFunctionTemplate(FunctionCallTemplate):
497-
def __init__(self):
498-
super().__init__((Bidirectional), include_header=recr_include_list)
499-
500-
def format(self, node):
501-
params = self._default_function_params(node)
502-
503-
# TO DO: Add initial states functions
504-
'''
505-
if params['pass_initial_states'] == 'true':
506-
params['input2_t'] = node.get_input_variable(node.inputs[1]).type.name
507-
params['input2'] = node.get_input_variable(node.inputs[1]).name
508-
if node.class_name == 'BLSTM':
509-
params['input3'] = node.get_input_variable(node.inputs[2]).name
510-
params['input3_t'] = node.get_input_variable(node.inputs[2]).type.name
511-
'''
512-
513-
params['w'] = node.get_weights('forward_weight').name
514-
params['b'] = node.get_weights('forward_bias').name
515-
params['wr'] = node.get_weights('forward_recurrent_weight').name
516-
params['br'] = node.get_weights('forward_recurrent_bias').name
517-
params['w_b'] = node.get_weights('backward_weight').name
518-
params['b_b'] = node.get_weights('backward_bias').name
519-
params['wr_b'] = node.get_weights('backward_recurrent_weight').name
520-
params['br_b'] = node.get_weights('backward_recurrent_bias').name
521-
522-
template = bidirectional_function_template
523-
524-
return template.format(**params)

0 commit comments

Comments
 (0)