30
30
import jax
31
31
from jax import core
32
32
from jax ._src import config
33
+ from jax ._src import sharding_impls
34
+ from jax ._src .interpreters import mlir
33
35
from jax ._src .lib import tpu
34
36
from jax ._src .lib import xla_client
35
37
from jax ._src .lib .mlir .dialects import hlo
36
- from jax ._src .interpreters import mlir
37
- from jax ._src import sharding_impls
38
38
from jax .interpreters import xla
39
39
from jaxlib .mlir import ir
40
+ from jaxlib .mlir .dialects import mhlo
40
41
from jaxlib .mlir .dialects import stablehlo
42
+ from jaxlib .mlir .passmanager import PassManager
41
43
import numpy as np
42
44
43
45
FLAGS = flags .FLAGS
44
46
47
+ _MOSAIC_USE_PYTHON_PIPELINE = config .define_bool_state (
48
+ name = "mosaic_use_python_pipeline" ,
49
+ default = False ,
50
+ help = (
51
+ "Run the initial Mosaic MLIR passes from Python, when as_tpu_kernel"
52
+ " is called (for Pallas, this happens at JAX lowering time), instead of"
53
+ " later within XLA."
54
+ ),
55
+ )
56
+
45
57
_MOSAIC_ALLOW_HLO = config .define_bool_state (
46
58
name = "jax_mosaic_allow_hlo" ,
47
59
default = False ,
@@ -250,6 +262,105 @@ def _tpu_custom_call_lowering(
250
262
platform = "tpu" )
251
263
252
264
265
+ def _lower_tpu_kernel (
266
+ module : ir .Module ,
267
+ hardware_generation : int ,
268
+ ) -> ir .Module :
269
+ """Runs MLIR passes lowering the given module to an MLIR module.
270
+
271
+ Uses Python versions of infer-memref-layout and apply-vector-layout.
272
+
273
+ Args:
274
+ module: The MLIR module to lower.
275
+ hardware_generation: The TPU hardware generation to target.
276
+
277
+ Returns:
278
+ An MLIR module implementing the kernel.
279
+ """
280
+ try :
281
+ module .operation .verify ()
282
+ except ir .MLIRError as e :
283
+ raise ValueError ("The compiled module fails MLIR verification" ) from e
284
+
285
+ with module .context as ctx , module .operation .location as _ :
286
+
287
+ ctx .append_dialect_registry (mlir .upstream_dialects )
288
+ ctx .load_all_available_dialects ()
289
+ tpu .register_dialect (ctx )
290
+ mhlo .register_mhlo_dialect (ctx )
291
+ mhlo .register_mhlo_passes ()
292
+
293
+ # We'll mutate the module, so clone it
294
+ module = ir .Module .parse (
295
+ module .operation .get_asm (binary = True , enable_debug_info = True )
296
+ )
297
+ dump_mlir (module , "original" )
298
+
299
+ if _MOSAIC_ALLOW_HLO .value :
300
+ # Run hlo dialect conversion: hlo -> linalg -> vector.
301
+ pipeline = [
302
+ "hlo-legalize-to-arithmetic" ,
303
+ "func.func(hlo-legalize-to-linalg)" ,
304
+ "func.func(linalg-vectorization)" ,
305
+ ]
306
+ pipeline = PassManager .parse (f"builtin.module({ ',' .join (pipeline )} )" )
307
+ pipeline .run (module .operation )
308
+ dump_mlir (module , "post-hlo-conversion" )
309
+
310
+ pipeline = [
311
+ f"func.func(tpu-infer-memref-layout{{hardware-generation={ hardware_generation } }})"
312
+ ]
313
+ pipeline = PassManager .parse (f"builtin.module({ ',' .join (pipeline )} )" )
314
+ pipeline .run (module .operation )
315
+ dump_mlir (module , "post-infer-memref-layout" )
316
+
317
+ pipeline = [
318
+ "canonicalize" ,
319
+ "cse" ,
320
+ ]
321
+ pipeline = PassManager .parse (f"builtin.module({ ',' .join (pipeline )} )" )
322
+ pipeline .run (module .operation )
323
+ dump_mlir (module , "post-simplify" )
324
+
325
+ if checks := FLAGS ["xla_mosaic_on_device_checks" ].value :
326
+ checks = set (checks .split ("," ))
327
+ if checks == {"bounds" }: # We only support one kind of checks now.
328
+ pipeline = PassManager .parse (
329
+ "builtin.module(func.func(debug-assert-insertion))"
330
+ )
331
+ pipeline .run (module .operation )
332
+ dump_mlir (module , "post-assert-insertion" )
333
+ elif checks :
334
+ checks .discard ("bounds" )
335
+ raise ValueError (
336
+ f"Unrecognized on-device check categories: { ', ' .join (checks )} "
337
+ )
338
+
339
+ pipeline = [
340
+ "func.func(tpu-infer-vector-layout{sublane-count=8 lane-count=128})" ,
341
+ ]
342
+ pipeline = PassManager .parse (f"builtin.module({ ',' .join (pipeline )} )" )
343
+ pipeline .run (module .operation )
344
+ dump_mlir (module , "post-infer-vector-layout" )
345
+
346
+ mxu_size = 128 if hardware_generation < 6 else 256
347
+ pipeline = [
348
+ "func.func(tpu-apply-vector-layout{sublane-count=8 lane-count=128"
349
+ f" hardware-generation={ hardware_generation } "
350
+ f" mxu-contracting-size={ mxu_size } mxu-noncontracting-size={ mxu_size } "
351
+ "})"
352
+ ]
353
+ pipeline = PassManager .parse (f"builtin.module({ ',' .join (pipeline )} )" )
354
+ pipeline .run (module .operation )
355
+ dump_mlir (module , "post-apply-vector-layout" )
356
+
357
+ pipeline = PassManager .parse ("builtin.module(canonicalize)" )
358
+ pipeline .run (module .operation )
359
+ dump_mlir (module , "pre-lower-to-llo" )
360
+
361
+ return module
362
+
363
+
253
364
def as_tpu_kernel (
254
365
module : ir .Module ,
255
366
out_type : Any ,
@@ -279,6 +390,11 @@ def as_tpu_kernel(
279
390
has_communication , has_custom_barrier = tpu .private_has_communication (
280
391
module .operation
281
392
)
393
+ needs_layout_passes = not device_type
394
+ if needs_layout_passes and _MOSAIC_USE_PYTHON_PIPELINE .value :
395
+ module = _lower_tpu_kernel (module , hardware_generation )
396
+ needs_layout_passes = False
397
+
282
398
bytecode_buffer = io .BytesIO ()
283
399
module .operation .write_bytecode (bytecode_buffer , desired_version = 0 )
284
400
asm = bytecode_buffer .getvalue ()
@@ -290,7 +406,7 @@ def as_tpu_kernel(
290
406
asm ,
291
407
out_type ,
292
408
needs_hlo_passes = _MOSAIC_ALLOW_HLO .value ,
293
- needs_layout_passes = not device_type ,
409
+ needs_layout_passes = needs_layout_passes ,
294
410
device_type = device_type ,
295
411
has_communication = has_communication ,
296
412
has_custom_barrier = has_custom_barrier ,
0 commit comments