Skip to content

Commit 44a4b04

Browse files
authored
Add type information to Pallas primatives.
1 parent 301c351 commit 44a4b04

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

jax/_src/pallas/primitives.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@
4848

4949
program_id_p = jax_core.Primitive("program_id")
5050

51-
def program_id(axis):
51+
def program_id(axis: int) -> jax.Array:
5252
return program_id_p.bind(axis=axis)
5353

5454
def program_id_bind(*, axis: int):
@@ -70,7 +70,7 @@ def _program_id_abstract_eval(**_):
7070

7171
num_programs_p = jax_core.Primitive("num_programs")
7272

73-
def num_programs(axis):
73+
def num_programs(axis: int) -> jax.Array:
7474
return num_programs_p.bind(axis=axis)
7575

7676
@num_programs_p.def_custom_bind
@@ -223,7 +223,7 @@ def _max_contiguous_abstract_eval(aval, **_):
223223
multiple_of_p.def_impl(lambda x, **_: x)
224224
mlir.register_lowering(multiple_of_p, lambda _, x, **__: [x])
225225

226-
def multiple_of(x, values):
226+
def multiple_of(x: jax.Array, values: list[int] | int) -> jax.Array:
227227
if not isinstance(values, list):
228228
values = [values]
229229
return multiple_of_p.bind(x, values=values)

0 commit comments

Comments
 (0)