Skip to content

Commit 3a89557

Browse files
yueshengysjax authors
authored andcommitted
Update CompiledMemoryStats in xla/python to include host memory stats and add a few tests to memories_test.py
PiperOrigin-RevId: 611834035
1 parent 1ae2022 commit 3a89557

File tree

1 file changed

+32
-4
lines changed

1 file changed

+32
-4
lines changed

tests/memories_test.py

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from jax import lax
2424
from jax._src import test_util as jtu
2525
from jax._src import xla_bridge as xb
26+
from jax._src.lib import xla_extension_version
2627
from jax._src import config
2728
from jax.ad_checkpoint import checkpoint_name, checkpoint as new_checkpoint
2829
import jax.numpy as jnp
@@ -1111,12 +1112,19 @@ def f(x):
11111112
f = jax.jit(jax.grad(f))
11121113
f(inp) # doesn't crash
11131114

1114-
compiled_text = f.lower(inp).compile().as_text()
1115+
compiled_f = f.lower(inp).compile()
1116+
1117+
compiled_text = compiled_f.as_text()
11151118
if compiled_text is not None:
11161119
self.assertIn('S(5)', compiled_text)
11171120
self.assertRegex(compiled_text, r"copy-start.*S\(5\)")
11181121
self.assertRegex(compiled_text, r"copy-done.*S\(5\)")
11191122

1123+
compiled_stats = compiled_f.memory_analysis()
1124+
if compiled_stats is not None:
1125+
if xla_extension_version >= 240 and jtu.pjrt_c_api_version_at_least(0, 43):
1126+
self.assertGreater(compiled_stats.host_temp_size_in_bytes, 0)
1127+
11201128
def test_remat_scan_jaxpr_offloadable(self):
11211129
mesh = jtu.create_global_mesh((2,), ("x",))
11221130
shape = (256, 128)
@@ -1161,12 +1169,19 @@ def g(ys, _):
11611169
f = jax.jit(jax.grad(f))
11621170
f(inp) # doesn't crash
11631171

1164-
compiled_text = f.lower(inp).compile().as_text()
1172+
compiled_f = f.lower(inp).compile()
1173+
1174+
compiled_text = compiled_f.as_text()
11651175
if compiled_text is not None:
11661176
self.assertIn('S(5)', compiled_text)
11671177
self.assertNotRegex(compiled_text, r"copy-start.*S\(5\)")
11681178
self.assertNotRegex(compiled_text, r"copy-done.*S\(5\)")
11691179

1180+
compiled_stats = compiled_f.memory_analysis()
1181+
if compiled_stats is not None:
1182+
if xla_extension_version >= 240 and jtu.pjrt_c_api_version_at_least(0, 43):
1183+
self.assertGreater(compiled_stats.host_temp_size_in_bytes, 0)
1184+
11701185
def test_remat_scan_layout_change_offloadable(self):
11711186
mesh = jtu.create_global_mesh((2,), ("x",))
11721187
shape = (256, 128)
@@ -1194,12 +1209,19 @@ def g(ys, _):
11941209
f = jax.jit(jax.grad(f))
11951210
f(inp) # doesn't crash
11961211

1197-
compiled_text = f.lower(inp).compile().as_text()
1212+
compiled_f = f.lower(inp).compile()
1213+
1214+
compiled_text = compiled_f.as_text()
11981215
if compiled_text is not None:
11991216
self.assertIn('S(5)', compiled_text)
12001217
self.assertNotRegex(compiled_text, r"copy-start.*S\(5\)")
12011218
self.assertNotRegex(compiled_text, r"copy-done.*S\(5\)")
12021219

1220+
compiled_stats = compiled_f.memory_analysis()
1221+
if compiled_stats is not None:
1222+
if xla_extension_version >= 240 and jtu.pjrt_c_api_version_at_least(0, 43):
1223+
self.assertGreater(compiled_stats.host_temp_size_in_bytes, 0)
1224+
12031225
def test_remat_checkpoint_dots_with_no_batch_dims(self):
12041226
policy = jax.checkpoint_policies.offload_dot_with_no_batch_dims(
12051227
"device", "pinned_host")
@@ -1219,12 +1241,18 @@ def f(x):
12191241
f = jax.jit(jax.grad(f))
12201242
f(inp) # doesn't crash
12211243

1222-
compiled_text = f.lower(inp).compile().as_text()
1244+
compiled_f = f.lower(inp).compile()
1245+
1246+
compiled_text = compiled_f.as_text()
12231247
if compiled_text is not None:
12241248
self.assertIn('S(5)', compiled_text)
12251249
self.assertRegex(compiled_text, r"copy-start.*S\(5\)")
12261250
self.assertRegex(compiled_text, r"copy-done.*S\(5\)")
12271251

1252+
compiled_stats = compiled_f.memory_analysis()
1253+
if compiled_stats is not None:
1254+
if xla_extension_version >= 240 and jtu.pjrt_c_api_version_at_least(0, 43):
1255+
self.assertGreater(compiled_stats.host_temp_size_in_bytes, 0)
12281256

12291257
if __name__ == "__main__":
12301258
absltest.main(testLoader=jtu.JaxTestLoader())

0 commit comments

Comments
 (0)