@@ -1538,6 +1538,82 @@ def initialize(self):
1538
1538
self .add_weights_variable (name = 'recurrent_bias_b' , var_name = 'br_b{index}' )
1539
1539
1540
1540
1541
+ class Bidirectional (Layer ):
1542
+ _expected_attributes = [
1543
+ Attribute ('n_out' ),
1544
+ Attribute ('return_sequences' , value_type = bool , default = False ),
1545
+ Attribute ('return_state' , value_type = bool , default = False ),
1546
+ Attribute ('pass_initial_states' , value_type = bool , default = False ),
1547
+ Attribute ('time_major' , value_type = bool , default = False ),
1548
+ Attribute ('forward_activation' , value_type = str ),
1549
+ Attribute ('forward_recurrent_activation' , value_type = str ),
1550
+ WeightAttribute ('forward_weight' ),
1551
+ WeightAttribute ('forward_bias' ),
1552
+ WeightAttribute ('forward_recurrent_weight' ),
1553
+ WeightAttribute ('forward_recurrent_bias' ),
1554
+ TypeAttribute ('forward_weight' ),
1555
+ TypeAttribute ('forward_bias' ),
1556
+ TypeAttribute ('forward_recurrent_weight' ),
1557
+ TypeAttribute ('forward_recurrent_bias' ),
1558
+ Attribute ('backward_activation' , value_type = str ),
1559
+ Attribute ('backward_recurrent_activation' , value_type = str ),
1560
+ WeightAttribute ('backward_weight' ),
1561
+ WeightAttribute ('backward_bias' ),
1562
+ WeightAttribute ('backward_recurrent_weight' ),
1563
+ WeightAttribute ('backward_recurrent_bias' ),
1564
+ TypeAttribute ('backward_weight' ),
1565
+ TypeAttribute ('backward_bias' ),
1566
+ TypeAttribute ('backward_recurrent_weight' ),
1567
+ TypeAttribute ('backward_recurrent_bias' ),
1568
+ ]
1569
+
1570
+ def initialize (self ):
1571
+ if self .attributes ['return_sequences' ]:
1572
+ shape = [self .attributes ['n_timesteps' ], self .attributes ['n_out' ]]
1573
+ dims = [f'N_TIME_STEPS_{ self .index } ' , f'N_OUT_{ self .index } ' ]
1574
+ else :
1575
+ shape = [self .attributes ['n_out' ]]
1576
+ dims = [f'N_OUT_{ self .index } ' ]
1577
+
1578
+ self .add_output_variable (shape , dims )
1579
+
1580
+ if self .attributes ['return_state' ]:
1581
+ state_shape = [self .attributes ['n_out' ]]
1582
+ state_dims = [f'N_OUT_{ self .index } ' ]
1583
+ self .add_output_variable (
1584
+ state_shape , state_dims , out_name = self .outputs [1 ], var_name = 'layer{index}_h' , type_name = 'layer{index}_h_t'
1585
+ )
1586
+ self .add_output_variable (
1587
+ state_shape , state_dims , out_name = self .outputs [2 ], var_name = 'layer{index}_c' , type_name = 'layer{index}_c_t'
1588
+ )
1589
+
1590
+ for dir in ['forward' , 'backward' ]:
1591
+ # weights
1592
+ self .add_weights_variable (name = f'{ dir } _weight' , var_name = (f'w_{ dir [0 ]} _' + '{index}' ))
1593
+
1594
+ # recurrent weights
1595
+ recurrent_weight = self .get_attr (f'{ dir } _recurrent_weight_data' )
1596
+ self .add_weights_variable (
1597
+ name = f'{ dir } _recurrent_weight' , var_name = (f'wr_{ dir [0 ]} _' + '{index}' ), data = recurrent_weight
1598
+ )
1599
+
1600
+ # biases
1601
+ self .add_weights_variable (name = f'{ dir } _bias' , var_name = (f'b_{ dir [0 ]} _' + '{index}' ))
1602
+
1603
+ if self .attributes [f'{ dir } _class_name' ] == 'LSTM' :
1604
+ if "pytorch" in self .attributes .keys ():
1605
+ self .add_weights_variable (name = f'{ dir } _recurrent_bias' , var_name = (f'br_{ dir [0 ]} _' + '{index}' ))
1606
+ else :
1607
+ recurrent_bias = np .zeros (recurrent_weight .shape [1 ])
1608
+ self .add_weights_variable (
1609
+ name = f'{ dir } _recurrent_bias' , var_name = (f'br_{ dir [0 ]} _' + '{index}' ), data = recurrent_bias
1610
+ )
1611
+ else :
1612
+ self .add_weights_variable (
1613
+ name = f'{ dir } _recurrent_bias' , var_name = (f'br_{ dir [0 ]} _' + '{index}' ), data = recurrent_bias
1614
+ )
1615
+
1616
+
1541
1617
class GarNet (Layer ):
1542
1618
ref_impl = False
1543
1619
@@ -1828,6 +1904,7 @@ def initialize(self):
1828
1904
'GRU' : GRU ,
1829
1905
'BidirectionalLSTM' : BidirectionalLSTM ,
1830
1906
'BidirectionalGRU' : BidirectionalGRU ,
1907
+ 'Bidirectional' : Bidirectional ,
1831
1908
'QSimpleRNN' : SimpleRNN ,
1832
1909
'QLSTM' : LSTM ,
1833
1910
'QGRU' : GRU ,
0 commit comments