Skip to content

Commit 82ee60b

Browse files
author
Rahul Batra
committed
[ROCm]: Add rocm as platform in export_test
1 parent 604f5ec commit 82ee60b

File tree

1 file changed

+8
-8
lines changed

1 file changed

+8
-8
lines changed

tests/export_test.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1031,8 +1031,8 @@ def f_jax(x): # x: f32[10,20] -> f32[20,10]
10311031
def test_multi_platform(self):
10321032
x = np.arange(8, dtype=np.float32)
10331033
exp = get_exported(_testing_multi_platform_func,
1034-
lowering_platforms=("tpu", "cpu", "cuda"))(x)
1035-
self.assertEqual(exp.lowering_platforms, ("tpu", "cpu", "cuda"))
1034+
lowering_platforms=("tpu", "cpu", "cuda","rocm"))(x)
1035+
self.assertEqual(exp.lowering_platforms, ("tpu", "cpu", "cuda", "rocm"))
10361036
module_str = str(exp.mlir_module())
10371037
expected_main_re = (
10381038
r"@main\("
@@ -1054,14 +1054,14 @@ def test_multi_platform(self):
10541054
def test_multi_platform_nested(self):
10551055
x = np.arange(5, dtype=np.float32)
10561056
exp = get_exported(lambda x: _testing_multi_platform_func(jnp.sin(x)),
1057-
lowering_platforms=("cpu", "tpu", "cuda"))(x)
1058-
self.assertEqual(exp.lowering_platforms, ("cpu", "tpu", "cuda"))
1057+
lowering_platforms=("cpu", "tpu", "cuda","rocm"))(x)
1058+
self.assertEqual(exp.lowering_platforms, ("cpu", "tpu", "cuda","rocm"))
10591059

10601060
# Now serialize the call to the exported using a different sequence of
10611061
# lowering platforms, but included in the lowering platforms for the
10621062
# nested exported.
10631063
exp2 = get_exported(export.call_exported(exp),
1064-
lowering_platforms=("cpu", "cuda"))(x)
1064+
lowering_platforms=("cpu", "cuda","rocm"))(x)
10651065

10661066
# Ensure that we do not have multiple lowerings of the exported function
10671067
exp2_module_str = str(exp2.mlir_module())
@@ -1080,8 +1080,8 @@ def test_multi_platform_nested(self):
10801080
def test_multi_platform_nested_inside_single_platform_export(self):
10811081
x = np.arange(5, dtype=np.float32)
10821082
exp = get_exported(_testing_multi_platform_func,
1083-
lowering_platforms=("cpu", "tpu", "cuda"))(x)
1084-
self.assertEqual(exp.lowering_platforms, ("cpu", "tpu", "cuda"))
1083+
lowering_platforms=("cpu", "tpu", "cuda","rocm"))(x)
1084+
self.assertEqual(exp.lowering_platforms, ("cpu", "tpu", "cuda", "rocm"))
10851085

10861086
# Now serialize the call for the current platform.
10871087
exp2 = get_exported(export.call_exported(exp))(x)
@@ -1120,7 +1120,7 @@ def f_jax(b): # b: f32[16 // DEVICES, 4]
11201120

11211121
res_native = f_jax(a)
11221122
exp = get_exported(f_jax,
1123-
lowering_platforms=("cpu", "tpu", "cuda"))(a)
1123+
lowering_platforms=("cpu", "tpu", "cuda", "rocm"))(a)
11241124

11251125
# Call with argument placed on different plaforms
11261126
for platform in self.__class__.platforms:

0 commit comments

Comments
 (0)