Skip to content

Commit 015edd5

Browse files
author
jax authors
committed
Merge pull request #20062 from Micky774:strip_ir
PiperOrigin-RevId: 621930992
2 parents 685d97f + 86bf8fa commit 015edd5

File tree

2 files changed

+31
-2
lines changed

2 files changed

+31
-2
lines changed

jax/_src/interpreters/mlir.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,13 @@
7575
"Supports the special value 'sponge' to pick the path from the "
7676
"environment variable TEST_UNDECLARED_OUTPUTS_DIR.")
7777

78+
_JAX_INCLUDE_DEBUG_INFO_IN_DUMPS = config.DEFINE_string(
79+
'jax_include_debug_info_in_dumps',
80+
os.getenv('JAX_INCLUDE_DEBUG_INFO_IN_DUMPS', "True"),
81+
help="Determine whether or not to keep debug symbols and location information "
82+
"when dumping IR code. By default, debug information will be preserved in "
83+
"the IR dump. To avoid exposing source code and potentially sensitive "
84+
"information, set to false")
7885
lowerable_effects: effects_lib.EffectTypeSet = effects_lib.lowerable_effects
7986

8087

@@ -474,9 +481,12 @@ def dump_module_message(module: ir.Module, stage_name: str) -> str:
474481
def _make_string_safe_for_filename(s: str) -> str:
475482
return re.sub(r'[^\w.)( -]', '', s)
476483

477-
def module_to_string(module: ir.Module) -> str:
484+
def module_to_string(module: ir.Module, enable_debug_info=None) -> str:
478485
output = io.StringIO()
479-
module.operation.print(file=output, enable_debug_info=True)
486+
if enable_debug_info is None:
487+
enable_debug_flag = str.lower(_JAX_INCLUDE_DEBUG_INFO_IN_DUMPS.value)
488+
enable_debug_info = enable_debug_flag not in ('false', '0')
489+
module.operation.print(file=output, enable_debug_info=enable_debug_info)
480490
return output.getvalue()
481491

482492
def module_to_bytecode(module: ir.Module) -> bytes:

tests/api_test.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1227,6 +1227,14 @@ def f(x, y, *args, **kwargs):
12271227
self.assertIn("kwargs['z']", hlo_str)
12281228
self.assertIn("kwargs['w']", hlo_str)
12291229

1230+
hlo_str = mlir.module_to_string(
1231+
lowered.compiler_ir('stablehlo'),
1232+
enable_debug_info=False,
1233+
)
1234+
for s in ("\"x\"", "y['hi']", "args[0]", "args[1]", "kwargs['z']", "kwargs['w']"):
1235+
self.assertNotIn(s, hlo_str)
1236+
1237+
12301238
@parameterized.parameters([0, 2, [(0, 2)]])
12311239
def test_jit_lower_arg_info_static_argnums(self, static_argnums):
12321240
def f(x, y, *args, **kwargs):
@@ -1242,6 +1250,10 @@ def f(x, y, *args, **kwargs):
12421250
self.assertIn("kwargs['z']", hlo_str)
12431251
self.assertIn("kwargs['w']", hlo_str)
12441252

1253+
hlo_str = mlir.module_to_string(ir, enable_debug_info=False)
1254+
for s in ("\"x\"", "y['hi']", "args[0]", "args[1]", "kwargs['z']", "kwargs['w']"):
1255+
self.assertNotIn(s, hlo_str)
1256+
12451257
@parameterized.parameters(['a', 'b', [('a', 'b')]])
12461258
def test_jit_lower_arg_info_static_argnames(self, static_argnames):
12471259
def f(x, y, *args, **kwargs):
@@ -1259,6 +1271,13 @@ def f(x, y, *args, **kwargs):
12591271
self.assertNotIn("kwargs['a']", hlo_str)
12601272
self.assertNotIn("kwargs['b']", hlo_str)
12611273

1274+
hlo_str = mlir.module_to_string(ir, enable_debug_info=False)
1275+
for s in (
1276+
"\"x\"", "y['hi']", "args[0]", "args[1]", "kwargs['z']",
1277+
"kwargs['w']", "kwargs['a']", "kwargs['b']"
1278+
):
1279+
self.assertNotIn(s, hlo_str)
1280+
12621281
def test_jit_lower_result_info(self):
12631282
def f(x, y, z):
12641283
return {'a': x, 'b': [y]}

0 commit comments

Comments
 (0)