Skip to content

Commit 93e5bbe

Browse files
author
jax authors
committed
A fusion flag for each operand is set to false by default. A custom call writer is expected to turn them on if he expects those fusions to be profitable. The operand may not fuse despite the flag being turned to true because of other constrains such as estimated memory required after fusion.
PiperOrigin-RevId: 614772584
1 parent 53364b4 commit 93e5bbe

File tree

1 file changed

+14
-1
lines changed

1 file changed

+14
-1
lines changed

jax/_src/tpu_custom_call.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ class CustomCallBackendConfig:
7979
needs_layout_passes: bool
8080
vmem_limit_bytes: int | None
8181
flags: dict[str, bool | int | float] | None
82+
allow_input_fusion: list[bool] | None
8283

8384
# We omit the body while printing, because primitive params get embedded
8485
# in HLO metadata, and the body blows up its size.
@@ -107,7 +108,15 @@ def to_json(self) -> bytes:
107108
if self.needs_layout_passes:
108109
config.write(b', "needs_layout_passes": ')
109110
config.write(str(self.needs_layout_passes).lower().encode("ascii"))
110-
config.write(b"}")
111+
if self.allow_input_fusion is not None:
112+
config.write(b', "allow_input_fusion": [')
113+
for i, value in enumerate(self.allow_input_fusion):
114+
config.write(b"true" if value else b"false")
115+
# config.write(str(value).lower().encode("ascii"))
116+
if i + 1 != len(self.allow_input_fusion):
117+
config.write(b",")
118+
config.write(b"]")
119+
config.write(b"}") # End of custom_call_config.
111120
if self.device_type is not None:
112121
config.write(b', "device_type": ')
113122
config.write(
@@ -252,6 +261,7 @@ def as_tpu_kernel(
252261
kernel_regeneration_metadata: bytes | None = None,
253262
vmem_limit_bytes: int | None = None,
254263
flags: dict[str, bool | int | float] | None = None,
264+
allow_input_fusion: list[bool] | None = None,
255265
input_output_aliases: tuple[tuple[int, int], ...] = (),
256266
) -> Callable[..., Any]:
257267
"""Turns an MLIR Mosaic kernel into a JAX-compatible function."""
@@ -289,6 +299,7 @@ def as_tpu_kernel(
289299
cost_estimate=cost_estimate,
290300
vmem_limit_bytes=vmem_limit_bytes,
291301
flags=flags,
302+
allow_input_fusion=allow_input_fusion,
292303
input_output_aliases=input_output_aliases,
293304
)
294305

@@ -307,6 +318,7 @@ def _lowered_as_tpu_kernel(
307318
kernel_regeneration_metadata: bytes | None = None,
308319
vmem_limit_bytes: int | None = None,
309320
flags: dict[str, bool | int | float] | None = None,
321+
allow_input_fusion: list[bool] | None = None,
310322
input_output_aliases: tuple[tuple[int, int], ...] = (),
311323
):
312324
"""Turns a low-level MLIR Mosaic kernel into a JAX-compatible function."""
@@ -336,6 +348,7 @@ def apply_kernel(*args, collective_id: int | None = None):
336348
needs_layout_passes,
337349
vmem_limit_bytes,
338350
flags,
351+
allow_input_fusion,
339352
)
340353
result = tpu_custom_call_p.bind(
341354
*args,

0 commit comments

Comments
 (0)