Skip to content

Commit 2feb0ac

Browse files
authored
[scan] Make sure inputs into fn are not device_data IR nodes (#8769)
1 parent 2675e68 commit 2feb0ac

File tree

4 files changed

+163
-99
lines changed

4 files changed

+163
-99
lines changed

test/scan/test_scan_pallas.py

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
import logging
2+
import sys
3+
import unittest
4+
from absl.testing import parameterized
5+
6+
import torch
7+
from torch import nn as nn
8+
9+
import torch_xla
10+
import torch_xla.core.xla_model as xm
11+
from torch_xla import runtime as xr
12+
from torch_xla.experimental.scan_layers import scan_layers
13+
import torch_xla.distributed.spmd as xs
14+
from torch_xla.experimental.custom_kernel import flash_attention
15+
16+
17+
class AttentionModule(torch.nn.Module):
18+
19+
def __init__(self, has_model_weight=True, num_head=4, hidden_dim=256):
20+
super(AttentionModule, self).__init__()
21+
self.has_model_weight = has_model_weight
22+
if has_model_weight:
23+
self.num_head = num_head
24+
self.hidden_dim = hidden_dim
25+
self.fc = nn.Linear(hidden_dim, hidden_dim)
26+
27+
def forward(self, input):
28+
# query_states: [B, NUM_HEAD, SEQ_LEN, d_k]
29+
# attn_output: [B, SEQ_LEN, d_m], dm = dk * NUM_HEAD
30+
query_states = input.clone()
31+
key_states = input.clone()
32+
value_states = input.clone()
33+
attn_output = flash_attention(
34+
query_states,
35+
key_states,
36+
value_states,
37+
causal=True,
38+
partition_spec=("fsdp", None, None, None),
39+
)
40+
if self.has_model_weight:
41+
attn_output = self.fc(attn_output)
42+
return attn_output
43+
44+
45+
class AttentionLayers(torch.nn.Module):
46+
47+
def __init__(self, has_model_weight=True, num_layer=3, use_scan=False):
48+
super(AttentionLayers, self).__init__()
49+
self.num_layer = num_layer
50+
self.use_scan = use_scan
51+
self.has_model_weight = has_model_weight
52+
self.layers = nn.ModuleList([
53+
AttentionModule(has_model_weight=has_model_weight)
54+
for i in range(self.num_layer)
55+
])
56+
57+
def forward(self, input):
58+
hidden_states = input
59+
xs.mark_sharding(hidden_states, xs.get_global_mesh(),
60+
("fsdp", None, None, None))
61+
if not self.use_scan:
62+
for layer in self.layers:
63+
hidden_states = layer(hidden_states)
64+
else:
65+
hidden_states = scan_layers(self.layers, input_data=hidden_states)
66+
return hidden_states
67+
68+
69+
class ScanFlashAttentionTest(parameterized.TestCase):
70+
71+
def fake_fa_wrapper(self, has_model_weight, use_scan):
72+
torch.manual_seed(12)
73+
torch_xla.manual_seed(12)
74+
hidden_states = torch.randn((2, 4, 256, 256)).requires_grad_().to('xla')
75+
with xm.xla_device():
76+
attention_layers = AttentionLayers(
77+
has_model_weight, num_layer=3, use_scan=use_scan)
78+
hidden_states.retain_grad()
79+
output = attention_layers(hidden_states)
80+
return output
81+
82+
@unittest.skipIf(xr.device_type() != 'TPU', "This test only works on TPU")
83+
def test_scan_flash_attention_against_for_loop(self):
84+
for_loop_output = self.fake_fa_wrapper(
85+
has_model_weight=True, use_scan=False)
86+
torch_xla.sync()
87+
scan_output = self.fake_fa_wrapper(has_model_weight=True, use_scan=True)
88+
torch_xla.sync()
89+
torch.testing.assert_close(
90+
for_loop_output.cpu(), scan_output.cpu(), atol=1e-3, rtol=1e-3)
91+
92+
@unittest.skipIf(xr.device_type() != 'TPU', "This test only works on TPU")
93+
@parameterized.named_parameters(("has_model_weight_True", True),
94+
("has_model_weight_False", False))
95+
def test_scan_weight_layer_aot(self, has_model_weight_scan):
96+
output = self.fake_fa_wrapper(
97+
has_model_weight=has_model_weight_scan, use_scan=False)
98+
torch_xla.sync()
99+
# TODO(https://github.com/pytorch/xla/issues/8753): Fix assertion
100+
# torch.manual_seed(12)
101+
# torch_xla.manual_seed(12)
102+
# scan_output = self.fake_fa_wrapper(
103+
# has_model_weight=has_model_weight_scan, use_scan=True)
104+
# torch_xla.sync()
105+
# torch.testing.assert_close(output.cpu(), scan_output.cpu())
106+
107+
108+
if __name__ == '__main__':
109+
logging.getLogger().setLevel(logging.INFO)
110+
111+
xr.use_spmd()
112+
n_devices = xr.global_runtime_device_count()
113+
xs.set_global_mesh(xs.get_1d_mesh("fsdp"))
114+
115+
test = unittest.main()
116+
sys.exit(0 if test.result.wasSuccessful() else 1)

test/test_as_stride_use_slice.py

Lines changed: 0 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -9,67 +9,12 @@
99
import torch_xla
1010
import torch_xla.core.xla_model as xm
1111
from torch_xla import runtime as xr
12-
from torch_xla._internal import tpu
13-
from torch_xla.experimental.scan_layers import scan_layers
1412
import torch_xla.distributed.spmd as xs
15-
from torch_xla.experimental.custom_kernel import flash_attention
1613

1714
from functorch.compile import aot_function, make_boxed_func
1815
from torch.library import custom_op
1916

2017

21-
class AttentionModule(torch.nn.Module):
22-
23-
def __init__(self, has_model_weight=True, num_head=4, hidden_dim=256):
24-
super(AttentionModule, self).__init__()
25-
self.has_model_weight = has_model_weight
26-
if has_model_weight:
27-
self.num_head = num_head
28-
self.hidden_dim = hidden_dim
29-
self.fc = nn.Linear(hidden_dim, hidden_dim)
30-
31-
def forward(self, input):
32-
# query_states: [B, NUM_HEAD, SEQ_LEN, d_k]
33-
# attn_output: [B, SEQ_LEN, d_m], dm = dk * NUM_HEAD
34-
query_states = input.clone()
35-
key_states = input.clone()
36-
value_states = input.clone()
37-
attn_output = flash_attention(
38-
query_states,
39-
key_states,
40-
value_states,
41-
causal=True,
42-
partition_spec=("fsdp", None, None, None),
43-
)
44-
if self.has_model_weight:
45-
attn_output = self.fc(attn_output)
46-
return attn_output
47-
48-
49-
class AttentionLayers(torch.nn.Module):
50-
51-
def __init__(self, has_model_weight=True, num_layer=3, use_scan=False):
52-
super(AttentionLayers, self).__init__()
53-
self.num_layer = num_layer
54-
self.use_scan = use_scan
55-
self.has_model_weight = has_model_weight
56-
self.layers = nn.ModuleList([
57-
AttentionModule(has_model_weight=has_model_weight)
58-
for i in range(self.num_layer)
59-
])
60-
61-
def forward(self, input):
62-
hidden_states = input
63-
xs.mark_sharding(hidden_states, xs.get_global_mesh(),
64-
("fsdp", None, None, None))
65-
if not self.use_scan:
66-
for layer in self.layers:
67-
hidden_states = layer(hidden_states)
68-
else:
69-
hidden_states = scan_layers(self.layers, input_data=hidden_states)
70-
return hidden_states
71-
72-
7318
class StridedAndSlice(torch.nn.Module):
7419

7520
def __init__(self):
@@ -198,50 +143,7 @@ def compiler(gm, _):
198143
torch.testing.assert_close(cpu_output, xla_output.cpu())
199144

200145

201-
class ScanFlashAttentionTest(parameterized.TestCase):
202-
203-
def fake_fa_wrapper(self, has_model_weight, use_scan):
204-
with xm.xla_device():
205-
dm = AttentionLayers(has_model_weight, 3, use_scan)
206-
hidden_states = torch.randn((2, 4, 256, 256)).requires_grad_()
207-
hidden_states.retain_grad()
208-
output = dm(hidden_states)
209-
return output
210-
211-
@unittest.skipIf(xr.device_type() != 'TPU', "This test only works on TPU")
212-
@parameterized.named_parameters(("use_scan_True", True),
213-
("use_scan_False", False))
214-
def test_scan_layer_aot(self, use_scan):
215-
output = self.fake_fa_wrapper(has_model_weight=True, use_scan=use_scan)
216-
torch_xla.sync()
217-
# TODO(https://github.com/pytorch/xla/issues/8742): Fix NaN
218-
# self.assertFalse(torch.isnan(output).any())
219-
220-
@unittest.skipIf(xr.device_type() != 'TPU', "This test only works on TPU")
221-
@parameterized.named_parameters(("has_model_weight_True", True),
222-
("has_model_weight_False", False))
223-
def test_scan_weight_layer_aot(self, has_model_weight_scan):
224-
torch.manual_seed(12)
225-
torch_xla.manual_seed(12)
226-
output = self.fake_fa_wrapper(
227-
has_model_weight=has_model_weight_scan, use_scan=False)
228-
torch_xla.sync()
229-
# TODO(https://github.com/pytorch/xla/issues/8742): Fix NaN
230-
# TODO(https://github.com/pytorch/xla/issues/8753): Fix assertion
231-
# torch.manual_seed(12)
232-
# torch_xla.manual_seed(12)
233-
# scan_output = self.fake_fa_wrapper(
234-
# has_model_weight=has_model_weight_scan, use_scan=True)
235-
# torch_xla.sync()
236-
# torch.testing.assert_close(output.cpu(), scan_output.cpu())
237-
238-
239146
if __name__ == '__main__':
240147
logging.getLogger().setLevel(logging.INFO)
241-
242-
xr.use_spmd()
243-
n_devices = xr.global_runtime_device_count()
244-
xs.set_global_mesh(xs.get_1d_mesh("fsdp"))
245-
246148
test = unittest.main()
247149
sys.exit(0 if test.result.wasSuccessful() else 1)

test/tpu/run_tests.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ python3 "$TEST_CDIR/pjrt/test_dynamic_plugin_tpu.py"
3232
python3 "$TEST_CDIR/test_while_loop.py"
3333
python3 "$TEST_CDIR/scan/test_scan.py"
3434
python3 "$TEST_CDIR/scan/test_scan_spmd.py"
35+
python3 "$TEST_CDIR/scan/test_scan_pallas.py"
3536
python3 "$TEST_CDIR/scan/test_scan_layers.py"
3637
python3 "$TEST_CDIR/test_as_stride_use_slice.py"
3738
run_xla_hlo_debug python3 "$TEST_CDIR/scan/test_scan_debug.py"

torch_xla/experimental/scan.py

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -564,7 +564,52 @@ def make_fake_tensor(v: torch.Tensor) -> torch.Tensor:
564564
device = torch_xla.device()
565565
fake_carry = tree_map(make_fake_tensor, init)
566566
fake_x = tree_map(lambda v: make_fake_tensor(v[0]), xs)
567-
fake_output_carry, fake_output_y = fn(fake_carry, fake_x)
567+
568+
def defeat_device_data(v: torch.Tensor) -> torch.Tensor:
569+
"""
570+
Make sure inputs into `fn` are not `device_data` IR nodes.
571+
572+
This is to workaround a limitation of `mark_sharding`, which replaces
573+
the innards of the tensors it operates on. In other words, `mark_sharding`
574+
is an in-place operation as opposed to a transform like found in JAX.
575+
576+
When `fn` contains a `mark_sharding` and the `mark_sharding` operates on one
577+
of the carry or xs fake tensors, the original device data will be discarded
578+
and a new one will be created in its place. That's because `mark_sharding` has
579+
different code paths depending on if the IR has or doesn't have device data.
580+
If the IR is an intermediate operation like add or matmul, `mark_sharding` will
581+
update the sharding annotation. If the IR holds data, `mark_sharding` will
582+
transfer the data to the TPU in a sharded manner, and update the data object
583+
in the IR to point to a sharded data object, as can be seen in [2].
584+
585+
When lowering a graph to HLO, tensors that hold the same data object will
586+
map to the same HLO parameter. Changing the data object in the tensor will
587+
cause it to map to a different HLO parameter. As a result, `fn` will appear
588+
to create a few empty tensors internally that are unrelated to the carry and
589+
xs fake tensors, and the carry and xs will appear completely unused.
590+
591+
See https://github.com/pytorch/xla/issues/8742 for the bug. In short,
592+
if an input into the layer to be scanned is a device data, and that layer
593+
does a `mark_sharding` on said input, then the graph capturing in `scan`
594+
will fail.
595+
596+
The workaround here is simple and cursed: multiply any `device_data` by 1.
597+
This will make sure these tensor don't hold device data IR nodes and will
598+
defeat the device data replacement of `mark_sharding`.
599+
600+
Fortunately, XLA simplifies away the multiplication (see [1]) so this should
601+
become a no-op by the time it hits the TPU.
602+
603+
[1]: https://github.com/openxla/xla/blob/869f57d0082d7adbb9efc10cc18f51a562fc7bf3/xla/hlo/transforms/simplifiers/algebraic_simplifier.cc#L4755-L4770
604+
[2]: https://github.com/pytorch/xla/blob/2675e6892c6f955fc2baf88d85dfdfa72062273c/torch_xla/csrc/xla_sharding_util.cpp#L799-L846
605+
606+
"""
607+
return v * 1
608+
609+
# Trace `fn` in order to stage out its HLO.
610+
fake_output_carry, fake_output_y = fn(
611+
tree_map(defeat_device_data, fake_carry),
612+
tree_map(defeat_device_data, fake_x))
568613

569614
y_len = len(fake_output_y)
570615
fn_outputs = fake_output_carry + fake_output_y

0 commit comments

Comments
 (0)