@@ -445,16 +445,36 @@ struct LinearizeVectorExtract final
445
445
}
446
446
};
447
447
448
- // / This pattern converts the InsertOp to a ShuffleOp that works on a
449
- // / linearized vector.
450
- // / Following,
451
- // / vector.insert %source %destination [ position ]
452
- // / is converted to :
453
- // / %source_1d = vector.shape_cast %source
454
- // / %destination_1d = vector.shape_cast %destination
455
- // / %out_1d = vector.shuffle %destination_1d, %source_1d [ shuffle_indices_1d
456
- // / ] %out_nd = vector.shape_cast %out_1d
457
- // / `shuffle_indices_1d` is computed using the position of the original insert.
448
+ // / This pattern linearizes `vector.insert` operations. It generates a 1-D
449
+ // / version of the `vector.insert` operation when inserting a scalar into a
450
+ // / vector. It generates a 1-D `vector.shuffle` operation when inserting a
451
+ // / vector into another vector.
452
+ // /
453
+ // / Example #1:
454
+ // /
455
+ // / %0 = vector.insert %source, %destination[0] :
456
+ // / vector<2x4xf32> into vector<2x2x4xf32>
457
+ // /
458
+ // / is converted to:
459
+ // /
460
+ // / %0 = vector.shape_cast %source : vector<2x4xf32> to vector<8xf32>
461
+ // / %1 = vector.shape_cast %destination :
462
+ // / vector<2x2x4xf32> to vector<16xf32>
463
+ // / %2 = vector.shuffle %1, %0 [16, 17, 18, 19, 20, 21, 22, 23
464
+ // / 8, 9, 10, 11, 12, 13, 14, 15] :
465
+ // / vector<16xf32>, vector<8xf32>
466
+ // / %3 = vector.shape_cast %2 : vector<16xf32> to vector<2x2x4xf32>
467
+ // /
468
+ // / Example #2:
469
+ // /
470
+ // / %0 = vector.insert %source, %destination[1, 2]: f32 into vector<2x4xf32>
471
+ // /
472
+ // / is converted to:
473
+ // /
474
+ // / %0 = vector.shape_cast %destination : vector<2x4xf32> to vector<8xf32>
475
+ // / %1 = vector.insert %source, %0[6]: f32 into vector<8xf32>
476
+ // / %2 = vector.shape_cast %1 : vector<8xf32> to vector<2x4xf32>
477
+ // /
458
478
struct LinearizeVectorInsert final
459
479
: public OpConversionPattern<vector::InsertOp> {
460
480
using OpConversionPattern::OpConversionPattern;
@@ -468,48 +488,55 @@ struct LinearizeVectorInsert final
468
488
insertOp.getDestVectorType ());
469
489
assert (dstTy && " vector type destination expected." );
470
490
471
- // dynamic position is not supported
491
+ // Dynamic position is not supported.
472
492
if (insertOp.hasDynamicPosition ())
473
493
return rewriter.notifyMatchFailure (insertOp,
474
494
" dynamic position is not supported." );
475
495
auto srcTy = insertOp.getValueToStoreType ();
476
496
auto srcAsVec = dyn_cast<VectorType>(srcTy);
477
- uint64_t srcSize = 0 ;
478
- if (srcAsVec) {
479
- srcSize = srcAsVec.getNumElements ();
480
- } else {
481
- return rewriter.notifyMatchFailure (insertOp,
482
- " scalars are not supported." );
483
- }
497
+ uint64_t srcSize = srcAsVec ? srcAsVec.getNumElements () : 1 ;
484
498
485
499
auto dstShape = insertOp.getDestVectorType ().getShape ();
486
500
const auto dstSize = insertOp.getDestVectorType ().getNumElements ();
487
501
auto dstSizeForOffsets = dstSize;
488
502
489
- // compute linearized offset
503
+ // Compute linearized offset.
490
504
int64_t linearizedOffset = 0 ;
491
505
auto offsetsNd = insertOp.getStaticPosition ();
492
506
for (auto [dim, offset] : llvm::enumerate (offsetsNd)) {
493
507
dstSizeForOffsets /= dstShape[dim];
494
508
linearizedOffset += offset * dstSizeForOffsets;
495
509
}
496
510
511
+ Location loc = insertOp.getLoc ();
512
+ Value valueToStore = adaptor.getValueToStore ();
513
+
514
+ if (!isa<VectorType>(valueToStore.getType ())) {
515
+ // Scalar case: generate a 1-D insert.
516
+ Value result = rewriter.createOrFold <vector::InsertOp>(
517
+ loc, valueToStore, adaptor.getDest (), linearizedOffset);
518
+ rewriter.replaceOp (insertOp, result);
519
+ return success ();
520
+ }
521
+
522
+ // Vector case: generate a shuffle.
497
523
llvm::SmallVector<int64_t , 2 > indices (dstSize);
498
524
auto *origValsUntil = indices.begin ();
499
525
std::advance (origValsUntil, linearizedOffset);
500
- std::iota (indices.begin (), origValsUntil,
501
- 0 ); // original values that remain [0, offset)
526
+
527
+ // Original values that remain [0, offset).
528
+ std::iota (indices.begin (), origValsUntil, 0 );
502
529
auto *newValsUntil = origValsUntil;
503
530
std::advance (newValsUntil, srcSize);
504
- std::iota (origValsUntil, newValsUntil,
505
- dstSize); // new values [offset, offset+srcNumElements)
506
- std::iota (newValsUntil, indices.end (),
507
- linearizedOffset + srcSize); // the rest of original values
508
- // [offset+srcNumElements, end)
531
+ // New values [offset, offset+srcNumElements).
532
+ std::iota (origValsUntil, newValsUntil, dstSize);
533
+ // The rest of original values [offset+srcNumElements, end);
534
+ std::iota (newValsUntil, indices.end (), linearizedOffset + srcSize);
509
535
510
- rewriter.replaceOpWithNewOp <vector::ShuffleOp>(
511
- insertOp , dstTy, adaptor.getDest (), adaptor. getValueToStore () , indices);
536
+ Value result = rewriter.createOrFold <vector::ShuffleOp>(
537
+ loc , dstTy, adaptor.getDest (), valueToStore , indices);
512
538
539
+ rewriter.replaceOp (insertOp, result);
513
540
return success ();
514
541
}
515
542
};
0 commit comments