@@ -4367,63 +4367,33 @@ FailureOr<TypedValue<VectorType>> relayout(
4367
4367
});
4368
4368
src = dst;
4369
4369
src_tiles = std::move (src_tiles_retiled);
4370
- } else if ( // TODO(b/265133506): Generalize retiling to general 16-bit types
4371
- // (might need to use a different unpacking op).
4372
- // (8,128) -> (16,128) tiling change for packed 16-bit types.
4370
+ } else if ( // TODO(b/265133506): Generalize retiling.
4371
+ // (8,128) -> (8 * packing,128) tiling change for packed type.
4373
4372
src.implicit_dim () == VectorLayout::ImplicitDim::kNone &&
4374
4373
dst.implicit_dim () == VectorLayout::ImplicitDim::kNone &&
4375
- vty.getElementTypeBitWidth () == 16 && src.offsets () == dst.offsets () &&
4376
- src.tiling () == std::array<int64_t , 2 >{8 , 128 } &&
4377
- dst.tiling () == std::array<int64_t , 2 >{16 , 128 }) {
4378
- const VectorLayout new_src (src.bitwidth (), src.offsets (), dst.tiling ());
4379
- xla::Array<Value> src_tiles_retiled (
4380
- new_src.tileArrayShape (vty.getShape (), target_shape));
4381
- src_tiles_retiled.Each ([&](absl::Span<const int64_t > idx, Value *tile) {
4382
- SmallVector<int64_t > src_idx (idx.begin (), idx.end ());
4383
- src_idx[src_idx.size () - 2 ] *= 2 ;
4384
- src_idx[src_idx.size () - 1 ] /= 2 ;
4385
- Value src_row1 = src_tiles (src_idx);
4386
- if (src_idx[src_idx.size () - 2 ] + 1 <
4387
- src_tiles.dim (src_tiles.num_dimensions () - 2 )) {
4388
- ++src_idx[src_idx.size () - 2 ];
4389
- }
4390
- Value src_row2 = src_tiles (src_idx);
4391
- const int vreg_part = idx[idx.size () - 1 ] % 2 ;
4392
-
4393
- VectorType vreg_x32 =
4394
- vty.getElementType ().isSignlessInteger ()
4395
- ? VectorType::get (target_shape, builder.getI32Type ())
4396
- : VectorType::get (target_shape, builder.getF32Type ());
4397
- auto half_row1 = builder.create <tpu::UnpackSubelementsOp>(
4398
- v.getLoc (), vreg_x32, src_row1, vreg_part);
4399
- auto half_row2 = builder.create <tpu::UnpackSubelementsOp>(
4400
- v.getLoc (), vreg_x32, src_row2, vreg_part);
4401
- *tile = builder.create <tpu::PackSubelementsOp>(
4402
- v.getLoc (), src_row1.getType (), ValueRange{half_row1, half_row2});
4403
- });
4404
- src = new_src;
4405
- src_tiles = std::move (src_tiles_retiled);
4406
- } else if ( // (8,128) -> (32,128) tiling change for packed 8-bit integers.
4407
- src.implicit_dim () == VectorLayout::ImplicitDim::kNone &&
4408
- dst.implicit_dim () == VectorLayout::ImplicitDim::kNone &&
4409
- vty.getElementType () == builder.getI8Type () &&
4374
+ vty.getElementTypeBitWidth () < 32 &&
4375
+ 32 % vty.getElementTypeBitWidth () == 0 &&
4410
4376
src.offsets () == dst.offsets () &&
4411
4377
src.tiling () == std::array<int64_t , 2 >{8 , 128 } &&
4412
- dst.tiling () == std::array<int64_t , 2 >{32 , 128 }) {
4378
+ dst.tiling () == std::array<int64_t , 2 >{8 * dst. packing () , 128 }) {
4413
4379
const VectorLayout new_src (src.bitwidth (), src.offsets (), dst.tiling ());
4414
4380
xla::Array<Value> src_tiles_retiled (
4415
4381
new_src.tileArrayShape (vty.getShape (), target_shape));
4416
- VectorType vreg_i32 =
4417
- getNativeVregType (builder.getI32Type (), target_shape).value ();
4382
+ int vty_packing = dst.packing ();
4383
+ VectorType vreg_x32 =
4384
+ vty.getElementType ().isSignlessInteger ()
4385
+ ? VectorType::get (target_shape, builder.getI32Type ())
4386
+ : VectorType::get (target_shape, builder.getF32Type ());
4418
4387
src_tiles_retiled.Each ([&](absl::Span<const int64_t > idx, Value *tile) {
4419
- const int vreg_part = idx.back () % 4 ;
4420
- std::array<Value, 4 > parts;
4388
+ const int vreg_part = idx.back () % vty_packing;
4389
+ SmallVector<Value, 8 > parts;
4390
+ parts.reserve (vty_packing);
4421
4391
SmallVector<int64_t > src_idx (idx.begin (), idx.end ());
4422
- src_idx[src_idx.size () - 2 ] *= 4 ;
4423
- src_idx[src_idx.size () - 1 ] /= 4 ;
4424
- for (int i = 0 ; i < 4 ; ++i) {
4425
- parts[i] = builder.create <tpu::UnpackSubelementsOp>(
4426
- v.getLoc (), vreg_i32 , src_tiles (src_idx), vreg_part);
4392
+ src_idx[src_idx.size () - 2 ] *= vty_packing ;
4393
+ src_idx[src_idx.size () - 1 ] /= vty_packing ;
4394
+ for (int i = 0 ; i < vty_packing ; ++i) {
4395
+ parts. push_back ( builder.create <tpu::UnpackSubelementsOp>(
4396
+ v.getLoc (), vreg_x32 , src_tiles (src_idx), vreg_part) );
4427
4397
if (src_idx[src_idx.size () - 2 ] <
4428
4398
src_tiles.dim (src_tiles.num_dimensions () - 2 ) - 1 ) {
4429
4399
++src_idx[src_idx.size () - 2 ];
0 commit comments