Skip to content

Commit 05e61ed

Browse files
author
jax authors
committed
Expose API to control whether to fuse input computation with Pallas kernel on per input basis.
PiperOrigin-RevId: 617975104
1 parent 4c7351f commit 05e61ed

File tree

2 files changed

+28
-0
lines changed

2 files changed

+28
-0
lines changed

jax/_src/pallas/mosaic/pallas_call_registration.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ def _lower_fun(*args):
102102
cost_estimate=mosaic_params.get("cost_estimate", None),
103103
vmem_limit_bytes=mosaic_params.get("vmem_limit_bytes", None),
104104
flags=mosaic_params.get("flags", None),
105+
allow_input_fusion=mosaic_params.get("allow_input_fusion", None),
105106
input_output_aliases=input_output_aliases,
106107
)(
107108
*dynamic_grid_args,

tests/pallas/pallas_call_tpu_test.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
"""Test TPU-specific extensions to pallas_call."""
1616

1717
import functools
18+
import re
1819
from absl.testing import absltest
1920
from absl.testing import parameterized
2021
import jax
@@ -1358,6 +1359,32 @@ def kernel(x_ref, y_ref):
13581359
compiler_params=dict(mosaic=dict(vmem_limit_bytes=int(2**18))),
13591360
)(x)
13601361

1362+
def test_allow_input_fusion(self):
1363+
shape = (3, 128, 128)
1364+
1365+
def kernel(x_ref, y_ref):
1366+
y_ref[...] = x_ref[...]
1367+
1368+
def f(x, y):
1369+
z = jax.numpy.add(x, y)
1370+
return pl.pallas_call(
1371+
kernel,
1372+
grid=(3,),
1373+
in_specs=[pl.BlockSpec(lambda i: (i, 0, 0), (1, 128, 128))],
1374+
out_specs=pl.BlockSpec(lambda i: (i, 0, 0), (1, 128, 128)),
1375+
out_shape=x,
1376+
compiler_params=dict(mosaic=dict(allow_input_fusion=[True])),
1377+
)(z)
1378+
1379+
x = jnp.arange(np.prod(shape), dtype=np.float32).reshape(shape)
1380+
y = jnp.arange(np.prod(shape), dtype=np.float32).reshape(shape)
1381+
1382+
out = f(x, y)
1383+
expected = x + y
1384+
np.testing.assert_array_equal(out, expected)
1385+
compiled = jax.jit(f).lower(x, y).compile().as_text()
1386+
assert re.search(r'fusion.*kind=kCustom.*fused_computation', compiled)
1387+
13611388

13621389
class PallasCallUnblockedIndexingTest(PallasTPUTest):
13631390

0 commit comments

Comments
 (0)