23
23
from jax import lax
24
24
from jax ._src import test_util as jtu
25
25
from jax ._src import xla_bridge as xb
26
+ from jax ._src .lib import xla_extension_version
26
27
from jax ._src import config
27
28
from jax .ad_checkpoint import checkpoint_name , checkpoint as new_checkpoint
28
29
import jax .numpy as jnp
@@ -1111,12 +1112,19 @@ def f(x):
1111
1112
f = jax .jit (jax .grad (f ))
1112
1113
f (inp ) # doesn't crash
1113
1114
1114
- compiled_text = f .lower (inp ).compile ().as_text ()
1115
+ compiled_f = f .lower (inp ).compile ()
1116
+
1117
+ compiled_text = compiled_f .as_text ()
1115
1118
if compiled_text is not None :
1116
1119
self .assertIn ('S(5)' , compiled_text )
1117
1120
self .assertRegex (compiled_text , r"copy-start.*S\(5\)" )
1118
1121
self .assertRegex (compiled_text , r"copy-done.*S\(5\)" )
1119
1122
1123
+ compiled_stats = compiled_f .memory_analysis ()
1124
+ if compiled_stats is not None :
1125
+ if xla_extension_version >= 240 and jtu .pjrt_c_api_version_at_least (0 , 43 ):
1126
+ self .assertGreater (compiled_stats .host_temp_size_in_bytes , 0 )
1127
+
1120
1128
def test_remat_scan_jaxpr_offloadable (self ):
1121
1129
mesh = jtu .create_global_mesh ((2 ,), ("x" ,))
1122
1130
shape = (256 , 128 )
@@ -1161,12 +1169,19 @@ def g(ys, _):
1161
1169
f = jax .jit (jax .grad (f ))
1162
1170
f (inp ) # doesn't crash
1163
1171
1164
- compiled_text = f .lower (inp ).compile ().as_text ()
1172
+ compiled_f = f .lower (inp ).compile ()
1173
+
1174
+ compiled_text = compiled_f .as_text ()
1165
1175
if compiled_text is not None :
1166
1176
self .assertIn ('S(5)' , compiled_text )
1167
1177
self .assertNotRegex (compiled_text , r"copy-start.*S\(5\)" )
1168
1178
self .assertNotRegex (compiled_text , r"copy-done.*S\(5\)" )
1169
1179
1180
+ compiled_stats = compiled_f .memory_analysis ()
1181
+ if compiled_stats is not None :
1182
+ if xla_extension_version >= 240 and jtu .pjrt_c_api_version_at_least (0 , 43 ):
1183
+ self .assertGreater (compiled_stats .host_temp_size_in_bytes , 0 )
1184
+
1170
1185
def test_remat_scan_layout_change_offloadable (self ):
1171
1186
mesh = jtu .create_global_mesh ((2 ,), ("x" ,))
1172
1187
shape = (256 , 128 )
@@ -1194,12 +1209,19 @@ def g(ys, _):
1194
1209
f = jax .jit (jax .grad (f ))
1195
1210
f (inp ) # doesn't crash
1196
1211
1197
- compiled_text = f .lower (inp ).compile ().as_text ()
1212
+ compiled_f = f .lower (inp ).compile ()
1213
+
1214
+ compiled_text = compiled_f .as_text ()
1198
1215
if compiled_text is not None :
1199
1216
self .assertIn ('S(5)' , compiled_text )
1200
1217
self .assertNotRegex (compiled_text , r"copy-start.*S\(5\)" )
1201
1218
self .assertNotRegex (compiled_text , r"copy-done.*S\(5\)" )
1202
1219
1220
+ compiled_stats = compiled_f .memory_analysis ()
1221
+ if compiled_stats is not None :
1222
+ if xla_extension_version >= 240 and jtu .pjrt_c_api_version_at_least (0 , 43 ):
1223
+ self .assertGreater (compiled_stats .host_temp_size_in_bytes , 0 )
1224
+
1203
1225
def test_remat_checkpoint_dots_with_no_batch_dims (self ):
1204
1226
policy = jax .checkpoint_policies .offload_dot_with_no_batch_dims (
1205
1227
"device" , "pinned_host" )
@@ -1219,12 +1241,18 @@ def f(x):
1219
1241
f = jax .jit (jax .grad (f ))
1220
1242
f (inp ) # doesn't crash
1221
1243
1222
- compiled_text = f .lower (inp ).compile ().as_text ()
1244
+ compiled_f = f .lower (inp ).compile ()
1245
+
1246
+ compiled_text = compiled_f .as_text ()
1223
1247
if compiled_text is not None :
1224
1248
self .assertIn ('S(5)' , compiled_text )
1225
1249
self .assertRegex (compiled_text , r"copy-start.*S\(5\)" )
1226
1250
self .assertRegex (compiled_text , r"copy-done.*S\(5\)" )
1227
1251
1252
+ compiled_stats = compiled_f .memory_analysis ()
1253
+ if compiled_stats is not None :
1254
+ if xla_extension_version >= 240 and jtu .pjrt_c_api_version_at_least (0 , 43 ):
1255
+ self .assertGreater (compiled_stats .host_temp_size_in_bytes , 0 )
1228
1256
1229
1257
if __name__ == "__main__" :
1230
1258
absltest .main (testLoader = jtu .JaxTestLoader ())
0 commit comments