@@ -378,13 +378,20 @@ def lower_jaxpr_to_module(
378
378
raise NotImplementedError ("Index map jaxpr with consts not supported." )
379
379
# ANY operands don't support windowing and require empty window_params.
380
380
if aval .memory_space == tpu_core .TPUMemorySpace .ANY :
381
- requires_windowing = bm .block_shape != full_ty .shape
382
- for atom in bm .index_map_jaxpr .jaxpr .outvars :
383
- if requires_windowing :
384
- break
385
- requires_windowing = not (
386
- isinstance (atom , jax_core .Literal ) and atom .val == 0
387
- )
381
+ # We may not require windowing if our block_shape matches the original
382
+ # shape or the dimensions are mapped.
383
+ requires_windowing = any (
384
+ b != s
385
+ for b , s in zip (bm .block_shape , full_ty .shape )
386
+ if not (b is pl_core .mapped and s == 1 )
387
+ )
388
+ if np .prod (grid ) != 1 :
389
+ for atom in bm .index_map_jaxpr .jaxpr .outvars :
390
+ if requires_windowing :
391
+ break
392
+ requires_windowing = not (
393
+ isinstance (atom , jax_core .Literal ) and atom .val == 0
394
+ )
388
395
if requires_windowing :
389
396
raise NotImplementedError (
390
397
"Operands in placed in the TPUMemorySpace.ANY memory space don't"
@@ -823,8 +830,10 @@ def _slice_memref(ref: ir.Value, ref_aval: state.AbstractRef,
823
830
out = tpu .MemRefSliceOp (target_ref_ty , ref , starts , dynamic_sizes ).result
824
831
if any (squeeze_dims ):
825
832
# We need to squeeze out some dimensions
833
+ static_sizes = tuple (s if not isinstance (s , ir .Value )
834
+ else ir_dynamic_size for s in target_shape )
826
835
squeezed_ref_ty = ir .MemRefType .get (
827
- tuple ( target_shape ) , _dtype_to_ir_type (ref_aval .dtype ),
836
+ static_sizes , _dtype_to_ir_type (ref_aval .dtype ),
828
837
memory_space = ref .type .memory_space )
829
838
out = tpu .MemRefSqueezeOp (squeezed_ref_ty , out ).result
830
839
return out , ref_block_shape
0 commit comments