2
2
import torch .nn as nn
3
3
import torch_xla
4
4
import inspect
5
- from torch_xla .experimental .gru import GRU
5
+ from torch_xla .experimental .gru import GRU as ScanGRU
6
6
7
7
from absl .testing import absltest , parameterized
8
8
@@ -23,15 +23,15 @@ def build_models(
23
23
batch_first = False ,
24
24
bidirectional = False ,
25
25
):
26
- gru = nn .GRU (
26
+ gru = nn .GRU . _orig (
27
27
input_size ,
28
28
hidden_size ,
29
29
num_layers = num_layers ,
30
30
bias = bias ,
31
31
batch_first = batch_first ,
32
32
dropout = 0.0 ,
33
33
bidirectional = bidirectional )
34
- scan_gru = GRU (
34
+ scan_gru = nn . GRU (
35
35
input_size ,
36
36
hidden_size ,
37
37
num_layers = num_layers ,
@@ -95,28 +95,45 @@ def check_gradients(self,
95
95
atol = atol ,
96
96
rtol = rtol )
97
97
98
+ def test_patch_happened (self ):
99
+ """
100
+ Ensures that the GRU class is patched correctly. The patch should happen in _patched_functions.py before
101
+ this test is run.
102
+ """
103
+ # Check if the GRU class is patched.
104
+ assert type (nn .GRU ) is type (ScanGRU ), (
105
+ "GRU class should be patched. "
106
+ "Check if the patching code is executed before this test." )
107
+ assert hasattr (
108
+ nn .GRU ,
109
+ '_orig' ), ("GRU class should be patched. "
110
+ "Check if the patching code is executed before this test." )
111
+ assert nn .GRU ._orig is not None , (
112
+ "GRU class should have the original GRU class as _orig. "
113
+ "Check if the patching code is executed before this test." )
114
+
98
115
def test_scan_gru_fallback_to_upstream_gru (self ):
99
116
"""
100
117
Ensures that the scan-based GRU falls back to the upstream GRU when
101
118
unsupported parameters are set.
102
119
"""
103
120
input_size , hidden_size , num_layers = 16 , 32 , 2
104
121
_ , scan_gru = self .build_models (input_size , hidden_size , num_layers , True )
105
- assert type (scan_gru ) is GRU , (
122
+ assert type (scan_gru ) is nn . GRU , (
106
123
"Scan-based GRU should create scan-based GRU when *no* unsupported parameters are set."
107
124
)
108
125
_ , scan_gru = self .build_models (
109
126
input_size , hidden_size , num_layers , True , bidirectional = True )
110
- assert type (scan_gru ) is nn .GRU , (
127
+ assert type (scan_gru ) is nn .GRU . _orig , (
111
128
"Scan-based GRU should fall back to upstream GRU when `bidirectional` is set to True."
112
129
)
113
130
114
131
def test_scan_gru_and_upstream_gru_interchangeability (self ):
115
132
"""
116
133
Ensures that the scan-based GRU and upstream GRU are interchangeable.
117
134
"""
118
- nn_gru = nn .GRU
119
- scan_gru = GRU
135
+ nn_gru = nn .GRU . _orig
136
+ scan_gru = nn . GRU
120
137
nn_gru_members = dict (inspect .getmembers (nn_gru , inspect .isroutine ))
121
138
scan_gru_members = dict (inspect .getmembers (scan_gru , inspect .isroutine ))
122
139
0 commit comments