Skip to content

Commit 7863508

Browse files
superbobryjax authors
authored andcommitted
Include source info as ir.Locations when lowering Pallas kernels on GPU
I decided to leave out the name stacks for now for simplicity, but we might want to add them in the future. PiperOrigin-RevId: 614644216
1 parent 477a5aa commit 7863508

File tree

2 files changed

+16
-5
lines changed

2 files changed

+16
-5
lines changed

jax/_src/pallas/triton/lowering.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from jax._src import custom_derivatives
3333
from jax._src import linear_util as lu
3434
from jax._src import pjit
35+
from jax._src import source_info_util
3536
from jax._src import state
3637
from jax._src import util
3738
from jax._src.interpreters import mlir
@@ -44,7 +45,6 @@
4445
from jax._src.pallas import core as pallas_core
4546
from jax._src.pallas import primitives
4647
from jax._src.pallas import utils as pallas_utils
47-
from jax._src.state import AbstractRef
4848
from jax._src.state import discharge
4949
from jax._src.state import indexing
5050
from jax._src.state import primitives as sp
@@ -73,6 +73,7 @@ class ModuleContext:
7373
name: str
7474
grid_mapping: GridMapping
7575
program_ids: Sequence[ir.Value]
76+
traceback_caches: mlir.TracebackCaches = dataclasses.field(repr=False)
7677

7778

7879
@dataclasses.dataclass
@@ -269,7 +270,9 @@ def lower_jaxpr_to_triton_module(
269270
for i, pid in enumerate(program_ids)
270271
if i not in grid_mapping.mapped_dims
271272
]
272-
ctx = ModuleContext(name, grid_mapping, local_program_ids)
273+
ctx = ModuleContext(
274+
name, grid_mapping, local_program_ids, mlir.TracebackCaches()
275+
)
273276
if grid_mapping.num_index_operands:
274277
raise NotImplementedError(
275278
"Scalar prefetch not supported in Triton lowering."
@@ -336,9 +339,13 @@ def write_env(var: jax_core.Var, val):
336339
avals_in = [v.aval for v in eqn.invars]
337340
avals_out = [v.aval for v in eqn.outvars]
338341
eqn_block_infos = map(read_block_info_env, eqn.invars)
342+
loc = mlir._source_info_to_location(
343+
ctx, eqn.primitive, eqn.params, eqn.source_info
344+
)
339345
rule_ctx = LoweringRuleContext(ctx, avals_in, avals_out, eqn_block_infos)
340346
try:
341-
outvals = rule(rule_ctx, *invals, **eqn.params)
347+
with source_info_util.user_context(eqn.source_info.traceback), loc:
348+
outvals = rule(rule_ctx, *invals, **eqn.params)
342349
except LoweringError:
343350
raise # We only add the extra info to the innermost exception.
344351
except Exception as e:
@@ -2039,7 +2046,9 @@ def _for_lowering_rule(
20392046
step = _i32_constant(1)
20402047
init_args = map(_ensure_ir_value, args, ctx.avals_in)
20412048
# Partially discharge state from jaxpr for non-pointers
2042-
should_discharge = [not isinstance(a, AbstractRef) for a in ctx.avals_in]
2049+
should_discharge = [
2050+
not isinstance(a, state.AbstractRef) for a in ctx.avals_in
2051+
]
20432052
discharged_jaxpr, () = discharge.discharge_state(
20442053
jaxpr, (), should_discharge=[True, *should_discharge]
20452054
)

jax/_src/pallas/triton/pallas_call_registration.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,9 @@ def _pallas_call_ttir_lowering(
205205
lowering_result = lowering.lower_jaxpr_to_triton_module(
206206
jaxpr, (*in_shapes, *out_shapes), grid_mapping, name, cuda_options
207207
)
208+
module_op = lowering_result.module.operation
208209
if debug:
210+
print(module_op.get_asm(enable_debug_info=True, pretty_debug_info=True))
209211
lowering_result.module.dump()
210212

211213
grid_x, grid_y, grid_z = normalize_grid(lowering_result.grid)
@@ -214,7 +216,7 @@ def _pallas_call_ttir_lowering(
214216
for shape in out_shapes
215217
]
216218
buf = io.BytesIO()
217-
lowering_result.module.operation.write_bytecode(buf)
219+
module_op.write_bytecode(buf)
218220
backend_config = dict(
219221
name=ir.StringAttr.get(name),
220222
ir=ir.StringAttr.get(buf.getvalue()),

0 commit comments

Comments
 (0)