Skip to content

Commit f6ed624

Browse files
author
jax authors
committed
Merge pull request #20342 from ROCm:rocm-export_test-add-rocm-platform
PiperOrigin-RevId: 618179680
2 parents 44be575 + 11c35cd commit f6ed624

File tree

2 files changed

+13
-10
lines changed

2 files changed

+13
-10
lines changed

tests/export_harnesses_multi_platform_test.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -136,9 +136,12 @@ def export_and_compare_to_native(
136136
if d.platform not in unimplemented_platforms
137137
]
138138
logging.info("Using devices %s", [str(d) for d in devices])
139-
# lowering_platforms uses "cuda" instead of "gpu"
139+
# lowering_platforms uses "cuda" or "rocm" instead of "gpu"
140+
gpu_platform = "cuda"
141+
if jtu.is_device_rocm():
142+
gpu_platform = "rocm"
140143
lowering_platforms: list[str] = [
141-
p if p != "gpu" else "cuda"
144+
p if p != "gpu" else gpu_platform
142145
for p in ("cpu", "gpu", "tpu")
143146
if p not in unimplemented_platforms
144147
]

tests/export_test.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1031,8 +1031,8 @@ def f_jax(x): # x: f32[10,20] -> f32[20,10]
10311031
def test_multi_platform(self):
10321032
x = np.arange(8, dtype=np.float32)
10331033
exp = get_exported(_testing_multi_platform_func,
1034-
lowering_platforms=("tpu", "cpu", "cuda"))(x)
1035-
self.assertEqual(exp.lowering_platforms, ("tpu", "cpu", "cuda"))
1034+
lowering_platforms=("tpu", "cpu", "cuda","rocm"))(x)
1035+
self.assertEqual(exp.lowering_platforms, ("tpu", "cpu", "cuda", "rocm"))
10361036
module_str = str(exp.mlir_module())
10371037
expected_main_re = (
10381038
r"@main\("
@@ -1054,14 +1054,14 @@ def test_multi_platform(self):
10541054
def test_multi_platform_nested(self):
10551055
x = np.arange(5, dtype=np.float32)
10561056
exp = get_exported(lambda x: _testing_multi_platform_func(jnp.sin(x)),
1057-
lowering_platforms=("cpu", "tpu", "cuda"))(x)
1058-
self.assertEqual(exp.lowering_platforms, ("cpu", "tpu", "cuda"))
1057+
lowering_platforms=("cpu", "tpu", "cuda","rocm"))(x)
1058+
self.assertEqual(exp.lowering_platforms, ("cpu", "tpu", "cuda","rocm"))
10591059

10601060
# Now serialize the call to the exported using a different sequence of
10611061
# lowering platforms, but included in the lowering platforms for the
10621062
# nested exported.
10631063
exp2 = get_exported(export.call_exported(exp),
1064-
lowering_platforms=("cpu", "cuda"))(x)
1064+
lowering_platforms=("cpu", "cuda","rocm"))(x)
10651065

10661066
# Ensure that we do not have multiple lowerings of the exported function
10671067
exp2_module_str = str(exp2.mlir_module())
@@ -1080,8 +1080,8 @@ def test_multi_platform_nested(self):
10801080
def test_multi_platform_nested_inside_single_platform_export(self):
10811081
x = np.arange(5, dtype=np.float32)
10821082
exp = get_exported(_testing_multi_platform_func,
1083-
lowering_platforms=("cpu", "tpu", "cuda"))(x)
1084-
self.assertEqual(exp.lowering_platforms, ("cpu", "tpu", "cuda"))
1083+
lowering_platforms=("cpu", "tpu", "cuda","rocm"))(x)
1084+
self.assertEqual(exp.lowering_platforms, ("cpu", "tpu", "cuda", "rocm"))
10851085

10861086
# Now serialize the call for the current platform.
10871087
exp2 = get_exported(export.call_exported(exp))(x)
@@ -1120,7 +1120,7 @@ def f_jax(b): # b: f32[16 // DEVICES, 4]
11201120

11211121
res_native = f_jax(a)
11221122
exp = get_exported(f_jax,
1123-
lowering_platforms=("cpu", "tpu", "cuda"))(a)
1123+
lowering_platforms=("cpu", "tpu", "cuda", "rocm"))(a)
11241124

11251125
# Call with argument placed on different plaforms
11261126
for platform in self.__class__.platforms:

0 commit comments

Comments
 (0)