Skip to content

Commit fc760c0

Browse files
committed
[mlir][vector] Fold cancelling vector.shape_cast(vector.broadcast)
vector.broadcast can inject all size one dimensions. If it's followed by a vector.shape_cast to the original type, we can cancel the op pair, like cancelling consecutive shape_cast ops. Reviewed By: mravishankar Differential Revision: https://reviews.llvm.org/D124094
1 parent f693280 commit fc760c0

File tree

2 files changed

+41
-1
lines changed

2 files changed

+41
-1
lines changed

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4088,7 +4088,7 @@ LogicalResult ShapeCastOp::verify() {
40884088
}
40894089

40904090
OpFoldResult ShapeCastOp::fold(ArrayRef<Attribute> operands) {
4091-
// Nop shape cast.
4091+
// No-op shape cast.
40924092
if (getSource().getType() == getResult().getType())
40934093
return getSource();
40944094

@@ -4113,6 +4113,13 @@ OpFoldResult ShapeCastOp::fold(ArrayRef<Attribute> operands) {
41134113
setOperand(otherOp.getSource());
41144114
return getResult();
41154115
}
4116+
4117+
// Cancelling broadcast and shape cast ops.
4118+
if (auto bcastOp = getSource().getDefiningOp<BroadcastOp>()) {
4119+
if (bcastOp.getSourceType() == getType())
4120+
return bcastOp.getSource();
4121+
}
4122+
41164123
return {};
41174124
}
41184125

mlir/test/Dialect/Vector/canonicalize.mlir

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -635,6 +635,39 @@ func.func @dont_fold_expand_collapse(%arg0: vector<1x1x64xf32>) -> vector<8x8xf3
635635

636636
// -----
637637

638+
// CHECK-LABEL: func @fold_broadcast_shapecast
639+
// CHECK-SAME: (%[[V:.+]]: vector<4xf32>)
640+
// CHECK: return %[[V]]
641+
func @fold_broadcast_shapecast(%arg0: vector<4xf32>) -> vector<4xf32> {
642+
%0 = vector.broadcast %arg0 : vector<4xf32> to vector<1x1x4xf32>
643+
%1 = vector.shape_cast %0 : vector<1x1x4xf32> to vector<4xf32>
644+
return %1 : vector<4xf32>
645+
}
646+
647+
// -----
648+
649+
// CHECK-LABEL: func @dont_fold_broadcast_shapecast_scalar
650+
// CHECK: vector.broadcast
651+
// CHECK: vector.shape_cast
652+
func @dont_fold_broadcast_shapecast_scalar(%arg0: f32) -> vector<1xf32> {
653+
%0 = vector.broadcast %arg0 : f32 to vector<1x1x1xf32>
654+
%1 = vector.shape_cast %0 : vector<1x1x1xf32> to vector<1xf32>
655+
return %1 : vector<1xf32>
656+
}
657+
658+
// -----
659+
660+
// CHECK-LABEL: func @dont_fold_broadcast_shapecast_diff_shape
661+
// CHECK: vector.broadcast
662+
// CHECK: vector.shape_cast
663+
func @dont_fold_broadcast_shapecast_diff_shape(%arg0: vector<4xf32>) -> vector<8xf32> {
664+
%0 = vector.broadcast %arg0 : vector<4xf32> to vector<1x2x4xf32>
665+
%1 = vector.shape_cast %0 : vector<1x2x4xf32> to vector<8xf32>
666+
return %1 : vector<8xf32>
667+
}
668+
669+
// -----
670+
638671
// CHECK-LABEL: fold_vector_transfers
639672
func.func @fold_vector_transfers(%A: memref<?x8xf32>) -> (vector<4x8xf32>, vector<4x9xf32>) {
640673
%c0 = arith.constant 0 : index

0 commit comments

Comments
 (0)