Skip to content

Commit d44b16c

Browse files
bythew3ijax authors
authored andcommitted
[XLA:Mosaic] Generalize (8,128) -> (8 * packing,128) retiling for packed type.
PiperOrigin-RevId: 625816937
1 parent 7cb0e60 commit d44b16c

File tree

1 file changed

+18
-48
lines changed

1 file changed

+18
-48
lines changed

jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc

Lines changed: 18 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -4367,63 +4367,33 @@ FailureOr<TypedValue<VectorType>> relayout(
43674367
});
43684368
src = dst;
43694369
src_tiles = std::move(src_tiles_retiled);
4370-
} else if ( // TODO(b/265133506): Generalize retiling to general 16-bit types
4371-
// (might need to use a different unpacking op).
4372-
// (8,128) -> (16,128) tiling change for packed 16-bit types.
4370+
} else if ( // TODO(b/265133506): Generalize retiling.
4371+
// (8,128) -> (8 * packing,128) tiling change for packed type.
43734372
src.implicit_dim() == VectorLayout::ImplicitDim::kNone &&
43744373
dst.implicit_dim() == VectorLayout::ImplicitDim::kNone &&
4375-
vty.getElementTypeBitWidth() == 16 && src.offsets() == dst.offsets() &&
4376-
src.tiling() == std::array<int64_t, 2>{8, 128} &&
4377-
dst.tiling() == std::array<int64_t, 2>{16, 128}) {
4378-
const VectorLayout new_src(src.bitwidth(), src.offsets(), dst.tiling());
4379-
xla::Array<Value> src_tiles_retiled(
4380-
new_src.tileArrayShape(vty.getShape(), target_shape));
4381-
src_tiles_retiled.Each([&](absl::Span<const int64_t> idx, Value *tile) {
4382-
SmallVector<int64_t> src_idx(idx.begin(), idx.end());
4383-
src_idx[src_idx.size() - 2] *= 2;
4384-
src_idx[src_idx.size() - 1] /= 2;
4385-
Value src_row1 = src_tiles(src_idx);
4386-
if (src_idx[src_idx.size() - 2] + 1 <
4387-
src_tiles.dim(src_tiles.num_dimensions() - 2)) {
4388-
++src_idx[src_idx.size() - 2];
4389-
}
4390-
Value src_row2 = src_tiles(src_idx);
4391-
const int vreg_part = idx[idx.size() - 1] % 2;
4392-
4393-
VectorType vreg_x32 =
4394-
vty.getElementType().isSignlessInteger()
4395-
? VectorType::get(target_shape, builder.getI32Type())
4396-
: VectorType::get(target_shape, builder.getF32Type());
4397-
auto half_row1 = builder.create<tpu::UnpackSubelementsOp>(
4398-
v.getLoc(), vreg_x32, src_row1, vreg_part);
4399-
auto half_row2 = builder.create<tpu::UnpackSubelementsOp>(
4400-
v.getLoc(), vreg_x32, src_row2, vreg_part);
4401-
*tile = builder.create<tpu::PackSubelementsOp>(
4402-
v.getLoc(), src_row1.getType(), ValueRange{half_row1, half_row2});
4403-
});
4404-
src = new_src;
4405-
src_tiles = std::move(src_tiles_retiled);
4406-
} else if ( // (8,128) -> (32,128) tiling change for packed 8-bit integers.
4407-
src.implicit_dim() == VectorLayout::ImplicitDim::kNone &&
4408-
dst.implicit_dim() == VectorLayout::ImplicitDim::kNone &&
4409-
vty.getElementType() == builder.getI8Type() &&
4374+
vty.getElementTypeBitWidth() < 32 &&
4375+
32 % vty.getElementTypeBitWidth() == 0 &&
44104376
src.offsets() == dst.offsets() &&
44114377
src.tiling() == std::array<int64_t, 2>{8, 128} &&
4412-
dst.tiling() == std::array<int64_t, 2>{32, 128}) {
4378+
dst.tiling() == std::array<int64_t, 2>{8 * dst.packing(), 128}) {
44134379
const VectorLayout new_src(src.bitwidth(), src.offsets(), dst.tiling());
44144380
xla::Array<Value> src_tiles_retiled(
44154381
new_src.tileArrayShape(vty.getShape(), target_shape));
4416-
VectorType vreg_i32 =
4417-
getNativeVregType(builder.getI32Type(), target_shape).value();
4382+
int vty_packing = dst.packing();
4383+
VectorType vreg_x32 =
4384+
vty.getElementType().isSignlessInteger()
4385+
? VectorType::get(target_shape, builder.getI32Type())
4386+
: VectorType::get(target_shape, builder.getF32Type());
44184387
src_tiles_retiled.Each([&](absl::Span<const int64_t> idx, Value *tile) {
4419-
const int vreg_part = idx.back() % 4;
4420-
std::array<Value, 4> parts;
4388+
const int vreg_part = idx.back() % vty_packing;
4389+
SmallVector<Value, 8> parts;
4390+
parts.reserve(vty_packing);
44214391
SmallVector<int64_t> src_idx(idx.begin(), idx.end());
4422-
src_idx[src_idx.size() - 2] *= 4;
4423-
src_idx[src_idx.size() - 1] /= 4;
4424-
for (int i = 0; i < 4; ++i) {
4425-
parts[i] = builder.create<tpu::UnpackSubelementsOp>(
4426-
v.getLoc(), vreg_i32, src_tiles(src_idx), vreg_part);
4392+
src_idx[src_idx.size() - 2] *= vty_packing;
4393+
src_idx[src_idx.size() - 1] /= vty_packing;
4394+
for (int i = 0; i < vty_packing; ++i) {
4395+
parts.push_back(builder.create<tpu::UnpackSubelementsOp>(
4396+
v.getLoc(), vreg_x32, src_tiles(src_idx), vreg_part));
44274397
if (src_idx[src_idx.size() - 2] <
44284398
src_tiles.dim(src_tiles.num_dimensions() - 2) - 1) {
44294399
++src_idx[src_idx.size() - 2];

0 commit comments

Comments
 (0)