Skip to content

Commit 0239d8c

Browse files
authored
scan-based GRU falls back to nn.GRU when bidirectional is true. (#8984)
1 parent db49edd commit 0239d8c

File tree

2 files changed

+41
-8
lines changed

2 files changed

+41
-8
lines changed

test/test_gru.py

Lines changed: 31 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,27 +14,31 @@ def setUp(self):
1414
torch.manual_seed(0)
1515
torch_xla.manual_seed(0)
1616

17-
def build_models(self,
18-
input_size,
19-
hidden_size,
20-
num_layers,
21-
bias,
22-
batch_first=False):
17+
def build_models(
18+
self,
19+
input_size,
20+
hidden_size,
21+
num_layers,
22+
bias,
23+
batch_first=False,
24+
bidirectional=False,
25+
):
2326
gru = nn.GRU(
2427
input_size,
2528
hidden_size,
2629
num_layers=num_layers,
2730
bias=bias,
2831
batch_first=batch_first,
2932
dropout=0.0,
30-
bidirectional=False)
33+
bidirectional=bidirectional)
3134
scan_gru = GRU(
3235
input_size,
3336
hidden_size,
3437
num_layers=num_layers,
3538
bias=bias,
3639
batch_first=batch_first,
37-
dropout=0.0)
40+
dropout=0.0,
41+
bidirectional=bidirectional)
3842

3943
# Copy parameters from the upstream GRU to our scan-based GRU.
4044
# This ensures that the scan-based GRU has the same parameters as the
@@ -91,6 +95,22 @@ def check_gradients(self,
9195
atol=atol,
9296
rtol=rtol)
9397

98+
def test_scan_gru_fallback_to_upstream_gru(self):
99+
"""
100+
Ensures that the scan-based GRU falls back to the upstream GRU when
101+
unsupported parameters are set.
102+
"""
103+
input_size, hidden_size, num_layers = 16, 32, 2
104+
_, scan_gru = self.build_models(input_size, hidden_size, num_layers, True)
105+
assert type(scan_gru) is GRU, (
106+
"Scan-based GRU should create scan-based GRU when *no* unsupported parameters are set."
107+
)
108+
_, scan_gru = self.build_models(
109+
input_size, hidden_size, num_layers, True, bidirectional=True)
110+
assert type(scan_gru) is nn.GRU, (
111+
"Scan-based GRU should fall back to upstream GRU when `bidirectional` is set to True."
112+
)
113+
94114
def test_scan_gru_and_upstream_gru_interchangeability(self):
95115
"""
96116
Ensures that the scan-based GRU and upstream GRU are interchangeable.
@@ -114,7 +134,10 @@ def test_scan_gru_and_upstream_gru_interchangeability(self):
114134

115135
# Check that the methods of the GRU and scan-based GRU have the same signature.
116136
common_methods = nn_gru_names & scan_gru_names
137+
exempt_methods = ['__new__']
117138
for method_name in common_methods:
139+
if method_name in exempt_methods:
140+
continue
118141
try:
119142
nn_gru_method = nn_gru_members[method_name]
120143
scan_gru_method = scan_gru_members[method_name]

torch_xla/experimental/gru.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import torch
22
import torch.nn as nn
33
import torch.nn.functional as F
4+
import logging
45
from torch.nn.utils.rnn import PackedSequence
56
from typing import overload
67

@@ -59,6 +60,15 @@ class GRU(nn.GRU):
5960
6061
"""
6162

63+
def __new__(cls, *args, **kwargs):
64+
if ('bidirectional' in kwargs and kwargs['bidirectional'] == True):
65+
logging.warning(
66+
"Scan-based GRU only supports unidirectional GRU. (bidirectional = False) "
67+
"Scan-based GRU falls back to the default nn.GRU implementation instead."
68+
)
69+
return nn.GRU(*args, **kwargs)
70+
return super().__new__(cls)
71+
6272
@overload
6373
def __init__(
6474
self,

0 commit comments

Comments
 (0)