Skip to content

Commit 991fef7

Browse files
committed
Enable remat lowering and tests on GPU
1 parent 69795eb commit 991fef7

File tree

3 files changed

+15
-8
lines changed

3 files changed

+15
-8
lines changed

jax/_src/dispatch.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -445,7 +445,7 @@ def device_put_transpose_rule(ct, _, device, src):
445445
ad.deflinear2(device_put_p, device_put_transpose_rule)
446446
batching.defvectorized(device_put_p)
447447

448-
def _tpu_device_put_lowering(ctx, x, *, device, src):
448+
def _tpu_gpu_device_put_lowering(ctx, x, *, device, src):
449449
if (isinstance(device, (XLACompatibleSharding, TransferToMemoryKind)) and
450450
device.memory_kind is not None):
451451
aval, = ctx.avals_in
@@ -456,7 +456,10 @@ def _tpu_device_put_lowering(ctx, x, *, device, src):
456456
ctx, x, out_aval, device._to_xla_hlo_sharding(aval.ndim).to_proto())
457457
return [x]
458458
return [x]
459-
mlir.register_lowering(device_put_p, _tpu_device_put_lowering, platform='tpu')
459+
mlir.register_lowering(
460+
device_put_p, _tpu_gpu_device_put_lowering, platform='tpu')
461+
mlir.register_lowering(
462+
device_put_p, _tpu_gpu_device_put_lowering, platform='gpu')
460463

461464

462465
def _common_device_put_lowering(ctx, x, *, device, src):

tests/memories_test.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1118,8 +1118,8 @@ def test_device_put_python_int(self):
11181118
class ActivationOffloadingTest(jtu.JaxTestCase):
11191119

11201120
def setUp(self):
1121-
if not jtu.test_device_matches(["tpu"]):
1122-
self.skipTest("Memories do not work on CPU and GPU backends yet.")
1121+
if not jtu.test_device_matches(["tpu", "gpu"]):
1122+
self.skipTest("Memories do not work on CPU backend.")
11231123
super().setUp()
11241124
self.orig_memories_flag = config.enable_memories.value
11251125
jax.config.update('jax_enable_memories', True)
@@ -1167,11 +1167,13 @@ def f(x):
11671167
self.assertRegex(compiled_text, r"copy-done.*S\(5\)")
11681168

11691169
compiled_stats = compiled_f.memory_analysis()
1170-
if compiled_stats is not None:
1170+
if compiled_stats is not None and jtu.test_device_matches(["tpu"]):
11711171
if xla_extension_version >= 240 and jtu.pjrt_c_api_version_at_least(0, 43):
11721172
self.assertGreater(compiled_stats.host_temp_size_in_bytes, 0)
11731173

11741174
def test_remat_scan_jaxpr_offloadable(self):
1175+
if not jtu.test_device_matches(["tpu"]):
1176+
self.skipTest("Remat scan does not work on GPU backend.")
11751177
mesh = jtu.create_global_mesh((2,), ("x",))
11761178
shape = (256, 128)
11771179
np_inp = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
@@ -1229,6 +1231,8 @@ def g(ys, _):
12291231
self.assertGreater(compiled_stats.host_temp_size_in_bytes, 0)
12301232

12311233
def test_remat_scan_layout_change_offloadable(self):
1234+
if not jtu.test_device_matches(["tpu"]):
1235+
self.skipTest("Remat scan does not work on GPU backend.")
12321236
mesh = jtu.create_global_mesh((2,), ("x",))
12331237
shape = (256, 128)
12341238
np_inp = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
@@ -1296,7 +1300,7 @@ def f(x):
12961300
self.assertRegex(compiled_text, r"copy-done.*S\(5\)")
12971301

12981302
compiled_stats = compiled_f.memory_analysis()
1299-
if compiled_stats is not None:
1303+
if compiled_stats is not None and jtu.test_device_matches(["tpu"]):
13001304
if xla_extension_version >= 240 and jtu.pjrt_c_api_version_at_least(0, 43):
13011305
self.assertGreater(compiled_stats.host_temp_size_in_bytes, 0)
13021306

tests/pjit_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3748,8 +3748,8 @@ def f(inp):
37483748
' manager.*SingleDeviceSharding'):
37493749
jax.jit(jax.vmap(f, spmd_axis_name='x'))(arr)
37503750

3751-
@jtu.skip_on_devices("tpu")
3752-
def test_device_put_memory_kind_not_tpu(self):
3751+
@jtu.skip_on_devices("tpu", "gpu")
3752+
def test_device_put_memory_kind_not_tpu_gpu(self):
37533753
@jax.jit
37543754
def f(x):
37553755
y = x * 2

0 commit comments

Comments
 (0)