@@ -395,15 +395,32 @@ struct LinearizeVectorShuffle final
395
395
}
396
396
};
397
397
398
- // / This pattern converts the ExtractOp to a ShuffleOp that works on a
399
- // / linearized vector.
400
- // / Following,
401
- // / vector.extract %source [ position ]
402
- // / is converted to :
403
- // / %source_1d = vector.shape_cast %source
404
- // / %out_1d = vector.shuffle %source_1d, %source_1d [ shuffle_indices_1d ]
405
- // / %out_nd = vector.shape_cast %out_1d
406
- // / `shuffle_indices_1d` is computed using the position of the original extract.
398
+ // / This pattern linearizes `vector.extract` operations. It generates a 1-D
399
+ // / version of the `vector.extract` operation when extracting a scalar from a
400
+ // / vector. It generates a 1-D `vector.shuffle` operation when extracting a
401
+ // / subvector from a larger vector.
402
+ // /
403
+ // / Example #1:
404
+ // /
405
+ // / %0 = vector.extract %arg0[1]: vector<8x2xf32> from vector<2x8x2xf32>
406
+ // /
407
+ // / is converted to:
408
+ // /
409
+ // / %0 = vector.shape_cast %arg0 : vector<2x8x2xf32> to vector<32xf32>
410
+ // / %1 = vector.shuffle %0, %0 [16, 17, 18, 19, 20, 21, 22, 23,
411
+ // / 24, 25, 26, 27, 28, 29, 30, 31] :
412
+ // / vector<32xf32>, vector<32xf32>
413
+ // / %2 = vector.shape_cast %1 : vector<16xf32> to vector<8x2xf32>
414
+ // /
415
+ // / Example #2:
416
+ // /
417
+ // / %0 = vector.extract %arg0[1, 2] : i32 from vector<2x4xi32>
418
+ // /
419
+ // / is converted to:
420
+ // /
421
+ // / %0 = vector.shape_cast %arg0 : vector<2x4xi32> to vector<8xi32>
422
+ // / %1 = vector.extract %0[6] : i32 from vector<8xi32>
423
+ // /
407
424
struct LinearizeVectorExtract final
408
425
: public OpConversionPattern<vector::ExtractOp> {
409
426
using OpConversionPattern::OpConversionPattern;
@@ -413,10 +430,6 @@ struct LinearizeVectorExtract final
413
430
LogicalResult
414
431
matchAndRewrite (vector::ExtractOp extractOp, OpAdaptor adaptor,
415
432
ConversionPatternRewriter &rewriter) const override {
416
- // Skip if result is not a vector type
417
- if (!isa<VectorType>(extractOp.getType ()))
418
- return rewriter.notifyMatchFailure (extractOp,
419
- " scalar extract not supported" );
420
433
Type dstTy = getTypeConverter ()->convertType (extractOp.getType ());
421
434
assert (dstTy && " expected 1-D vector type" );
422
435
@@ -436,6 +449,16 @@ struct LinearizeVectorExtract final
436
449
linearizedOffset += offsets[i] * size;
437
450
}
438
451
452
+ if (!isa<VectorType>(extractOp.getType ())) {
453
+ // Scalar case: generate a 1-D extract.
454
+ Value result = rewriter.createOrFold <vector::ExtractOp>(
455
+ extractOp.getLoc (), adaptor.getVector (), linearizedOffset);
456
+ rewriter.replaceOp (extractOp, result);
457
+ return success ();
458
+ }
459
+
460
+ // Vector case: generate a shuffle.
461
+
439
462
llvm::SmallVector<int64_t , 2 > indices (size);
440
463
std::iota (indices.begin (), indices.end (), linearizedOffset);
441
464
rewriter.replaceOpWithNewOp <vector::ShuffleOp>(
0 commit comments