Skip to content

Commit 318ae89

Browse files
sharadmvjax authors
authored andcommitted
[Pallas TPU] Relax windowing restriction when lowering mapped grids
PiperOrigin-RevId: 621330022
1 parent f74f4ed commit 318ae89

File tree

1 file changed

+17
-8
lines changed

1 file changed

+17
-8
lines changed

jax/_src/pallas/mosaic/lowering.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -378,13 +378,20 @@ def lower_jaxpr_to_module(
378378
raise NotImplementedError("Index map jaxpr with consts not supported.")
379379
# ANY operands don't support windowing and require empty window_params.
380380
if aval.memory_space == tpu_core.TPUMemorySpace.ANY:
381-
requires_windowing = bm.block_shape != full_ty.shape
382-
for atom in bm.index_map_jaxpr.jaxpr.outvars:
383-
if requires_windowing:
384-
break
385-
requires_windowing = not (
386-
isinstance(atom, jax_core.Literal) and atom.val == 0
387-
)
381+
# We may not require windowing if our block_shape matches the original
382+
# shape or the dimensions are mapped.
383+
requires_windowing = any(
384+
b != s
385+
for b, s in zip(bm.block_shape, full_ty.shape)
386+
if not (b is pl_core.mapped and s == 1)
387+
)
388+
if np.prod(grid) != 1:
389+
for atom in bm.index_map_jaxpr.jaxpr.outvars:
390+
if requires_windowing:
391+
break
392+
requires_windowing = not (
393+
isinstance(atom, jax_core.Literal) and atom.val == 0
394+
)
388395
if requires_windowing:
389396
raise NotImplementedError(
390397
"Operands in placed in the TPUMemorySpace.ANY memory space don't"
@@ -823,8 +830,10 @@ def _slice_memref(ref: ir.Value, ref_aval: state.AbstractRef,
823830
out = tpu.MemRefSliceOp(target_ref_ty, ref, starts, dynamic_sizes).result
824831
if any(squeeze_dims):
825832
# We need to squeeze out some dimensions
833+
static_sizes = tuple(s if not isinstance(s, ir.Value)
834+
else ir_dynamic_size for s in target_shape)
826835
squeezed_ref_ty = ir.MemRefType.get(
827-
tuple(target_shape), _dtype_to_ir_type(ref_aval.dtype),
836+
static_sizes, _dtype_to_ir_type(ref_aval.dtype),
828837
memory_space=ref.type.memory_space)
829838
out = tpu.MemRefSqueezeOp(squeezed_ref_ty, out).result
830839
return out, ref_block_shape

0 commit comments

Comments
 (0)