Skip to content

Commit a205c91

Browse files
superbobryjax authors
authored andcommitted
pallas_call now has only one way to pass compiler_params=
Previously, it was possible to do pallas_call(..., foo=42) and also pallas_call(..., compiler_params=dict(foo=42)) PiperOrigin-RevId: 623277572
1 parent 008f87d commit a205c91

File tree

6 files changed

+19
-18
lines changed

6 files changed

+19
-18
lines changed

docs/pallas/tpu/details.rst

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -147,8 +147,10 @@ grid axes over cores. This is an opt-in procedure. To allow that,
147147
..
148148
pallas_call(
149149
...,
150-
mosaic_params=dict(
151-
dimension_semantics=["parallel", "parallel", "arbitrary"]
150+
compiler_params=dict(
151+
mosaic=dict(
152+
dimension_semantics=["parallel", "parallel", "arbitrary"]
153+
)
152154
),
153155
)
154156

jax/_src/pallas/pallas_call.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -561,14 +561,10 @@ def pallas_call(
561561
interpret: bool = False,
562562
name: str | None = None,
563563
compiler_params: dict[str, Any] | None = None,
564-
**compiler_params_: Any,
565564
):
566565
name = _extract_function_name(f, name)
567566
if compiler_params is None:
568567
compiler_params = {}
569-
assert not (compiler_params and compiler_params_)
570-
if compiler_params_:
571-
compiler_params = compiler_params_
572568
if grid is not None and grid_spec is not None:
573569
raise ValueError("Cannot specify both grid and grid_spec at the same time.")
574570
if grid_spec is None:

jax/experimental/pallas/ops/gpu/decode_attention.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -165,8 +165,9 @@ def attn_unbatched(
165165
),
166166
), # m
167167
],
168-
num_warps=num_warps_,
169-
num_stages=num_stages,
168+
compiler_params=dict(
169+
triton=dict(num_warps=num_warps_, num_stages=num_stages)
170+
),
170171
out_shape=[
171172
jax.ShapeDtypeStruct(shape=(k_splits, *q.shape), dtype=q.dtype), # o
172173
jax.ShapeDtypeStruct(

jax/experimental/pallas/ops/tpu/all_gather.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ def ag_local(x_shard):
136136
out = pl.pallas_call(
137137
functools.partial(ag_kernel, axis_name=axis_name, mesh=mesh),
138138
out_shape=out_shape,
139-
mosaic_params=dict(collective_id=0),
139+
compiler_params=dict(mosaic=dict(collective_id=0)),
140140
grid_spec=pltpu.PrefetchScalarGridSpec(
141141
num_scalar_prefetch=0,
142142
scratch_shapes=(

jax/experimental/pallas/ops/tpu/megablox/gmm.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -537,9 +537,11 @@ def out_transform_indices(n_i, grid_id, k_i, group_metadata, group_offset):
537537
scratch_shapes=[pltpu.VMEM((tm, tn), jnp.float32)],
538538
),
539539
input_output_aliases=input_output_aliases,
540-
mosaic_params=dict(
541-
dimension_semantics=("parallel", "arbitrary", "arbitrary"),
542-
cost_estimate=cost_estimate,
540+
compiler_params=dict(
541+
mosaic=dict(
542+
dimension_semantics=("parallel", "arbitrary", "arbitrary"),
543+
cost_estimate=cost_estimate,
544+
)
543545
),
544546
interpret=interpret,
545547
)
@@ -777,9 +779,11 @@ def out_transform_indices(n_i, k_i, grid_id, group_metadata, group_offset):
777779
scratch_shapes=[pltpu.VMEM((tk, tn), jnp.float32)],
778780
),
779781
input_output_aliases=input_output_aliases,
780-
mosaic_params=dict(
781-
dimension_semantics=("parallel", "arbitrary", "arbitrary"),
782-
cost_estimate=cost_estimate,
782+
compiler_params=dict(
783+
mosaic=dict(
784+
dimension_semantics=("parallel", "arbitrary", "arbitrary"),
785+
cost_estimate=cost_estimate,
786+
)
783787
),
784788
interpret=interpret,
785789
)

jax/experimental/pallas/ops/tpu/paged_attention/paged_attention_kernel.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -487,9 +487,7 @@ def paged_attention(
487487
pltpu.SemaphoreType.DMA,
488488
),
489489
),
490-
mosaic_params=dict(
491-
dimension_semantics=dimension_sematics,
492-
),
490+
compiler_params=dict(mosaic=dict(dimension_semantics=dimension_sematics)),
493491
out_shape=[
494492
jax.ShapeDtypeStruct(q.shape, q_dtype_for_kernel_launch),
495493
jax.ShapeDtypeStruct((*q.shape[:-1], 1), jnp.float32),

0 commit comments

Comments
 (0)