@@ -79,6 +79,7 @@ class CustomCallBackendConfig:
79
79
needs_layout_passes : bool
80
80
vmem_limit_bytes : int | None
81
81
flags : dict [str , bool | int | float ] | None
82
+ allow_input_fusion : list [bool ] | None
82
83
83
84
# We omit the body while printing, because primitive params get embedded
84
85
# in HLO metadata, and the body blows up its size.
@@ -107,7 +108,15 @@ def to_json(self) -> bytes:
107
108
if self .needs_layout_passes :
108
109
config .write (b', "needs_layout_passes": ' )
109
110
config .write (str (self .needs_layout_passes ).lower ().encode ("ascii" ))
110
- config .write (b"}" )
111
+ if self .allow_input_fusion is not None :
112
+ config .write (b', "allow_input_fusion": [' )
113
+ for i , value in enumerate (self .allow_input_fusion ):
114
+ config .write (b"true" if value else b"false" )
115
+ # config.write(str(value).lower().encode("ascii"))
116
+ if i + 1 != len (self .allow_input_fusion ):
117
+ config .write (b"," )
118
+ config .write (b"]" )
119
+ config .write (b"}" ) # End of custom_call_config.
111
120
if self .device_type is not None :
112
121
config .write (b', "device_type": ' )
113
122
config .write (
@@ -252,6 +261,7 @@ def as_tpu_kernel(
252
261
kernel_regeneration_metadata : bytes | None = None ,
253
262
vmem_limit_bytes : int | None = None ,
254
263
flags : dict [str , bool | int | float ] | None = None ,
264
+ allow_input_fusion : list [bool ] | None = None ,
255
265
input_output_aliases : tuple [tuple [int , int ], ...] = (),
256
266
) -> Callable [..., Any ]:
257
267
"""Turns an MLIR Mosaic kernel into a JAX-compatible function."""
@@ -289,6 +299,7 @@ def as_tpu_kernel(
289
299
cost_estimate = cost_estimate ,
290
300
vmem_limit_bytes = vmem_limit_bytes ,
291
301
flags = flags ,
302
+ allow_input_fusion = allow_input_fusion ,
292
303
input_output_aliases = input_output_aliases ,
293
304
)
294
305
@@ -307,6 +318,7 @@ def _lowered_as_tpu_kernel(
307
318
kernel_regeneration_metadata : bytes | None = None ,
308
319
vmem_limit_bytes : int | None = None ,
309
320
flags : dict [str , bool | int | float ] | None = None ,
321
+ allow_input_fusion : list [bool ] | None = None ,
310
322
input_output_aliases : tuple [tuple [int , int ], ...] = (),
311
323
):
312
324
"""Turns a low-level MLIR Mosaic kernel into a JAX-compatible function."""
@@ -336,6 +348,7 @@ def apply_kernel(*args, collective_id: int | None = None):
336
348
needs_layout_passes ,
337
349
vmem_limit_bytes ,
338
350
flags ,
351
+ allow_input_fusion ,
339
352
)
340
353
result = tpu_custom_call_p .bind (
341
354
* args ,
0 commit comments