-
Notifications
You must be signed in to change notification settings - Fork 14.5k
[vector][mlir] Canonicalize to shape_cast where possible #140583
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 6 commits
cce1483
f2e5417
24f7531
e673522
aa99292
7ad1802
1ff3399
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2351,11 +2351,41 @@ LogicalResult foldExtractFromFromElements(ExtractOp extractOp, | |
return success(); | ||
} | ||
|
||
/// BEFORE: | ||
/// %0 = vector.extract %arg0[0] : vector<4xf32> from vector<1x4xf32> | ||
/// AFTER: | ||
/// %0 = vector.shape_cast %arg0 : vector<1x4xf32> to vector<4xf32> | ||
struct ExtractToShapeCast final : public OpRewritePattern<vector::ExtractOp> { | ||
using OpRewritePattern::OpRewritePattern; | ||
LogicalResult matchAndRewrite(vector::ExtractOp extractOp, | ||
PatternRewriter &rewriter) const override { | ||
VectorType sourceType = extractOp.getSourceVectorType(); | ||
VectorType outType = dyn_cast<VectorType>(extractOp.getType()); | ||
if (!outType) | ||
return failure(); | ||
|
||
// Negative values in `position` indicates poison, which cannot be | ||
// represented with a shape_cast | ||
if (llvm::any_of(extractOp.getMixedPosition(), | ||
[](OpFoldResult v) { return !isConstantIntValue(v, 0); })) | ||
return failure(); | ||
|
||
if (sourceType.getNumElements() != outType.getNumElements()) | ||
return failure(); | ||
|
||
rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(extractOp, outType, | ||
extractOp.getVector()); | ||
return success(); | ||
} | ||
}; | ||
|
||
} // namespace | ||
|
||
void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results, | ||
MLIRContext *context) { | ||
results.add<ExtractOpFromBroadcast, ExtractOpFromCreateMask>(context); | ||
results | ||
.add<ExtractOpFromBroadcast, ExtractOpFromCreateMask, ExtractToShapeCast>( | ||
context); | ||
results.add(foldExtractFromShapeCastToShapeCast); | ||
results.add(foldExtractFromFromElements); | ||
} | ||
|
@@ -2867,13 +2897,36 @@ struct BroadcastFolder : public OpRewritePattern<BroadcastOp> { | |
return success(); | ||
} | ||
}; | ||
|
||
/// BEFORE: | ||
/// %0 = vector.broadcast %arg0 : vector<4xi8> to vector<1x1x4xi8> | ||
/// AFTER: | ||
/// %0 = vector.shape_cast %arg0 : vector<4xi8> to vector<1x1x4xi8> | ||
struct BroadcastToShapeCast final | ||
: public OpRewritePattern<vector::BroadcastOp> { | ||
using OpRewritePattern::OpRewritePattern; | ||
LogicalResult matchAndRewrite(vector::BroadcastOp broadcast, | ||
PatternRewriter &rewriter) const override { | ||
auto sourceType = dyn_cast<VectorType>(broadcast.getSourceType()); | ||
if (!sourceType) { | ||
return rewriter.notifyMatchFailure( | ||
broadcast, "source is a scalar, shape_cast doesn't support scalar"); | ||
} | ||
|
||
VectorType outType = broadcast.getType(); | ||
if (sourceType.getNumElements() != outType.getNumElements()) | ||
return failure(); | ||
|
||
rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(broadcast, outType, | ||
broadcast.getSource()); | ||
return success(); | ||
} | ||
}; | ||
} // namespace | ||
|
||
void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results, | ||
MLIRContext *context) { | ||
// BroadcastToShapeCast is not a default canonicalization, it is opt-in by | ||
// calling `populateCastAwayVectorLeadingOneDimPatterns` | ||
results.add<BroadcastFolder>(context); | ||
results.add<BroadcastFolder, BroadcastToShapeCast>(context); | ||
} | ||
|
||
//===----------------------------------------------------------------------===// | ||
|
@@ -5991,10 +6044,7 @@ class ShapeCastCreateMaskFolderTrailingOneDim final | |
} | ||
}; | ||
|
||
/// Pattern to rewrite Y = ShapeCast(Broadcast(X)) as either | ||
/// i) Y = ShapeCast(X), or | ||
/// ii) Y = Broadcast(X) | ||
/// If both (i) and (ii) are possible, (i) is chosen. | ||
/// Pattern to rewrite Y = ShapeCast(Broadcast(X)) as Y = Broadcast(X) | ||
class ShapeCastBroadcastFolder final : public OpRewritePattern<ShapeCastOp> { | ||
public: | ||
using OpRewritePattern::OpRewritePattern; | ||
|
@@ -6009,22 +6059,6 @@ class ShapeCastBroadcastFolder final : public OpRewritePattern<ShapeCastOp> { | |
auto srcVectorType = dyn_cast<VectorType>(broadcastOp.getSourceType()); | ||
bool srcIsScalar = !srcVectorType; | ||
|
||
// Replace Y = ShapeCast(Broadcast(X)) with Y = ShapeCast(X). | ||
// Example: | ||
// %0 = vector.broadcast %in : vector<3x4xf32> to vector<1x3x4xf32> | ||
// %1 = vector.shape_cast %0 : vector<1x3x4xf32> to vector<12xf32> | ||
// to | ||
// %1 = vector.shape_cast %in : vector<3x4xf32> to vector<12xf32> | ||
if (srcVectorType) { | ||
if (srcVectorType.getNumElements() == | ||
shapeCastOp.getResultVectorType().getNumElements()) { | ||
rewriter.replaceOpWithNewOp<vector::ShapeCastOp>( | ||
shapeCastOp, shapeCastOp.getResultVectorType(), | ||
broadcastOp.getSource()); | ||
return success(); | ||
} | ||
} | ||
|
||
// Replace Y = ShapeCast(Broadcast(X)) with Y = Broadcast(X) | ||
// Example | ||
// %0 = vector.broadcast %in : vector<3xf32> to vector<2x4x3xf32> | ||
|
@@ -6233,7 +6267,7 @@ OpFoldResult vector::TransposeOp::fold(FoldAdaptor adaptor) { | |
// %0 = vector.transpose %arg, [0, 1] : vector<2x2xi8> to vector<2x2xi8> | ||
// %0 = vector.transpose %arg, [1, 0] : vector<1x1xi8> to vector<1x1xi8> | ||
// | ||
// Example of what NOT to fold: | ||
// Example of what not to fold: | ||
// %0 = vector.transpose %arg, [1, 0] : vector<2x2xi8> to vector<2x2xi8> | ||
// | ||
if (getSourceVectorType() == getResultVectorType() && | ||
|
@@ -6359,32 +6393,6 @@ class FoldTransposeCreateMask final : public OpRewritePattern<TransposeOp> { | |
} | ||
}; | ||
|
||
/// Folds transpose(shape_cast) into a new shape_cast. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Author note: I've removed this, as it now happens in 2 steps during canonicalization. The first (new) step is to rewrite the transpose as a shape_cast. The second step is to fold shape_cast(shape_cast) to shape_cast. |
||
class FoldTransposeShapeCast final : public OpRewritePattern<TransposeOp> { | ||
public: | ||
using OpRewritePattern::OpRewritePattern; | ||
|
||
LogicalResult matchAndRewrite(TransposeOp transposeOp, | ||
PatternRewriter &rewriter) const override { | ||
auto shapeCastOp = | ||
transposeOp.getVector().getDefiningOp<vector::ShapeCastOp>(); | ||
if (!shapeCastOp) | ||
return failure(); | ||
if (!isOrderPreserving(transposeOp)) | ||
return failure(); | ||
|
||
VectorType resultType = transposeOp.getType(); | ||
|
||
// We don't need to check isValidShapeCast at this point, because it is | ||
// guaranteed that merging the transpose into the the shape_cast is a valid | ||
// shape_cast, because the transpose just inserts/removes ones. | ||
|
||
rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(transposeOp, resultType, | ||
shapeCastOp.getSource()); | ||
return success(); | ||
} | ||
}; | ||
|
||
/// Folds transpose(broadcast(x)) to broadcast(x) if the transpose is | ||
/// 'order preserving', where 'order preserving' means the flattened | ||
/// inputs and outputs of the transpose have identical (numerical) values. | ||
|
@@ -6480,12 +6488,35 @@ class FoldTransposeBroadcast : public OpRewritePattern<vector::TransposeOp> { | |
} | ||
}; | ||
|
||
/// BEFORE: | ||
/// %0 = vector.transpose %arg0, [0, 2, 1] : | ||
/// vector<2x1x2xf32> to vector<2x2x1xf32> | ||
/// AFTER: | ||
/// %0 = vector.shape_cast %arg0 : | ||
/// vector<2x1x2xf32> to vector<2x2x1xf32> | ||
struct TransposeToShapeCast final | ||
: public OpRewritePattern<vector::TransposeOp> { | ||
using OpRewritePattern::OpRewritePattern; | ||
LogicalResult matchAndRewrite(vector::TransposeOp transpose, | ||
PatternRewriter &rewriter) const override { | ||
|
||
if (!isOrderPreserving(transpose)) { | ||
return rewriter.notifyMatchFailure( | ||
transpose, "not order preserving, so not semantically a 'copy'"); | ||
} | ||
rewriter.replaceOpWithNewOp<vector::ShapeCastOp>( | ||
transpose, transpose.getType(), transpose.getVector()); | ||
return success(); | ||
} | ||
}; | ||
|
||
} // namespace | ||
|
||
void vector::TransposeOp::getCanonicalizationPatterns( | ||
RewritePatternSet &results, MLIRContext *context) { | ||
results.add<FoldTransposeCreateMask, FoldTransposeShapeCast, TransposeFolder, | ||
FoldTransposeSplat, FoldTransposeBroadcast>(context); | ||
results.add<FoldTransposeBroadcast, FoldTransposeCreateMask, | ||
FoldTransposeSplat, TransposeFolder, TransposeToShapeCast>( | ||
context); | ||
} | ||
|
||
//===----------------------------------------------------------------------===// | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,7 +11,6 @@ | |
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
#include "mlir/Dialect/Arith/IR/Arith.h" | ||
#include "mlir/Dialect/MemRef/IR/MemRef.h" | ||
#include "mlir/Dialect/UB/IR/UBOps.h" | ||
#include "mlir/Dialect/Utils/IndexingUtils.h" | ||
|
@@ -382,64 +381,6 @@ class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> { | |
vector::VectorTransposeLowering vectorTransposeLowering; | ||
}; | ||
|
||
/// Rewrites vector.transpose as vector.shape_cast. This pattern is only applied | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Author note: I've removed this pattern, as it is a special case of TransposeToShapeCast |
||
/// to 2D vectors with at least one unit dim. For example: | ||
/// | ||
/// Replace: | ||
/// vector.transpose %0, [1, 0] : vector<4x1xi32>> to | ||
/// vector<1x4xi32> | ||
/// with: | ||
/// vector.shape_cast %0 : vector<4x1xi32> to vector<1x4xi32> | ||
/// | ||
/// Source with leading unit dim (inverse) is also replaced. Unit dim must | ||
/// be fixed. Non-unit dim can be scalable. | ||
/// | ||
/// TODO: This pattern was introduced specifically to help lower scalable | ||
/// vectors. In hindsight, a more specialised canonicalization (for shape_cast's | ||
/// to cancel out) would be preferable: | ||
/// | ||
/// BEFORE: | ||
/// %0 = some_op | ||
/// %1 = vector.shape_cast %0 : vector<[4]xf32> to vector<[4]x1xf32> | ||
/// %2 = vector.transpose %1 [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32> | ||
/// AFTER: | ||
/// %0 = some_op | ||
/// %1 = vector.shape_cast %0 : vector<[4]xf32> to vector<1x[4]xf32> | ||
/// | ||
/// Given the context above, we may want to consider (re-)moving this pattern | ||
/// at some later time. I am leaving it for now in case there are other users | ||
/// that I am not aware of. | ||
class Transpose2DWithUnitDimToShapeCast | ||
: public OpRewritePattern<vector::TransposeOp> { | ||
public: | ||
using OpRewritePattern::OpRewritePattern; | ||
|
||
Transpose2DWithUnitDimToShapeCast(MLIRContext *context, | ||
PatternBenefit benefit = 1) | ||
: OpRewritePattern<vector::TransposeOp>(context, benefit) {} | ||
|
||
LogicalResult matchAndRewrite(vector::TransposeOp op, | ||
PatternRewriter &rewriter) const override { | ||
Value input = op.getVector(); | ||
VectorType resType = op.getResultVectorType(); | ||
|
||
// Set up convenience transposition table. | ||
ArrayRef<int64_t> transp = op.getPermutation(); | ||
|
||
if (resType.getRank() == 2 && | ||
((resType.getShape().front() == 1 && | ||
!resType.getScalableDims().front()) || | ||
(resType.getShape().back() == 1 && | ||
!resType.getScalableDims().back())) && | ||
transp == ArrayRef<int64_t>({1, 0})) { | ||
rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, resType, input); | ||
return success(); | ||
} | ||
|
||
return failure(); | ||
} | ||
}; | ||
|
||
/// Rewrite a 2-D vector.transpose as a sequence of shuffle ops. | ||
/// If the strategy is Shuffle1D, it will be lowered to: | ||
/// vector.shape_cast 2D -> 1D | ||
|
@@ -511,8 +452,8 @@ class TransposeOp2DToShuffleLowering | |
void mlir::vector::populateVectorTransposeLoweringPatterns( | ||
RewritePatternSet &patterns, | ||
VectorTransposeLowering vectorTransposeLowering, PatternBenefit benefit) { | ||
patterns.add<Transpose2DWithUnitDimToShapeCast>(patterns.getContext(), | ||
benefit); | ||
TransposeOp::getCanonicalizationPatterns(patterns, patterns.getContext()); | ||
ShapeCastOp::getCanonicalizationPatterns(patterns, patterns.getContext()); | ||
patterns.add<TransposeOpLowering, TransposeOp2DToShuffleLowering>( | ||
vectorTransposeLowering, patterns.getContext(), benefit); | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -480,11 +480,11 @@ func.func @lift_illegal_transpose_to_memory_with_in_bounds_attr(%a: index, %b: i | |
|
||
// ----- | ||
|
||
// The pass should do nothing (and not crash). | ||
// CHECK-LABEL: @illegal_transpose_no_defining_source_op | ||
func.func @illegal_transpose_no_defining_source_op(%vec: vector<[4]x1xf32>) -> vector<1x[4]xf32> | ||
// CHECK-LABEL: @transpose_no_defining_source_op | ||
func.func @transpose_no_defining_source_op(%vec: vector<[4]x1xf32>) -> vector<1x[4]xf32> | ||
{ | ||
// CHECK: vector.transpose | ||
// CHECK: vector.shape_cast | ||
// CHECK-SAME: vector<[4]x1xf32> to vector<1x[4]xf32> | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @banach-space I'm getting back to this PR. Peephole question: is this operation ok? i.e. is
an acceptable operation to have after running There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In general, yes. But I can't guarantee there's no logic that expects |
||
%0 = vector.transpose %vec, [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32> | ||
return %0 : vector<1x[4]xf32> | ||
} | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -451,16 +451,25 @@ func.func @extract_strided_fold_insert(%a: vector<2x8xf32>, %b: vector<1x4xf32>, | |
// ----- | ||
|
||
// CHECK-LABEL: transpose_3D_identity | ||
// CHECK-SAME: ([[ARG:%.*]]: vector<4x3x2xf32>) | ||
// CHECK-SAME: ([[ARG:%.*]]: vector<4x3x2xf32>) | ||
// CHECK-NEXT: return [[ARG]] | ||
func.func @transpose_3D_identity(%arg : vector<4x3x2xf32>) -> vector<4x3x2xf32> { | ||
// CHECK-NOT: transpose | ||
%0 = vector.transpose %arg, [0, 1, 2] : vector<4x3x2xf32> to vector<4x3x2xf32> | ||
// CHECK-NEXT: return [[ARG]] | ||
return %0 : vector<4x3x2xf32> | ||
} | ||
|
||
// ----- | ||
|
||
// CHECK-LABEL: transpose_0D_identity | ||
// CHECK-SAME: ([[ARG:%.*]]: vector<i8>) | ||
// CHECK-NEXT: return [[ARG]] | ||
func.func @transpose_0D_identity(%arg : vector<i8>) -> vector<i8> { | ||
%0 = vector.transpose %arg, [] : vector<i8> to vector<i8> | ||
return %0 : vector<i8> | ||
} | ||
|
||
// ----- | ||
|
||
// CHECK-LABEL: transpose_2D_sequence | ||
// CHECK-SAME: ([[ARG:%.*]]: vector<4x3xf32>) | ||
func.func @transpose_2D_sequence(%arg : vector<4x3xf32>) -> vector<4x3xf32> { | ||
|
@@ -753,12 +762,13 @@ func.func @fold_extract_broadcast_0dvec_input_scalar_output(%a : vector<f32>, | |
|
||
// ----- | ||
|
||
|
||
// CHECK-LABEL: negative_fold_extract_broadcast | ||
// CHECK: vector.broadcast %{{.*}} : vector<1x1xf32> to vector<1x1x4xf32> | ||
// CHECK: vector.extract %{{.*}}[0, 0] : vector<4xf32> from vector<1x1x4xf32> | ||
// CHECK: vector.broadcast %{{.*}} : vector<1x1xf32> to vector<1x2x4xf32> | ||
// CHECK: vector.extract %{{.*}}[0, 0] : vector<4xf32> from vector<1x2x4xf32> | ||
func.func @negative_fold_extract_broadcast(%a : vector<1x1xf32>) -> vector<4xf32> { | ||
%b = vector.broadcast %a : vector<1x1xf32> to vector<1x1x4xf32> | ||
%r = vector.extract %b[0, 0] : vector<4xf32> from vector<1x1x4xf32> | ||
%b = vector.broadcast %a : vector<1x1xf32> to vector<1x2x4xf32> | ||
%r = vector.extract %b[0, 0] : vector<4xf32> from vector<1x2x4xf32> | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Keep both tests, one with the original shape and one with the new ones? Unrelated: it looks like we are missing a canonicalization patter here? This should be turned into a single There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Makes sense, will do.
No because you can't broadcast <1x1xf32> to <4xf32> -- broadcasts can never reduce rank in Vector. FWIW slightly related to my comment here where this would be simpler if ops didn't do implicit shape casting. In this case if it was something like
ie if we constrained broadcasts and extracts to be rank retaining, then this would be canonicalized to
which, if you have faith that the shape_casts will vanish at a later point, is simpler! p.s. I plan to reply in #145740 later today |
||
return %r : vector<4xf32> | ||
} | ||
|
||
|
@@ -797,8 +807,8 @@ func.func @fold_extract_broadcast_dim1_broadcasting(%a : vector<2x1xf32>, | |
// rank(extract_output) < rank(broadcast_input) | ||
func.func @fold_extract_broadcast_to_lower_rank(%a : vector<2x4xf32>, | ||
%idx0 : index, %idx1 : index) -> vector<4xf32> { | ||
%b = vector.broadcast %a : vector<2x4xf32> to vector<1x2x4xf32> | ||
%r = vector.extract %b[%idx0, %idx1] : vector<4xf32> from vector<1x2x4xf32> | ||
%b = vector.broadcast %a : vector<2x4xf32> to vector<2x2x4xf32> | ||
%r = vector.extract %b[%idx0, %idx1] : vector<4xf32> from vector<2x2x4xf32> | ||
Comment on lines
+844
to
+845
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why change shapes? |
||
return %r : vector<4xf32> | ||
} | ||
|
||
|
@@ -1920,12 +1930,12 @@ func.func @extract_strided_splat(%arg0: f16) -> vector<2x4xf16> { | |
|
||
// ----- | ||
|
||
// CHECK-LABEL: func @insert_extract_to_broadcast | ||
// CHECK-LABEL: func @insert_extract_to_shape_cast | ||
// CHECK-SAME: (%[[ARG0:.*]]: vector<1x1x4xf32>, %[[ARG1:.*]]: vector<4xf32>) | ||
// CHECK: %[[V0:.*]] = vector.extract %[[ARG0]][0, 0] : vector<4xf32> from vector<1x1x4xf32> | ||
// CHECK: %[[V1:.*]] = vector.broadcast %[[ARG1]] : vector<4xf32> to vector<1x1x4xf32> | ||
// CHECK: %[[V0:.*]] = vector.shape_cast %[[ARG0]] : vector<1x1x4xf32> to vector<4xf32> | ||
// CHECK: %[[V1:.*]] = vector.shape_cast %[[ARG1]] : vector<4xf32> to vector<1x1x4xf32> | ||
// CHECK: return %[[V0]], %[[V1]] : vector<4xf32>, vector<1x1x4xf32> | ||
func.func @insert_extract_to_broadcast(%arg0 : vector<1x1x4xf32>, | ||
func.func @insert_extract_to_shape_cast(%arg0 : vector<1x1x4xf32>, | ||
%arg1 : vector<4xf32>) -> (vector<4xf32>, vector<1x1x4xf32>) { | ||
%0 = vector.extract %arg0[0, 0] : vector<4xf32> from vector<1x1x4xf32> | ||
%1 = vector.insert %arg1, %arg0 [0, 0] : vector<4xf32> into vector<1x1x4xf32> | ||
|
@@ -2277,7 +2287,7 @@ func.func @shuffle_1d_rhs_poison() -> vector<4xi32> { | |
|
||
// CHECK-LABEL: func @shuffle_canonicalize_0d | ||
func.func @shuffle_canonicalize_0d(%v0 : vector<i32>, %v1 : vector<i32>) -> vector<1xi32> { | ||
// CHECK: vector.broadcast %{{.*}} : vector<i32> to vector<1xi32> | ||
// CHECK: vector.shape_cast %{{.*}} : vector<i32> to vector<1xi32> | ||
%shuffle = vector.shuffle %v0, %v1 [0] : vector<i32>, vector<i32> | ||
return %shuffle : vector<1xi32> | ||
} | ||
|
@@ -2764,9 +2774,8 @@ func.func @transfer_read_from_rank_reducing_extract_slice(%src: tensor<1x8x8x8xf | |
// CHECK-LABEL: func.func @extract_from_broadcast | ||
func.func @extract_from_broadcast(%src: vector<1x1x1xf32>) -> vector<1xf32> { | ||
%0 = vector.broadcast %src : vector<1x1x1xf32> to vector<1x1x32x1xf32> | ||
|
||
// CHECK-NEXT: %0 = vector.extract {{.*}}[0, 0] : vector<1xf32> from vector<1x1x1xf32> | ||
// CHECK-NEXT: return %0 : vector<1xf32> | ||
// CHECK-NEXT: %[[RES:.*]] = vector.shape_cast{{.*}} vector<1x1x1xf32> to vector<1xf32> | ||
// CHECK-NEXT: return %[[RES]] : vector<1xf32> | ||
%1 = vector.extract %0[0, 0, 31] : vector<1xf32> from vector<1x1x32x1xf32> | ||
return %1: vector<1xf32> | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Author note: I've removed this, as now it happens in 2 steps during canonicalization. The first converts the Broadcast to a ShapeCast. The second combines the 2 ShapeCasts.