@@ -428,6 +428,37 @@ def format(self, node):
428
428
return template .format (** params )
429
429
430
430
431
+ class BidirectionalFunctionTemplate (FunctionCallTemplate ):
432
+ def __init__ (self ):
433
+ super ().__init__ ((Bidirectional ), include_header = recr_include_list )
434
+
435
+ def format (self , node ):
436
+ params = self ._default_function_params (node )
437
+
438
+ # TO DO: Add initial states functions
439
+ '''
440
+ if params['pass_initial_states'] == 'true':
441
+ params['input2_t'] = node.get_input_variable(node.inputs[1]).type.name
442
+ params['input2'] = node.get_input_variable(node.inputs[1]).name
443
+ if node.class_name == 'BLSTM':
444
+ params['input3'] = node.get_input_variable(node.inputs[2]).name
445
+ params['input3_t'] = node.get_input_variable(node.inputs[2]).type.name
446
+ '''
447
+
448
+ params ['w' ] = node .get_weights ('forward_weight' ).name
449
+ params ['b' ] = node .get_weights ('forward_bias' ).name
450
+ params ['wr' ] = node .get_weights ('forward_recurrent_weight' ).name
451
+ params ['br' ] = node .get_weights ('forward_recurrent_bias' ).name
452
+ params ['w_b' ] = node .get_weights ('backward_weight' ).name
453
+ params ['b_b' ] = node .get_weights ('backward_bias' ).name
454
+ params ['wr_b' ] = node .get_weights ('backward_recurrent_weight' ).name
455
+ params ['br_b' ] = node .get_weights ('backward_recurrent_bias' ).name
456
+
457
+ template = bidirectional_function_template
458
+
459
+ return template .format (** params )
460
+
461
+
431
462
time_distributed_config_template = """struct config{index} : nnet::time_distributed_config {{
432
463
static const unsigned dim = {dim};
433
464
@@ -492,33 +523,3 @@ def format(self, node):
492
523
return self .template_start .format (** params )
493
524
else :
494
525
return self .template_end .format (** params )
495
-
496
- class BidirectionalFunctionTemplate (FunctionCallTemplate ):
497
- def __init__ (self ):
498
- super ().__init__ ((Bidirectional ), include_header = recr_include_list )
499
-
500
- def format (self , node ):
501
- params = self ._default_function_params (node )
502
-
503
- # TO DO: Add initial states functions
504
- '''
505
- if params['pass_initial_states'] == 'true':
506
- params['input2_t'] = node.get_input_variable(node.inputs[1]).type.name
507
- params['input2'] = node.get_input_variable(node.inputs[1]).name
508
- if node.class_name == 'BLSTM':
509
- params['input3'] = node.get_input_variable(node.inputs[2]).name
510
- params['input3_t'] = node.get_input_variable(node.inputs[2]).type.name
511
- '''
512
-
513
- params ['w' ] = node .get_weights ('forward_weight' ).name
514
- params ['b' ] = node .get_weights ('forward_bias' ).name
515
- params ['wr' ] = node .get_weights ('forward_recurrent_weight' ).name
516
- params ['br' ] = node .get_weights ('forward_recurrent_bias' ).name
517
- params ['w_b' ] = node .get_weights ('backward_weight' ).name
518
- params ['b_b' ] = node .get_weights ('backward_bias' ).name
519
- params ['wr_b' ] = node .get_weights ('backward_recurrent_weight' ).name
520
- params ['br_b' ] = node .get_weights ('backward_recurrent_bias' ).name
521
-
522
- template = bidirectional_function_template
523
-
524
- return template .format (** params )
0 commit comments