Skip to content

Commit faf777a

Browse files
authored
Replace upstream GRU implementation with scan-based GRU (#9010)
1 parent 329be9e commit faf777a

File tree

3 files changed

+43
-10
lines changed

3 files changed

+43
-10
lines changed

test/test_gru.py

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import torch.nn as nn
33
import torch_xla
44
import inspect
5-
from torch_xla.experimental.gru import GRU
5+
from torch_xla.experimental.gru import GRU as ScanGRU
66

77
from absl.testing import absltest, parameterized
88

@@ -23,15 +23,15 @@ def build_models(
2323
batch_first=False,
2424
bidirectional=False,
2525
):
26-
gru = nn.GRU(
26+
gru = nn.GRU._orig(
2727
input_size,
2828
hidden_size,
2929
num_layers=num_layers,
3030
bias=bias,
3131
batch_first=batch_first,
3232
dropout=0.0,
3333
bidirectional=bidirectional)
34-
scan_gru = GRU(
34+
scan_gru = nn.GRU(
3535
input_size,
3636
hidden_size,
3737
num_layers=num_layers,
@@ -95,28 +95,45 @@ def check_gradients(self,
9595
atol=atol,
9696
rtol=rtol)
9797

98+
def test_patch_happened(self):
99+
"""
100+
Ensures that the GRU class is patched correctly. The patch should happen in _patched_functions.py before
101+
this test is run.
102+
"""
103+
# Check if the GRU class is patched.
104+
assert type(nn.GRU) is type(ScanGRU), (
105+
"GRU class should be patched. "
106+
"Check if the patching code is executed before this test.")
107+
assert hasattr(
108+
nn.GRU,
109+
'_orig'), ("GRU class should be patched. "
110+
"Check if the patching code is executed before this test.")
111+
assert nn.GRU._orig is not None, (
112+
"GRU class should have the original GRU class as _orig. "
113+
"Check if the patching code is executed before this test.")
114+
98115
def test_scan_gru_fallback_to_upstream_gru(self):
99116
"""
100117
Ensures that the scan-based GRU falls back to the upstream GRU when
101118
unsupported parameters are set.
102119
"""
103120
input_size, hidden_size, num_layers = 16, 32, 2
104121
_, scan_gru = self.build_models(input_size, hidden_size, num_layers, True)
105-
assert type(scan_gru) is GRU, (
122+
assert type(scan_gru) is nn.GRU, (
106123
"Scan-based GRU should create scan-based GRU when *no* unsupported parameters are set."
107124
)
108125
_, scan_gru = self.build_models(
109126
input_size, hidden_size, num_layers, True, bidirectional=True)
110-
assert type(scan_gru) is nn.GRU, (
127+
assert type(scan_gru) is nn.GRU._orig, (
111128
"Scan-based GRU should fall back to upstream GRU when `bidirectional` is set to True."
112129
)
113130

114131
def test_scan_gru_and_upstream_gru_interchangeability(self):
115132
"""
116133
Ensures that the scan-based GRU and upstream GRU are interchangeable.
117134
"""
118-
nn_gru = nn.GRU
119-
scan_gru = GRU
135+
nn_gru = nn.GRU._orig
136+
scan_gru = nn.GRU
120137
nn_gru_members = dict(inspect.getmembers(nn_gru, inspect.isroutine))
121138
scan_gru_members = dict(inspect.getmembers(scan_gru, inspect.isroutine))
122139

torch_xla/_patched_functions.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,17 @@
33
import torch.nn as nn
44
from torch import inf
55
from typing import Iterable, Union, Optional
6+
from torch_xla.experimental.gru import GRU as ScanGRU
67

78
_tensor_or_tensors = Union[torch.Tensor, Iterable[torch.Tensor]]
89

910

10-
def _patch(fn, newfn):
11+
def _pathch_module(m, new_m):
12+
new_m._orig = m
13+
return new_m
14+
15+
16+
def _patch_fn(fn, newfn):
1117
xfingerprint = inspect.signature(fn)
1218
fingerprint = inspect.signature(newfn)
1319
if xfingerprint != fingerprint:
@@ -58,4 +64,6 @@ def clip_grad_norm_(parameters: _tensor_or_tensors,
5864

5965

6066
def _apply_patches():
61-
nn.utils.clip_grad_norm_ = _patch(nn.utils.clip_grad_norm_, clip_grad_norm_)
67+
nn.utils.clip_grad_norm_ = _patch_fn(nn.utils.clip_grad_norm_,
68+
clip_grad_norm_)
69+
nn.GRU = _pathch_module(nn.GRU, ScanGRU)

torch_xla/experimental/gru.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,15 @@ def __new__(cls, *args, **kwargs):
6666
"Scan-based GRU only supports unidirectional GRU. (bidirectional = False) "
6767
"Scan-based GRU falls back to the default nn.GRU implementation instead."
6868
)
69-
return nn.GRU(*args, **kwargs)
69+
if nn.GRU._orig is None:
70+
# If nn.GRU._orig is None, it means that the original GRU has not been
71+
# patched yet for some reason. The patching should happen in _patched_functions.py.
72+
# So we need to call the original GRU constructor here.
73+
return nn.GRU(*args, **kwargs)
74+
else:
75+
# If nn.GRU._orig is not None, it means that the original GRU has been
76+
# patched already. So we need to call the patched GRU constructor here.
77+
return nn.GRU._orig(*args, **kwargs)
7078
return super().__new__(cls)
7179

7280
@overload

0 commit comments

Comments
 (0)