Skip to content

Commit 99fadcb

Browse files
tlongerijax authors
authored andcommitted
[Mosaic] Restore Python pipeline and add a CLI flag to run it.
We decided to expose a Python alternative again to make it easier for OSS users to see and customize the pipeline. The default is still to run the pipeline from XLA. The original one was removed in cl/596464480 and cl/597332393. PiperOrigin-RevId: 617291995
1 parent df9cefa commit 99fadcb

File tree

2 files changed

+121
-3
lines changed

2 files changed

+121
-3
lines changed

jax/BUILD

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -826,6 +826,8 @@ pytype_strict_library(
826826
"//jax/_src/lib",
827827
] + if_building_jaxlib([
828828
"//jaxlib/mlir:ir",
829+
"//jaxlib/mlir:mhlo_dialect",
830+
"//jaxlib/mlir:pass_manager",
829831
"//jaxlib/mlir:stablehlo_dialect",
830832
]) + py_deps("numpy") + py_deps("absl/flags"),
831833
)

jax/_src/tpu_custom_call.py

Lines changed: 119 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,18 +30,30 @@
3030
import jax
3131
from jax import core
3232
from jax._src import config
33+
from jax._src import sharding_impls
34+
from jax._src.interpreters import mlir
3335
from jax._src.lib import tpu
3436
from jax._src.lib import xla_client
3537
from jax._src.lib.mlir.dialects import hlo
36-
from jax._src.interpreters import mlir
37-
from jax._src import sharding_impls
3838
from jax.interpreters import xla
3939
from jaxlib.mlir import ir
40+
from jaxlib.mlir.dialects import mhlo
4041
from jaxlib.mlir.dialects import stablehlo
42+
from jaxlib.mlir.passmanager import PassManager
4143
import numpy as np
4244

4345
FLAGS = flags.FLAGS
4446

