Skip to content

Commit abfbb0a

Browse files
sharadmvjax authors
authored andcommitted
Add dynamic grid support to emit_pipeline
PiperOrigin-RevId: 623393190
1 parent 0d8eb45 commit abfbb0a

File tree

2 files changed

+9
-6
lines changed

2 files changed

+9
-6
lines changed

jax/_src/pallas/core.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,9 @@
3737
# mypy: ignore-errors
3838

3939
partial = functools.partial
40-
Grid = tuple[Union[int, None], ...] # None indicates that the bound is dynamic.
41-
StaticGrid = tuple[int, ...] # None indicates that the bound is dynamic.
40+
Grid = tuple[Union[int, jax_core.Array, None], ...] # None indicates that the bound is dynamic.
41+
DynamicGrid = tuple[Union[int, jax_core.Array], ...]
42+
StaticGrid = tuple[int, ...]
4243
split_list = util.split_list
4344

4445
map, unsafe_map = util.safe_map, map

jax/_src/pallas/mosaic/pipeline.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,9 @@ def _tree_map_with_kwargs(f, *args, **kwargs):
120120
)
121121

122122

123-
def _get_next_indices(grid: core.StaticGrid, indices: GridIndices) -> GridIndices:
123+
def _get_next_indices(
124+
grid: core.DynamicGrid, indices: GridIndices
125+
) -> GridIndices:
124126
"""Takes a grid and current indices and returns the next indices.
125127
126128
grid: (3, 4, 5)
@@ -135,7 +137,7 @@ def _get_next_indices(grid: core.StaticGrid, indices: GridIndices) -> GridIndice
135137
Incremented indices.
136138
"""
137139
next_indices = []
138-
carry = True
140+
carry: Union[bool, jax.Array] = True
139141
for dim_size, index in reversed(list(zip(grid, indices))):
140142
i = jnp.where(carry, index + 1, index)
141143
carry = dim_size == i
@@ -494,7 +496,7 @@ def __call__(
494496
def emit_pipeline_with_allocations(
495497
body: PipelineBody,
496498
*,
497-
grid: core.StaticGrid,
499+
grid: core.DynamicGrid,
498500
in_specs: PipelineBlockSpecs,
499501
out_specs: PipelineBlockSpecs,
500502
should_accumulate_out: Union[Sequence[bool], Any] = False,
@@ -1236,7 +1238,7 @@ def set_buffer_ref(buffer_ref, buffer):
12361238
def emit_pipeline(
12371239
body: PipelineBody,
12381240
*,
1239-
grid: core.StaticGrid,
1241+
grid: core.DynamicGrid,
12401242
in_specs: PipelineBlockSpecs,
12411243
out_specs: PipelineBlockSpecs,
12421244
should_accumulate_out: Union[Sequence[bool], Any] = False,

0 commit comments

Comments
 (0)