Skip to content

Commit 7f7e0c0

Browse files
tlongerijax authors
authored andcommitted
[Mosaic] Support left shifting relayouts
PiperOrigin-RevId: 618008857
1 parent 8a2ba76 commit 7f7e0c0

File tree

1 file changed

+37
-19
lines changed

1 file changed

+37
-19
lines changed

jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc

Lines changed: 37 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -4349,38 +4349,56 @@ FailureOr<TypedValue<VectorType>> relayout(
43494349
return emitError(v.getLoc(),
43504350
"Not implemented: Both columns and rows are shifted");
43514351
}
4352-
if (col_diff < 0) {
4353-
return emitError(v.getLoc(), "Not implemented: Shifts to the left");
4354-
}
43554352
if (bitwidth != 32 || tiling != target_shape) {
43564353
return emitError(v.getLoc(),
43574354
"Not implemented: Only 32-bit column shifts for "
43584355
"native layouts supported");
43594356
}
4360-
const int64_t sublane_diff = col_diff;
43614357
TPU_ASSERT_GE_LOC(v.getLoc(), src_tiles.num_dimensions(), 1);
43624358
std::optional<tpu::CreateMaskOp> maybe_create_mask;
4363-
if (src_tiles.dimensions()[src_tiles.num_dimensions() - 1] > 1) {
4359+
if (*(src_tiles.dimensions().end() - 1) > 1) {
4360+
int64_t lane_start, lane_end;
4361+
if (col_diff > 0) {
4362+
lane_start = 0;
4363+
lane_end = col_diff;
4364+
} else { // col_diff < 0
4365+
lane_start = target_shape[1] + col_diff;
4366+
lane_end = target_shape[1];
4367+
}
43644368
auto boundIdxConst =
43654369
std::bind(IdxConst, std::placeholders::_1, builder, v.getLoc());
43664370
maybe_create_mask = builder.create<tpu::CreateMaskOp>(
43674371
v.getLoc(), VectorType::get(target_shape, builder.getI1Type()),
4368-
ValueRange{boundIdxConst(0), boundIdxConst(0)},
4372+
ValueRange{boundIdxConst(0), boundIdxConst(lane_start)},
43694373
ValueRange{boundIdxConst(target_shape[0]),
4370-
boundIdxConst(col_diff)});
4374+
boundIdxConst(lane_end)});
43714375
}
4372-
src_tiles.Each([&](absl::Span<const int64_t> idx, Value tile) {
4373-
Value rot_tile =
4374-
builder
4375-
.create<tpu::RotateOp>(v.getLoc(), tile,
4376-
/*amount=*/sublane_diff,
4377-
/*dimension=*/1, /*stride=*/nullptr,
4378-
/*stride_dimension=*/nullptr)
4379-
.getResult();
4380-
if (idx[idx.size() - 1] != 0) {
4381-
SmallVector<int64_t> prev_idx(idx.begin(), idx.end());
4382-
--prev_idx[idx.size() - 1];
4383-
Value prev_rot_tile = dst_tiles(prev_idx);
4376+
src_tiles.Each([&](absl::Span<const int64_t> idx, Value *tile) {
4377+
*tile = builder
4378+
.create<tpu::RotateOp>(v.getLoc(), *tile,
4379+
/*amount=*/col_diff < 0
4380+
? target_shape[1] + col_diff
4381+
: col_diff,
4382+
/*dimension=*/1, /*stride=*/nullptr,
4383+
/*stride_dimension=*/nullptr)
4384+
.getResult();
4385+
});
4386+
src_tiles.Each([&](absl::Span<const int64_t> idx, Value rot_tile) {
4387+
Value prev_rot_tile;
4388+
if (col_diff > 0) {
4389+
if (*(idx.end() - 1) != 0) {
4390+
SmallVector<int64_t> prev_idx(idx.begin(), idx.end());
4391+
--*(prev_idx.end() - 1);
4392+
prev_rot_tile = src_tiles(prev_idx);
4393+
}
4394+
} else { // col_diff < 0
4395+
if (*(idx.end() - 1) != *(src_tiles.dimensions().end() - 1) - 1) {
4396+
SmallVector<int64_t> prev_idx(idx.begin(), idx.end());
4397+
++*(prev_idx.end() - 1);
4398+
prev_rot_tile = src_tiles(prev_idx);
4399+
}
4400+
}
4401+
if (prev_rot_tile != nullptr) {
43844402
rot_tile = builder.create<arith::SelectOp>(
43854403
v.getLoc(), maybe_create_mask->getResult(), prev_rot_tile,
43864404
rot_tile);

0 commit comments

Comments
 (0)