47+
_MOSAIC_USE_PYTHON_PIPELINE = config.define_bool_state(
48+
name="mosaic_use_python_pipeline",
49+
default=False,
50+
help=(
51+
"Run the initial Mosaic MLIR passes from Python, when as_tpu_kernel"
52+
" is called (for Pallas, this happens at JAX lowering time), instead of"
53+
" later within XLA."
54+
),
55+
)
56+
4557
_MOSAIC_ALLOW_HLO = config.define_bool_state(
4658
name="jax_mosaic_allow_hlo",
4759
default=False,
@@ -250,6 +262,105 @@ def _tpu_custom_call_lowering(
250262
platform="tpu")
251263

252264

265+
def _lower_tpu_kernel(
266+
module: ir.Module,
267+
hardware_generation: int,
268+
) -> ir.Module:
269+
"""Runs MLIR passes lowering the given module to an MLIR module.
270+
271+
Uses Python versions of infer-memref-layout and apply-vector-layout.
272+
273+
Args:
274+
module: The MLIR module to lower.
275+
hardware_generation: The TPU hardware generation to target.
276+
277+
Returns:
278+
An MLIR module implementing the kernel.
279+
"""
280+
try:
281+
module.operation.verify()
282+
except ir.MLIRError as e:
283+
raise ValueError("The compiled module fails MLIR verification") from e
284+
285+
with module.context as ctx, module.operation.location as _:
286+
287+
ctx.append_dialect_registry(mlir.upstream_dialects)
288+
ctx.load_all_available_dialects()
289+
tpu.register_dialect(ctx)
290+
mhlo.register_mhlo_dialect(ctx)
291+
mhlo.register_mhlo_passes()
292+
293+
# We'll mutate the module, so clone it
294+
module = ir.Module.parse(
295+
module.operation.get_asm(binary=True, enable_debug_info=True)
296+
)
297+
dump_mlir(module, "original")
298+
299+
if _MOSAIC_ALLOW_HLO.value:
300+
# Run hlo dialect conversion: hlo -> linalg -> vector.
301+
pipeline = [
302+
"hlo-legalize-to-arithmetic",
303+
"func.func(hlo-legalize-to-linalg)",
304+
"func.func(linalg-vectorization)",
305+
]
306+
pipeline = PassManager.parse(f"builtin.module({','.join(pipeline)})")
307+
pipeline.run(module.operation)
308+
dump_mlir(module, "post-hlo-conversion")
309+
310+
pipeline = [
311+
f"func.func(tpu-infer-memref-layout{{hardware-generation={hardware_generation}}})"
312+
]
313+
pipeline = PassManager.parse(f"builtin.module({','.join(pipeline)})")
314+
pipeline.run(module.operation)
315+
dump_mlir(module, "post-infer-memref-layout")
316+
317+
pipeline = [
318+
"canonicalize",
319+
"cse",
320+
]
321+
pipeline = PassManager.parse(f"builtin.module({','.join(pipeline)})")
322+
pipeline.run(module.operation)
323+
dump_mlir(module, "post-simplify")
324+
325+
if checks := FLAGS["xla_mosaic_on_device_checks"].value:
326+
checks = set(checks.split(","))
327+
if checks == {"bounds"}: # We only support one kind of checks now.
328+
pipeline = PassManager.parse(
329+
"builtin.module(func.func(debug-assert-insertion))"
330+
)
331+
pipeline.run(module.operation)
332+
dump_mlir(module, "post-assert-insertion")
333+
elif checks:
334+
checks.discard("bounds")
335+
raise ValueError(
336+
f"Unrecognized on-device check categories: {', '.join(checks)}"
337+
)
338+
339+
pipeline = [
340+
"func.func(tpu-infer-vector-layout{sublane-count=8 lane-count=128})",
341+
]
342+
pipeline = PassManager.parse(f"builtin.module({','.join(pipeline)})")
343+
pipeline.run(module.operation)
344+
dump_mlir(module, "post-infer-vector-layout")
345+
346+
mxu_size = 128 if hardware_generation < 6 else 256
347+
pipeline = [
348+
"func.func(tpu-apply-vector-layout{sublane-count=8 lane-count=128"
349+
f" hardware-generation={hardware_generation}"
350+
f" mxu-contracting-size={mxu_size} mxu-noncontracting-size={mxu_size}"
351+
"})"
352+
]
353+
pipeline = PassManager.parse(f"builtin.module({','.join(pipeline)})")
354+
pipeline.run(module.operation)
355+
dump_mlir(module, "post-apply-vector-layout")
356+
357+
pipeline = PassManager.parse("builtin.module(canonicalize)")
358+
pipeline.run(module.operation)
359+
dump_mlir(module, "pre-lower-to-llo")
360+
361+
return module
362+
363+
253364
def as_tpu_kernel(
254365
module: ir.Module,
255366
out_type: Any,
@@ -279,6 +390,11 @@ def as_tpu_kernel(
279390
has_communication, has_custom_barrier = tpu.private_has_communication(
280391
module.operation
281392
)
393+
needs_layout_passes = not device_type
394+
if needs_layout_passes and _MOSAIC_USE_PYTHON_PIPELINE.value:
395+
module = _lower_tpu_kernel(module, hardware_generation)
396+
needs_layout_passes = False
397+
282398
bytecode_buffer = io.BytesIO()
283399
module.operation.write_bytecode(bytecode_buffer, desired_version=0)
284400
asm = bytecode_buffer.getvalue()
@@ -290,7 +406,7 @@ def as_tpu_kernel(
290406
asm,
291407
out_type,
292408
needs_hlo_passes=_MOSAIC_ALLOW_HLO.value,
293-
needs_layout_passes=not device_type,
409+
needs_layout_passes=needs_layout_passes,
294410
device_type=device_type,
295411
has_communication=has_communication,
296412
has_custom_barrier=has_custom_barrier,

0 commit comments

Comments
 (0)