@@ -1031,8 +1031,8 @@ def f_jax(x): # x: f32[10,20] -> f32[20,10]
1031
1031
def test_multi_platform (self ):
1032
1032
x = np .arange (8 , dtype = np .float32 )
1033
1033
exp = get_exported (_testing_multi_platform_func ,
1034
- lowering_platforms = ("tpu" , "cpu" , "cuda" ))(x )
1035
- self .assertEqual (exp .lowering_platforms , ("tpu" , "cpu" , "cuda" ))
1034
+ lowering_platforms = ("tpu" , "cpu" , "cuda" , "rocm" ))(x )
1035
+ self .assertEqual (exp .lowering_platforms , ("tpu" , "cpu" , "cuda" , "rocm" ))
1036
1036
module_str = str (exp .mlir_module ())
1037
1037
expected_main_re = (
1038
1038
r"@main\("
@@ -1054,14 +1054,14 @@ def test_multi_platform(self):
1054
1054
def test_multi_platform_nested (self ):
1055
1055
x = np .arange (5 , dtype = np .float32 )
1056
1056
exp = get_exported (lambda x : _testing_multi_platform_func (jnp .sin (x )),
1057
- lowering_platforms = ("cpu" , "tpu" , "cuda" ))(x )
1058
- self .assertEqual (exp .lowering_platforms , ("cpu" , "tpu" , "cuda" ))
1057
+ lowering_platforms = ("cpu" , "tpu" , "cuda" , "rocm" ))(x )
1058
+ self .assertEqual (exp .lowering_platforms , ("cpu" , "tpu" , "cuda" , "rocm" ))
1059
1059
1060
1060
# Now serialize the call to the exported using a different sequence of
1061
1061
# lowering platforms, but included in the lowering platforms for the
1062
1062
# nested exported.
1063
1063
exp2 = get_exported (export .call_exported (exp ),
1064
- lowering_platforms = ("cpu" , "cuda" ))(x )
1064
+ lowering_platforms = ("cpu" , "cuda" , "rocm" ))(x )
1065
1065
1066
1066
# Ensure that we do not have multiple lowerings of the exported function
1067
1067
exp2_module_str = str (exp2 .mlir_module ())
@@ -1080,8 +1080,8 @@ def test_multi_platform_nested(self):
1080
1080
def test_multi_platform_nested_inside_single_platform_export (self ):
1081
1081
x = np .arange (5 , dtype = np .float32 )
1082
1082
exp = get_exported (_testing_multi_platform_func ,
1083
- lowering_platforms = ("cpu" , "tpu" , "cuda" ))(x )
1084
- self .assertEqual (exp .lowering_platforms , ("cpu" , "tpu" , "cuda" ))
1083
+ lowering_platforms = ("cpu" , "tpu" , "cuda" , "rocm" ))(x )
1084
+ self .assertEqual (exp .lowering_platforms , ("cpu" , "tpu" , "cuda" , "rocm" ))
1085
1085
1086
1086
# Now serialize the call for the current platform.
1087
1087
exp2 = get_exported (export .call_exported (exp ))(x )
@@ -1120,7 +1120,7 @@ def f_jax(b): # b: f32[16 // DEVICES, 4]
1120
1120
1121
1121
res_native = f_jax (a )
1122
1122
exp = get_exported (f_jax ,
1123
- lowering_platforms = ("cpu" , "tpu" , "cuda" ))(a )
1123
+ lowering_platforms = ("cpu" , "tpu" , "cuda" , "rocm" ))(a )
1124
1124
1125
1125
# Call with argument placed on different plaforms
1126
1126
for platform in self .__class__ .platforms :
0 commit comments