Skip to content

Commit 008f87d

Browse files
superbobryjax authors
authored andcommitted
The compiler_params= argument of pl.pallas_call on GPU now uses "triton" to refer to Triton-specific parameters, instead of the repetitive "triton_params"
PiperOrigin-RevId: 623275152
1 parent b865c5b commit 008f87d

File tree

1 file changed

+4
-5
lines changed

1 file changed

+4
-5
lines changed

jax/_src/pallas/triton/pallas_call_registration.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -145,15 +145,14 @@ def pallas_call_lowering(
145145
raise NotImplementedError(
146146
"dynamic grid bounds not supported in the Triton backend"
147147
)
148-
triton_compiler_params = compiler_params.get("triton", compiler_params)
149-
triton_params = compiler_params.get("triton_params", {})
150-
num_warps = triton_compiler_params.pop("num_warps", 4)
148+
triton_params = compiler_params.get("triton", compiler_params)
149+
num_warps = triton_params.pop("num_warps", 4)
151150
if len(ctx.module_context.platforms) > 1:
152151
raise NotImplementedError("multi-platform lowering for Pallas kernels")
153152
if ctx.module_context.platforms[0] == "rocm":
154-
num_stages = triton_compiler_params.pop("num_stages", 1)
153+
num_stages = triton_params.pop("num_stages", 1)
155154
else:
156-
num_stages = triton_compiler_params.pop("num_stages", 3)
155+
num_stages = triton_params.pop("num_stages", 3)
157156

158157
if debug:
159158
print(jaxpr)

0 commit comments

Comments
 (0)