Skip to content

Commit 8250419

Browse files
author
jax authors
committed
Merge pull request #20718 from kkiningh:patch-3
PiperOrigin-RevId: 624336436
2 parents 2948a80 + 44a4b04 commit 8250419

File tree

2 files changed

+4
-4
lines changed

2 files changed

+4
-4
lines changed

jax/_src/pallas/primitives.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@
4848

4949
program_id_p = jax_core.Primitive("program_id")
5050

51-
def program_id(axis):
51+
def program_id(axis: int) -> jax.Array:
5252
return program_id_p.bind(axis=axis)
5353

5454
def program_id_bind(*, axis: int):
@@ -70,7 +70,7 @@ def _program_id_abstract_eval(**_):
7070

7171
num_programs_p = jax_core.Primitive("num_programs")
7272

73-
def num_programs(axis):
73+
def num_programs(axis: int) -> jax.Array:
7474
return num_programs_p.bind(axis=axis)
7575

7676
@num_programs_p.def_custom_bind
@@ -223,7 +223,7 @@ def _max_contiguous_abstract_eval(aval, **_):
223223
multiple_of_p.def_impl(lambda x, **_: x)
224224
mlir.register_lowering(multiple_of_p, lambda _, x, **__: [x])
225225

226-
def multiple_of(x, values):
226+
def multiple_of(x: jax.Array, values: list[int] | int) -> jax.Array:
227227
if not isinstance(values, list):
228228
values = [values]
229229
return multiple_of_p.bind(x, values=values)

jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -582,7 +582,7 @@ def _apply_mask_and_soft_cap(
582582
*,
583583
attn_logits_soft_cap: float,
584584
k_slice: pl.Slice,
585-
k_offset: int,
585+
k_offset: int | jax.Array,
586586
bq: int,
587587
k_in_lanes=True,
588588
mask_function=None,

0 commit comments

Comments
 (0)