@@ -14,27 +14,31 @@ def setUp(self):
14
14
torch .manual_seed (0 )
15
15
torch_xla .manual_seed (0 )
16
16
17
- def build_models (self ,
18
- input_size ,
19
- hidden_size ,
20
- num_layers ,
21
- bias ,
22
- batch_first = False ):
17
+ def build_models (
18
+ self ,
19
+ input_size ,
20
+ hidden_size ,
21
+ num_layers ,
22
+ bias ,
23
+ batch_first = False ,
24
+ bidirectional = False ,
25
+ ):
23
26
gru = nn .GRU (
24
27
input_size ,
25
28
hidden_size ,
26
29
num_layers = num_layers ,
27
30
bias = bias ,
28
31
batch_first = batch_first ,
29
32
dropout = 0.0 ,
30
- bidirectional = False )
33
+ bidirectional = bidirectional )
31
34
scan_gru = GRU (
32
35
input_size ,
33
36
hidden_size ,
34
37
num_layers = num_layers ,
35
38
bias = bias ,
36
39
batch_first = batch_first ,
37
- dropout = 0.0 )
40
+ dropout = 0.0 ,
41
+ bidirectional = bidirectional )
38
42
39
43
# Copy parameters from the upstream GRU to our scan-based GRU.
40
44
# This ensures that the scan-based GRU has the same parameters as the
@@ -91,6 +95,22 @@ def check_gradients(self,
91
95
atol = atol ,
92
96
rtol = rtol )
93
97
98
+ def test_scan_gru_fallback_to_upstream_gru (self ):
99
+ """
100
+ Ensures that the scan-based GRU falls back to the upstream GRU when
101
+ unsupported parameters are set.
102
+ """
103
+ input_size , hidden_size , num_layers = 16 , 32 , 2
104
+ _ , scan_gru = self .build_models (input_size , hidden_size , num_layers , True )
105
+ assert type (scan_gru ) is GRU , (
106
+ "Scan-based GRU should create scan-based GRU when *no* unsupported parameters are set."
107
+ )
108
+ _ , scan_gru = self .build_models (
109
+ input_size , hidden_size , num_layers , True , bidirectional = True )
110
+ assert type (scan_gru ) is nn .GRU , (
111
+ "Scan-based GRU should fall back to upstream GRU when `bidirectional` is set to True."
112
+ )
113
+
94
114
def test_scan_gru_and_upstream_gru_interchangeability (self ):
95
115
"""
96
116
Ensures that the scan-based GRU and upstream GRU are interchangeable.
@@ -114,7 +134,10 @@ def test_scan_gru_and_upstream_gru_interchangeability(self):
114
134
115
135
# Check that the methods of the GRU and scan-based GRU have the same signature.
116
136
common_methods = nn_gru_names & scan_gru_names
137
+ exempt_methods = ['__new__' ]
117
138
for method_name in common_methods :
139
+ if method_name in exempt_methods :
140
+ continue
118
141
try :
119
142
nn_gru_method = nn_gru_members [method_name ]
120
143
scan_gru_method = scan_gru_members [method_name ]
0 commit comments