Skip to content

Commit fd10078

Browse files
jvstokesjax authors
authored andcommitted
[XLA:TPU] Support scan loops for parameter input and output streaming in host offloading.
Currently, parameter input and output streaming use HloAliasAnalysis to correlate MoveToDevice calls with their corresponding input buffers. However, this can break down in scan loops, in which a dynamic slice creates the buffer which is offloaded. This prevents the AliasAnalysis from operating on the right buffer and finding the entry parameter. This change adds a function to TryParameterStreaming which traces up the call graph to potentially find a dynamic slice, and if so performs alias analysis on the input to that dynamic slice. For TryOutputStreaming, we trace down the call graph to find a dynamic update slice and perform alias analysis on that buffer instead. PiperOrigin-RevId: 626894899
1 parent e498bca commit fd10078

File tree

1 file changed

+20
-0
lines changed

1 file changed

+20
-0
lines changed

tests/memories_test.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1152,6 +1152,26 @@ def test_identity_jit_host_to_device_and_vice_versa(self):
11521152
self.assertArraysEqual(out_host, np_inp)
11531153
self.assertEqual(out_host.sharding, s_host)
11541154

1155+
def test_parameter_streaming_inside_scan(self):
1156+
mesh = jtu.create_global_mesh((1, 1, 2), ("x", "y", "z"))
1157+
np_inp = np.arange(4096.0).reshape(16, 16, 16)
1158+
s_host = NamedSharding(mesh, P("x", "y", "z"), memory_kind="pinned_host")
1159+
arr_host = jax.device_put(np_inp, s_host)
1160+
1161+
@jax.jit
1162+
def f(xs):
1163+
def body(carry, x):
1164+
x_tpu = jax.device_put(x, TransferToMemoryKind("device"))
1165+
return carry, x_tpu + carry
1166+
1167+
return jax.lax.scan(body, 1.0, xs)
1168+
1169+
_, out_hbm = f(arr_host)
1170+
self.assertArraysEqual(out_hbm, np_inp + 1.0)
1171+
# Only expect the last dimension to have a named sharding.
1172+
out_s = NamedSharding(mesh, P(None, None, "z"), memory_kind="device")
1173+
self.assertEqual(out_hbm.sharding, out_s)
1174+
11551175

11561176
class ActivationOffloadingTest(jtu.JaxTestCase):
11571177

0 commit comments

Comments
 (0)