@@ -120,7 +120,9 @@ def _tree_map_with_kwargs(f, *args, **kwargs):
120
120
)
121
121
122
122
123
- def _get_next_indices (grid : core .StaticGrid , indices : GridIndices ) -> GridIndices :
123
+ def _get_next_indices (
124
+ grid : core .DynamicGrid , indices : GridIndices
125
+ ) -> GridIndices :
124
126
"""Takes a grid and current indices and returns the next indices.
125
127
126
128
grid: (3, 4, 5)
@@ -135,7 +137,7 @@ def _get_next_indices(grid: core.StaticGrid, indices: GridIndices) -> GridIndice
135
137
Incremented indices.
136
138
"""
137
139
next_indices = []
138
- carry = True
140
+ carry : Union [ bool , jax . Array ] = True
139
141
for dim_size , index in reversed (list (zip (grid , indices ))):
140
142
i = jnp .where (carry , index + 1 , index )
141
143
carry = dim_size == i
@@ -494,7 +496,7 @@ def __call__(
494
496
def emit_pipeline_with_allocations (
495
497
body : PipelineBody ,
496
498
* ,
497
- grid : core .StaticGrid ,
499
+ grid : core .DynamicGrid ,
498
500
in_specs : PipelineBlockSpecs ,
499
501
out_specs : PipelineBlockSpecs ,
500
502
should_accumulate_out : Union [Sequence [bool ], Any ] = False ,
@@ -1236,7 +1238,7 @@ def set_buffer_ref(buffer_ref, buffer):
1236
1238
def emit_pipeline (
1237
1239
body : PipelineBody ,
1238
1240
* ,
1239
- grid : core .StaticGrid ,
1241
+ grid : core .DynamicGrid ,
1240
1242
in_specs : PipelineBlockSpecs ,
1241
1243
out_specs : PipelineBlockSpecs ,
1242
1244
should_accumulate_out : Union [Sequence [bool ], Any ] = False ,
0 commit comments