@@ -4349,38 +4349,56 @@ FailureOr<TypedValue<VectorType>> relayout(
4349
4349
return emitError (v.getLoc (),
4350
4350
" Not implemented: Both columns and rows are shifted" );
4351
4351
}
4352
- if (col_diff < 0 ) {
4353
- return emitError (v.getLoc (), " Not implemented: Shifts to the left" );
4354
- }
4355
4352
if (bitwidth != 32 || tiling != target_shape) {
4356
4353
return emitError (v.getLoc (),
4357
4354
" Not implemented: Only 32-bit column shifts for "
4358
4355
" native layouts supported" );
4359
4356
}
4360
- const int64_t sublane_diff = col_diff;
4361
4357
TPU_ASSERT_GE_LOC (v.getLoc (), src_tiles.num_dimensions (), 1 );
4362
4358
std::optional<tpu::CreateMaskOp> maybe_create_mask;
4363
- if (src_tiles.dimensions ()[src_tiles.num_dimensions () - 1 ] > 1 ) {
4359
+ if (*(src_tiles.dimensions ().end () - 1 ) > 1 ) {
4360
+ int64_t lane_start, lane_end;
4361
+ if (col_diff > 0 ) {
4362
+ lane_start = 0 ;
4363
+ lane_end = col_diff;
4364
+ } else { // col_diff < 0
4365
+ lane_start = target_shape[1 ] + col_diff;
4366
+ lane_end = target_shape[1 ];
4367
+ }
4364
4368
auto boundIdxConst =
4365
4369
std::bind (IdxConst, std::placeholders::_1, builder, v.getLoc ());
4366
4370
maybe_create_mask = builder.create <tpu::CreateMaskOp>(
4367
4371
v.getLoc (), VectorType::get (target_shape, builder.getI1Type ()),
4368
- ValueRange{boundIdxConst (0 ), boundIdxConst (0 )},
4372
+ ValueRange{boundIdxConst (0 ), boundIdxConst (lane_start )},
4369
4373
ValueRange{boundIdxConst (target_shape[0 ]),
4370
- boundIdxConst (col_diff )});
4374
+ boundIdxConst (lane_end )});
4371
4375
}
4372
- src_tiles.Each ([&](absl::Span<const int64_t > idx, Value tile) {
4373
- Value rot_tile =
4374
- builder
4375
- .create <tpu::RotateOp>(v.getLoc (), tile,
4376
- /* amount=*/ sublane_diff,
4377
- /* dimension=*/ 1 , /* stride=*/ nullptr ,
4378
- /* stride_dimension=*/ nullptr )
4379
- .getResult ();
4380
- if (idx[idx.size () - 1 ] != 0 ) {
4381
- SmallVector<int64_t > prev_idx (idx.begin (), idx.end ());
4382
- --prev_idx[idx.size () - 1 ];
4383
- Value prev_rot_tile = dst_tiles (prev_idx);
4376
+ src_tiles.Each ([&](absl::Span<const int64_t > idx, Value *tile) {
4377
+ *tile = builder
4378
+ .create <tpu::RotateOp>(v.getLoc (), *tile,
4379
+ /* amount=*/ col_diff < 0
4380
+ ? target_shape[1 ] + col_diff
4381
+ : col_diff,
4382
+ /* dimension=*/ 1 , /* stride=*/ nullptr ,
4383
+ /* stride_dimension=*/ nullptr )
4384
+ .getResult ();
4385
+ });
4386
+ src_tiles.Each ([&](absl::Span<const int64_t > idx, Value rot_tile) {
4387
+ Value prev_rot_tile;
4388
+ if (col_diff > 0 ) {
4389
+ if (*(idx.end () - 1 ) != 0 ) {
4390
+ SmallVector<int64_t > prev_idx (idx.begin (), idx.end ());
4391
+ --*(prev_idx.end () - 1 );
4392
+ prev_rot_tile = src_tiles (prev_idx);
4393
+ }
4394
+ } else { // col_diff < 0
4395
+ if (*(idx.end () - 1 ) != *(src_tiles.dimensions ().end () - 1 ) - 1 ) {
4396
+ SmallVector<int64_t > prev_idx (idx.begin (), idx.end ());
4397
+ ++*(prev_idx.end () - 1 );
4398
+ prev_rot_tile = src_tiles (prev_idx);
4399
+ }
4400
+ }
4401
+ if (prev_rot_tile != nullptr ) {
4384
4402
rot_tile = builder.create <arith::SelectOp>(
4385
4403
v.getLoc (), maybe_create_mask->getResult (), prev_rot_tile,
4386
4404
rot_tile);
0 commit comments