@@ -395,15 +395,32 @@ struct LinearizeVectorShuffle final
395
395
}
396
396
};
397
397
398
- // / This pattern converts the ExtractOp to a ShuffleOp that works on a
399
- // / linearized vector.
400
- // / Following,
401
- // / vector.extract %source [ position ]
402
- // / is converted to :
403
- // / %source_1d = vector.shape_cast %source
404
- // / %out_1d = vector.shuffle %source_1d, %source_1d [ shuffle_indices_1d ]
405
- // / %out_nd = vector.shape_cast %out_1d
406
- // / `shuffle_indices_1d` is computed using the position of the original extract.
398
+ // / This pattern linearizes `vector.extract` operations. It generates a 1-D
399
+ // / version of the `vector.extract` operation when extracting a scalar from a
400
+ // / vector. It generates a 1-D `vector.shuffle` operation when extracting a
401
+ // / subvector from a larger vector.
402
+ // /
403
+ // / Example #1:
404
+ // /
405
+ // / %0 = vector.extract %arg0[1]: vector<8x2xf32> from vector<2x8x2xf32>
406
+ // /
407
+ // / is converted to:
408
+ // /
409
+ // / %0 = vector.shape_cast %arg0 : vector<2x8x2xf32> to vector<32xf32>
410
+ // / %1 = vector.shuffle %0, %0 [16, 17, 18, 19, 20, 21, 22, 23,
411
+ // / 24, 25, 26, 27, 28, 29, 30, 31] :
412
+ // / vector<32xf32>, vector<32xf32>
413
+ // / %2 = vector.shape_cast %1 : vector<16xf32> to vector<8x2xf32>
414
+ // /
415
+ // / Example #2:
416
+ // /
417
+ // / %0 = vector.extract %arg0[1, 2] : i32 from vector<2x4xi32>
418
+ // /
419
+ // / is converted to:
420
+ // /
421
+ // / %0 = vector.shape_cast %arg0 : vector<2x4xi32> to vector<8xi32>
422
+ // / %1 = vector.extract %0[6] : i32 from vector<8xi32>
423
+ // /
407
424
struct LinearizeVectorExtract final
408
425
: public OpConversionPattern<vector::ExtractOp> {
409
426
using OpConversionPattern::OpConversionPattern;
@@ -413,10 +430,6 @@ struct LinearizeVectorExtract final
413
430
LogicalResult
414
431
matchAndRewrite (vector::ExtractOp extractOp, OpAdaptor adaptor,
415
432
ConversionPatternRewriter &rewriter) const override {
416
- // Skip if result is not a vector type
417
- if (!isa<VectorType>(extractOp.getType ()))
418
- return rewriter.notifyMatchFailure (extractOp,
419
- " scalar extract not supported" );
420
433
Type dstTy = getTypeConverter ()->convertType (extractOp.getType ());
421
434
assert (dstTy && " expected 1-D vector type" );
422
435
@@ -436,10 +449,21 @@ struct LinearizeVectorExtract final
436
449
linearizedOffset += offsets[i] * size;
437
450
}
438
451
452
+ Value srcVector = adaptor.getVector ();
453
+ if (!isa<VectorType>(extractOp.getType ())) {
454
+ // Scalar case: generate a 1-D extract.
455
+ Value result = rewriter.createOrFold <vector::ExtractOp>(
456
+ extractOp.getLoc (), srcVector, linearizedOffset);
457
+ rewriter.replaceOp (extractOp, result);
458
+ return success ();
459
+ }
460
+
461
+ // Vector case: generate a shuffle.
462
+
439
463
llvm::SmallVector<int64_t , 2 > indices (size);
440
464
std::iota (indices.begin (), indices.end (), linearizedOffset);
441
- rewriter.replaceOpWithNewOp <vector::ShuffleOp>(
442
- extractOp, dstTy, adaptor. getVector (), adaptor. getVector () , indices);
465
+ rewriter.replaceOpWithNewOp <vector::ShuffleOp>(extractOp, dstTy, srcVector,
466
+ srcVector , indices);
443
467
444
468
return success ();
445
469
}
0 commit comments