Skip to content

Commit 86bf8fa

Browse files
committed
Update env name
1 parent 0b70244 commit 86bf8fa

File tree

2 files changed

+31
-2
lines changed

2 files changed

+31
-2
lines changed

jax/_src/interpreters/mlir.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,13 @@
7575
"Supports the special value 'sponge' to pick the path from the "
7676
"environment variable TEST_UNDECLARED_OUTPUTS_DIR.")
7777

78+
_JAX_INCLUDE_DEBUG_INFO_IN_DUMPS = config.DEFINE_string(
79+
'jax_include_debug_info_in_dumps',
80+
os.getenv('JAX_INCLUDE_DEBUG_INFO_IN_DUMPS', "True"),
81+
help="Determine whether or not to keep debug symbols and location information "
82+
"when dumping IR code. By default, debug information will be preserved in "
83+
"the IR dump. To avoid exposing source code and potentially sensitive "
84+
"information, set to false")
7885
lowerable_effects: effects_lib.EffectTypeSet = effects_lib.lowerable_effects
7986

8087

@@ -474,9 +481,12 @@ def dump_module_message(module: ir.Module, stage_name: str) -> str:
474481
def _make_string_safe_for_filename(s: str) -> str:
475482
return re.sub(r'[^\w.)( -]', '', s)
476483

477-
def module_to_string(module: ir.Module) -> str:
484+
def module_to_string(module: ir.Module, enable_debug_info=None) -> str:
478485
output = io.StringIO()
479-
module.operation.print(file=output, enable_debug_info=True)
486+
if enable_debug_info is None:
487+
enable_debug_flag = str.lower(_JAX_INCLUDE_DEBUG_INFO_IN_DUMPS.value)
488+
enable_debug_info = enable_debug_flag not in ('false', '0')
489+
module.operation.print(file=output, enable_debug_info=enable_debug_info)
480490
return output.getvalue()
481491

482492
def module_to_bytecode(module: ir.Module) -> bytes:

tests/api_test.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1222,6 +1222,14 @@ def f(x, y, *args, **kwargs):
12221222
self.assertIn("kwargs['z']", hlo_str)
12231223
self.assertIn("kwargs['w']", hlo_str)
12241224

1225+
hlo_str = mlir.module_to_string(
1226+
lowered.compiler_ir('stablehlo'),
1227+
enable_debug_info=False,
1228+
)
1229+
for s in ("\"x\"", "y['hi']", "args[0]", "args[1]", "kwargs['z']", "kwargs['w']"):
1230+
self.assertNotIn(s, hlo_str)
1231+
1232+
12251233
@parameterized.parameters([0, 2, [(0, 2)]])
12261234
def test_jit_lower_arg_info_static_argnums(self, static_argnums):
12271235
def f(x, y, *args, **kwargs):
@@ -1237,6 +1245,10 @@ def f(x, y, *args, **kwargs):
12371245
self.assertIn("kwargs['z']", hlo_str)
12381246
self.assertIn("kwargs['w']", hlo_str)
12391247

1248+
hlo_str = mlir.module_to_string(ir, enable_debug_info=False)
1249+
for s in ("\"x\"", "y['hi']", "args[0]", "args[1]", "kwargs['z']", "kwargs['w']"):
1250+
self.assertNotIn(s, hlo_str)
1251+
12401252
@parameterized.parameters(['a', 'b', [('a', 'b')]])
12411253
def test_jit_lower_arg_info_static_argnames(self, static_argnames):
12421254
def f(x, y, *args, **kwargs):
@@ -1254,6 +1266,13 @@ def f(x, y, *args, **kwargs):
12541266
self.assertNotIn("kwargs['a']", hlo_str)
12551267
self.assertNotIn("kwargs['b']", hlo_str)
12561268

1269+
hlo_str = mlir.module_to_string(ir, enable_debug_info=False)
1270+
for s in (
1271+
"\"x\"", "y['hi']", "args[0]", "args[1]", "kwargs['z']",
1272+
"kwargs['w']", "kwargs['a']", "kwargs['b']"
1273+
):
1274+
self.assertNotIn(s, hlo_str)
1275+
12571276
def test_jit_lower_result_info(self):
12581277
def f(x, y, z):
12591278
return {'a': x, 'b': [y]}

0 commit comments

Comments
 (0)