File tree Expand file tree Collapse file tree 1 file changed +5
-4
lines changed Expand file tree Collapse file tree 1 file changed +5
-4
lines changed Original file line number Diff line number Diff line change @@ -46,7 +46,6 @@ class CompilationResult:
46
46
lowering_result : lowering .LoweringResult
47
47
48
48
49
-
50
49
@util .weakref_lru_cache
51
50
def compile_jaxpr (
52
51
jaxpr : jax_core .Jaxpr ,
@@ -193,9 +192,6 @@ def _pallas_call_ttir_lowering(
193
192
num_warps : int ,
194
193
num_stages : int ,
195
194
):
196
- if triton_params :
197
- raise NotImplementedError ("triton_params are not supported" )
198
-
199
195
# TODO(sharadmv): handle multiple devices, right now we assume device 0
200
196
# which is fine when we have multiple of the same GPU but this won't work in
201
197
# general.
@@ -231,6 +227,11 @@ def _pallas_call_ttir_lowering(
231
227
grid_z = mlir .i32_attr (grid_z ),
232
228
debug = ir .BoolAttr .get (debug ),
233
229
)
230
+ if "serialized_metadata" in (triton_params or {}):
231
+ # This field is unstable and may be removed in the future.
232
+ backend_config ["serialized_metadata" ] = ir .StringAttr .get (
233
+ triton_params ["serialized_metadata" ]
234
+ )
234
235
return mlir .custom_call (
235
236
call_target_name = "__gpu$xla.gpu.triton" ,
236
237
result_types = out_types ,
You can’t perform that action at this time.
0 commit comments