@@ -1118,8 +1118,8 @@ def test_device_put_python_int(self):
1118
1118
class ActivationOffloadingTest (jtu .JaxTestCase ):
1119
1119
1120
1120
def setUp (self ):
1121
- if not jtu .test_device_matches (["tpu" ]):
1122
- self .skipTest ("Memories do not work on CPU and GPU backends yet ." )
1121
+ if not jtu .test_device_matches (["tpu" , "gpu" ]):
1122
+ self .skipTest ("Memories do not work on CPU backend ." )
1123
1123
super ().setUp ()
1124
1124
self .orig_memories_flag = config .enable_memories .value
1125
1125
jax .config .update ('jax_enable_memories' , True )
@@ -1167,11 +1167,13 @@ def f(x):
1167
1167
self .assertRegex (compiled_text , r"copy-done.*S\(5\)" )
1168
1168
1169
1169
compiled_stats = compiled_f .memory_analysis ()
1170
- if compiled_stats is not None :
1170
+ if compiled_stats is not None and jtu . test_device_matches ([ "tpu" ]) :
1171
1171
if xla_extension_version >= 240 and jtu .pjrt_c_api_version_at_least (0 , 43 ):
1172
1172
self .assertGreater (compiled_stats .host_temp_size_in_bytes , 0 )
1173
1173
1174
1174
def test_remat_scan_jaxpr_offloadable (self ):
1175
+ if not jtu .test_device_matches (["tpu" ]):
1176
+ self .skipTest ("Remat scan does not work on GPU backend." )
1175
1177
mesh = jtu .create_global_mesh ((2 ,), ("x" ,))
1176
1178
shape = (256 , 128 )
1177
1179
np_inp = np .arange (math .prod (shape ), dtype = np .float32 ).reshape (shape )
@@ -1229,6 +1231,8 @@ def g(ys, _):
1229
1231
self .assertGreater (compiled_stats .host_temp_size_in_bytes , 0 )
1230
1232
1231
1233
def test_remat_scan_layout_change_offloadable (self ):
1234
+ if not jtu .test_device_matches (["tpu" ]):
1235
+ self .skipTest ("Remat scan does not work on GPU backend." )
1232
1236
mesh = jtu .create_global_mesh ((2 ,), ("x" ,))
1233
1237
shape = (256 , 128 )
1234
1238
np_inp = np .arange (math .prod (shape ), dtype = np .float32 ).reshape (shape )
@@ -1296,7 +1300,7 @@ def f(x):
1296
1300
self .assertRegex (compiled_text , r"copy-done.*S\(5\)" )
1297
1301
1298
1302
compiled_stats = compiled_f .memory_analysis ()
1299
- if compiled_stats is not None :
1303
+ if compiled_stats is not None and jtu . test_device_matches ([ "tpu" ]) :
1300
1304
if xla_extension_version >= 240 and jtu .pjrt_c_api_version_at_least (0 , 43 ):
1301
1305
self .assertGreater (compiled_stats .host_temp_size_in_bytes , 0 )
1302
1306
0 commit comments