|
32 | 32 | from jax._src import custom_derivatives
|
33 | 33 | from jax._src import linear_util as lu
|
34 | 34 | from jax._src import pjit
|
| 35 | +from jax._src import source_info_util |
35 | 36 | from jax._src import state
|
36 | 37 | from jax._src import util
|
37 | 38 | from jax._src.interpreters import mlir
|
|
44 | 45 | from jax._src.pallas import core as pallas_core
|
45 | 46 | from jax._src.pallas import primitives
|
46 | 47 | from jax._src.pallas import utils as pallas_utils
|
47 |
| -from jax._src.state import AbstractRef |
48 | 48 | from jax._src.state import discharge
|
49 | 49 | from jax._src.state import indexing
|
50 | 50 | from jax._src.state import primitives as sp
|
@@ -73,6 +73,7 @@ class ModuleContext:
|
73 | 73 | name: str
|
74 | 74 | grid_mapping: GridMapping
|
75 | 75 | program_ids: Sequence[ir.Value]
|
| 76 | + traceback_caches: mlir.TracebackCaches = dataclasses.field(repr=False) |
76 | 77 |
|
77 | 78 |
|
78 | 79 | @dataclasses.dataclass
|
@@ -269,7 +270,9 @@ def lower_jaxpr_to_triton_module(
|
269 | 270 | for i, pid in enumerate(program_ids)
|
270 | 271 | if i not in grid_mapping.mapped_dims
|
271 | 272 | ]
|
272 |
| - ctx = ModuleContext(name, grid_mapping, local_program_ids) |
| 273 | + ctx = ModuleContext( |
| 274 | + name, grid_mapping, local_program_ids, mlir.TracebackCaches() |
| 275 | + ) |
273 | 276 | if grid_mapping.num_index_operands:
|
274 | 277 | raise NotImplementedError(
|
275 | 278 | "Scalar prefetch not supported in Triton lowering."
|
@@ -336,9 +339,13 @@ def write_env(var: jax_core.Var, val):
|
336 | 339 | avals_in = [v.aval for v in eqn.invars]
|
337 | 340 | avals_out = [v.aval for v in eqn.outvars]
|
338 | 341 | eqn_block_infos = map(read_block_info_env, eqn.invars)
|
| 342 | + loc = mlir._source_info_to_location( |
| 343 | + ctx, eqn.primitive, eqn.params, eqn.source_info |
| 344 | + ) |
339 | 345 | rule_ctx = LoweringRuleContext(ctx, avals_in, avals_out, eqn_block_infos)
|
340 | 346 | try:
|
341 |
| - outvals = rule(rule_ctx, *invals, **eqn.params) |
| 347 | + with source_info_util.user_context(eqn.source_info.traceback), loc: |
| 348 | + outvals = rule(rule_ctx, *invals, **eqn.params) |
342 | 349 | except LoweringError:
|
343 | 350 | raise # We only add the extra info to the innermost exception.
|
344 | 351 | except Exception as e:
|
@@ -2039,7 +2046,9 @@ def _for_lowering_rule(
|
2039 | 2046 | step = _i32_constant(1)
|
2040 | 2047 | init_args = map(_ensure_ir_value, args, ctx.avals_in)
|
2041 | 2048 | # Partially discharge state from jaxpr for non-pointers
|
2042 |
| - should_discharge = [not isinstance(a, AbstractRef) for a in ctx.avals_in] |
| 2049 | + should_discharge = [ |
| 2050 | + not isinstance(a, state.AbstractRef) for a in ctx.avals_in |
| 2051 | + ] |
2043 | 2052 | discharged_jaxpr, () = discharge.discharge_state(
|
2044 | 2053 | jaxpr, (), should_discharge=[True, *should_discharge]
|
2045 | 2054 | )
|
|
0 commit comments