File tree Expand file tree Collapse file tree 1 file changed +3
-3
lines changed Expand file tree Collapse file tree 1 file changed +3
-3
lines changed Original file line number Diff line number Diff line change 48
48
49
49
program_id_p = jax_core .Primitive ("program_id" )
50
50
51
- def program_id (axis ) :
51
+ def program_id (axis : int ) -> jax . Array :
52
52
return program_id_p .bind (axis = axis )
53
53
54
54
def program_id_bind (* , axis : int ):
@@ -70,7 +70,7 @@ def _program_id_abstract_eval(**_):
70
70
71
71
num_programs_p = jax_core .Primitive ("num_programs" )
72
72
73
- def num_programs (axis ) :
73
+ def num_programs (axis : int ) -> jax . Array :
74
74
return num_programs_p .bind (axis = axis )
75
75
76
76
@num_programs_p .def_custom_bind
@@ -223,7 +223,7 @@ def _max_contiguous_abstract_eval(aval, **_):
223
223
multiple_of_p .def_impl (lambda x , ** _ : x )
224
224
mlir .register_lowering (multiple_of_p , lambda _ , x , ** __ : [x ])
225
225
226
- def multiple_of (x , values ) :
226
+ def multiple_of (x : jax . Array , values : list [ int ] | int ) -> jax . Array :
227
227
if not isinstance (values , list ):
228
228
values = [values ]
229
229
return multiple_of_p .bind (x , values = values )
You can’t perform that action at this time.
0 commit comments