Skip to content

Commit 9bf1148

Browse files
tlongerijax authors
authored andcommitted
[Mosaic] Always define tiling as (1, 128) for 1D loaded or stored vectors (not for the memref), instead of sometimes using (1, 128 * n).
They are equivalent - the way values are laid out is the same - but relayouts check specifically for (1, 128). We define (1, 128) to be canonical. PiperOrigin-RevId: 629748121
1 parent 26049b1 commit 9bf1148

File tree

1 file changed

+9
-4
lines changed

1 file changed

+9
-4
lines changed

jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1025,10 +1025,12 @@ class VectorLayoutInferer {
10251025
TPU_CHECK_OP(tile % target_shape_[1] == 0,
10261026
"Unsupported tiling for 1D load");
10271027
CHECK_EQ(tile_offsets.size(), 1);
1028+
// TODO(tlongeri): Also pick a unique (canonical) tiling for packed types
1029+
const int64_t lane_tiling = bitwidth == 32 ? target_shape_[1] : tile;
10281030
// TODO(apaszke): We could generate replicated loads for short values.
10291031
setLayout(op, in_layout,
1030-
VectorLayout(bitwidth, {0, tile_offsets[0]}, {1, tile},
1031-
ImplicitDim::kSecondMinor));
1032+
VectorLayout(bitwidth, {0, tile_offsets[0] % lane_tiling},
1033+
{1, lane_tiling}, ImplicitDim::kSecondMinor));
10321034
} else { // rank >= 2
10331035
TPU_CHECK_OP(tiling.size() == 2, "Expected 2D tiling in 2D+ loads");
10341036
CHECK_EQ(tile_offsets.size(), 2);
@@ -1366,9 +1368,12 @@ class VectorLayoutInferer {
13661368
auto tile = tiling.front();
13671369
TPU_CHECK_OP(tile % target_shape_[1] == 0,
13681370
"Unsupported 1D tiling for 1D store");
1371+
// TODO(tlongeri): Also pick a unique (canonical) tiling for packed types
1372+
const int64_t lane_tiling = bitwidth == 32 ? target_shape_[1] : tile;
13691373
CHECK_EQ(tile_offsets.size(), 1);
1370-
store_layout = VectorLayout(bitwidth, {0, tile_offsets[0]}, {1, tile},
1371-
ImplicitDim::kSecondMinor);
1374+
store_layout =
1375+
VectorLayout(bitwidth, {0, tile_offsets[0] % lane_tiling},
1376+
{1, lane_tiling}, ImplicitDim::kSecondMinor);
13721377
} else { // rank >= 2 // NOLINT(readability-else-after-return)
13731378
TPU_CHECK_OP(tiling.size() == 2, "Expected 2D tiling in 2D+ store");
13741379
CHECK_EQ(tile_offsets.size(), 2);

0 commit comments

Comments
 (0)