Skip to content

[mlir][vector] Folder: shape_cast(extract) -> extract #146368

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 62 additions & 64 deletions mlir/lib/Dialect/Vector/IR/VectorOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1696,59 +1696,71 @@ static bool hasZeroDimVectors(Operation *op) {
llvm::any_of(op->getResultTypes(), hasZeroDimVectorType);
}

/// All BroadcastOps and SplatOps, as well as ShapeCastOps that only prepends
/// 1s, are considered 'broadcastlike'.
static bool isBroadcastLike(Operation *op) {
if (isa<BroadcastOp, SplatOp>(op))
return true;

auto shapeCast = dyn_cast<ShapeCastOp>(op);
if (!shapeCast)
return false;

// Check that it just prepends 1s, like (2,3) -> (1,1,2,3).
// Condition 1: dst has hight rank.
// Condition 2: src shape is a suffix of dst shape.
VectorType srcType = shapeCast.getSourceVectorType();
ArrayRef<int64_t> srcShape = srcType.getShape();
uint64_t srcRank = srcType.getRank();
ArrayRef<int64_t> dstShape = shapeCast.getType().getShape();
return dstShape.size() >= srcRank && dstShape.take_back(srcRank) == srcShape;
}

/// Fold extractOp with scalar result coming from BroadcastOp or SplatOp.
static Value foldExtractFromBroadcast(ExtractOp extractOp) {
Operation *defOp = extractOp.getVector().getDefiningOp();
if (!defOp || !isa<vector::BroadcastOp, SplatOp>(defOp))

Operation *broadcastLikeOp = extractOp.getVector().getDefiningOp();
if (!broadcastLikeOp || !isBroadcastLike(broadcastLikeOp))
return Value();

Value source = defOp->getOperand(0);
if (extractOp.getType() == source.getType())
return source;
auto getRank = [](Type type) {
return llvm::isa<VectorType>(type) ? llvm::cast<VectorType>(type).getRank()
: 0;
};
Value src = broadcastLikeOp->getOperand(0);

// Replace extract(broadcast(X)) with X
if (extractOp.getType() == src.getType())
return src;

// If splat or broadcast from a scalar, just return the source scalar.
unsigned broadcastSrcRank = getRank(source.getType());
if (broadcastSrcRank == 0 && source.getType() == extractOp.getType())
return source;
// Get required types and ranks in the chain
// src -> broadcastDst -> dst
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess ExtractOp is missing somewhere here?

auto srcType = llvm::dyn_cast<VectorType>(src.getType());
auto dstType = llvm::dyn_cast<VectorType>(extractOp.getType());
unsigned srcRank = srcType ? srcType.getRank() : 0;
unsigned broadcastDstRank = extractOp.getSourceVectorType().getRank();
unsigned dstRank = dstType ? dstType.getRank() : 0;

unsigned extractResultRank = getRank(extractOp.getType());
if (extractResultRank > broadcastSrcRank)
// Cannot do without the broadcast if overall the rank increases.
if (dstRank > srcRank)
return Value();
// Check that the dimension of the result haven't been broadcasted.
auto extractVecType = llvm::dyn_cast<VectorType>(extractOp.getType());
auto broadcastVecType = llvm::dyn_cast<VectorType>(source.getType());
if (extractVecType && broadcastVecType &&
extractVecType.getShape() !=
broadcastVecType.getShape().take_back(extractResultRank))

assert(srcType && "src must be a vector type because of previous checks");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm, this doesn't quite agree with:

  unsigned srcRank = srcType ? srcType.getRank() : 0;

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's the intervening rank check which allows this assertion. Code is like

 if (extractOp.getType() == src.getType())
    return src;
[...]
  unsigned srcRank = srcType ? srcType.getRank() : 0;
[...]
 if (dstRank > srcRank)
    return Value();
[...]
 assert(srcType && "src must be a vector type because of previous checks");

Suppose src is scalar at the point of assertion.
Then srcRank is 0, so dstRank is 0.
If dstRank is 0, then dst is scalar.
If they're both scalar, we would have returned early (same types).
Contradiction -- src is not scalar.

TBH this is reasoning is probably too complicated, and could be replaced with a if (...) return Value()

if (srcType) return Value();


ArrayRef<int64_t> srcShape = srcType.getShape();
if (dstType && dstType.getShape() != srcShape.take_back(dstRank))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add a comment explaining what case this is?

return Value();

auto broadcastOp = cast<vector::BroadcastOp>(defOp);
int64_t broadcastDstRank = broadcastOp.getResultVectorType().getRank();
// Replace extract(broadcast(X)) with extract(X).
// First, determine the new extraction position.
unsigned deltaOverall = srcRank - dstRank;
unsigned deltaBroadcast = broadcastDstRank - srcRank;

// Detect all the positions that come from "dim-1" broadcasting.
// These dimensions correspond to "dim-1" broadcasted dims; set the mathching
// extract position to `0` when extracting from the source operand.
llvm::SetVector<int64_t> broadcastedUnitDims =
broadcastOp.computeBroadcastedUnitDims();
SmallVector<OpFoldResult> extractPos(extractOp.getMixedPosition());
OpBuilder b(extractOp.getContext());
int64_t broadcastRankDiff = broadcastDstRank - broadcastSrcRank;
for (int64_t i = broadcastRankDiff, e = extractPos.size(); i < e; ++i)
if (broadcastedUnitDims.contains(i))
extractPos[i] = b.getIndexAttr(0);
// `rankDiff` leading dimensions correspond to new broadcasted dims, drop the
// matching extract position when extracting from the source operand.
int64_t rankDiff = broadcastSrcRank - extractResultRank;
extractPos.erase(extractPos.begin(),
std::next(extractPos.begin(), extractPos.size() - rankDiff));
// OpBuilder is only used as a helper to build an I64ArrayAttr.
auto [staticPos, dynPos] = decomposeMixedValues(extractPos);
SmallVector<OpFoldResult> oldPositions = extractOp.getMixedPosition();
SmallVector<OpFoldResult> newPositions(deltaOverall);
IntegerAttr zero = OpBuilder(extractOp.getContext()).getIndexAttr(0);
for (auto [i, size] : llvm::enumerate(srcShape.take_front(deltaOverall))) {
newPositions[i] = size == 1 ? zero : oldPositions[i + deltaBroadcast];
}
auto [staticPos, dynPos] = decomposeMixedValues(newPositions);
extractOp->setOperands(
llvm::to_vector(llvm::concat<Value>(ValueRange(source), dynPos)));
llvm::to_vector(llvm::concat<Value>(ValueRange(src), dynPos)));
extractOp.setStaticPosition(staticPos);
return extractOp.getResult();
}
Expand Down Expand Up @@ -2193,32 +2205,18 @@ class ExtractOpFromBroadcast final : public OpRewritePattern<ExtractOp> {

LogicalResult matchAndRewrite(ExtractOp extractOp,
PatternRewriter &rewriter) const override {
Operation *defOp = extractOp.getVector().getDefiningOp();
if (!defOp || !isa<vector::BroadcastOp, SplatOp>(defOp))
return failure();

Value source = defOp->getOperand(0);
if (extractOp.getType() == source.getType())
Operation *broadcastLikeOp = extractOp.getVector().getDefiningOp();
VectorType outType = dyn_cast<VectorType>(extractOp.getType());
if (!broadcastLikeOp || !isBroadcastLike(broadcastLikeOp) || !outType)
return failure();
auto getRank = [](Type type) {
return llvm::isa<VectorType>(type)
? llvm::cast<VectorType>(type).getRank()
: 0;
};
unsigned broadcastSrcRank = getRank(source.getType());
unsigned extractResultRank = getRank(extractOp.getType());
// We only consider the case where the rank of the source is less than or
// equal to the rank of the extract dst. The other cases are handled in the
// folding patterns.
if (extractResultRank < broadcastSrcRank)
return failure();
// For scalar result, the input can only be a rank-0 vector, which will
// be handled by the folder.
if (extractResultRank == 0)

Value source = broadcastLikeOp->getOperand(0);
if (isBroadcastableTo(source.getType(), outType) !=
BroadcastableToResult::Success)
return failure();

rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
extractOp, extractOp.getType(), source);
rewriter.replaceOpWithNewOp<BroadcastOp>(extractOp, outType, source);
return success();
}
};
Expand Down
46 changes: 42 additions & 4 deletions mlir/test/Dialect/Vector/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -764,17 +764,27 @@ func.func @negative_fold_extract_broadcast(%a : vector<1x1xf32>) -> vector<4xf32

// -----

// CHECK-LABEL: fold_extract_splat
// CHECK-LABEL: fold_extract_scalar_from_splat
// CHECK-SAME: %[[A:.*]]: f32
// CHECK: return %[[A]] : f32
func.func @fold_extract_splat(%a : f32, %idx0 : index, %idx1 : index, %idx2 : index) -> f32 {
func.func @fold_extract_scalar_from_splat(%a : f32, %idx0 : index, %idx1 : index, %idx2 : index) -> f32 {
%b = vector.splat %a : vector<1x2x4xf32>
%r = vector.extract %b[%idx0, %idx1, %idx2] : f32 from vector<1x2x4xf32>
return %r : f32
}

// -----

// CHECK-LABEL: fold_extract_vector_from_splat
// CHECK: vector.broadcast {{.*}} f32 to vector<4xf32>
func.func @fold_extract_vector_from_splat(%a : f32, %idx0 : index, %idx1 : index) -> vector<4xf32> {
%b = vector.splat %a : vector<1x2x4xf32>
%r = vector.extract %b[%idx0, %idx1] : vector<4xf32> from vector<1x2x4xf32>
return %r : vector<4xf32>
}

// -----

// CHECK-LABEL: fold_extract_broadcast_dim1_broadcasting
// CHECK-SAME: %[[A:.*]]: vector<2x1xf32>
// CHECK-SAME: %[[IDX:.*]]: index, %[[IDX1:.*]]: index, %[[IDX2:.*]]: index
Expand Down Expand Up @@ -804,6 +814,21 @@ func.func @fold_extract_broadcast_to_lower_rank(%a : vector<2x4xf32>,

// -----

// Test where the shape_cast is broadcast-like.
// CHECK-LABEL: fold_extract_shape_cast_to_lower_rank
// CHECK-SAME: %[[A:.*]]: vector<2x4xf32>
// CHECK-SAME: %[[IDX0:.*]]: index, %[[IDX1:.*]]: index
// CHECK: %[[B:.+]] = vector.extract %[[A]][%[[IDX1]]] : vector<4xf32> from vector<2x4xf32>
// CHECK: return %[[B]] : vector<4xf32>
func.func @fold_extract_shape_cast_to_lower_rank(%a : vector<2x4xf32>,
%idx0 : index, %idx1 : index) -> vector<4xf32> {
%b = vector.shape_cast %a : vector<2x4xf32> to vector<1x2x4xf32>
%r = vector.extract %b[%idx0, %idx1] : vector<4xf32> from vector<1x2x4xf32>
return %r : vector<4xf32>
}

// -----

// CHECK-LABEL: fold_extract_broadcast_to_higher_rank
// CHECK: %[[B:.*]] = vector.broadcast %{{.*}} : f32 to vector<4xf32>
// CHECK: return %[[B]] : vector<4xf32>
Expand Down Expand Up @@ -831,6 +856,19 @@ func.func @fold_extract_broadcast_to_equal_rank(%a : vector<1xf32>, %idx0 : inde

// -----

// CHECK-LABEL: fold_extract_broadcastlike_shape_cast
// CHECK-SAME: %[[A:.*]]: vector<1xf32>
// CHECK: %[[R:.*]] = vector.broadcast %[[A]] : vector<1xf32> to vector<1x1xf32>
// CHECK: return %[[R]] : vector<1x1xf32>
func.func @fold_extract_broadcastlike_shape_cast(%a : vector<1xf32>, %idx0 : index)
-> vector<1x1xf32> {
%s = vector.shape_cast %a : vector<1xf32> to vector<1x1x1xf32>
%r = vector.extract %s[%idx0] : vector<1x1xf32> from vector<1x1x1xf32>
return %r : vector<1x1xf32>
}

// -----

// CHECK-LABEL: @fold_extract_shuffle
// CHECK-SAME: %[[A:.*]]: vector<8xf32>, %[[B:.*]]: vector<8xf32>
// CHECK-NOT: vector.shuffle
Expand Down Expand Up @@ -1549,7 +1587,7 @@ func.func @negative_store_to_load_tensor_memref(
%arg0 : tensor<?x?xf32>,
%arg1 : memref<?x?xf32>,
%v0 : vector<4x2xf32>
) -> vector<4x2xf32>
) -> vector<4x2xf32>
{
%c0 = arith.constant 0 : index
%cf0 = arith.constant 0.0 : f32
Expand Down Expand Up @@ -1606,7 +1644,7 @@ func.func @negative_store_to_load_tensor_broadcast_out_of_bounds(%arg0 : tensor<
// CHECK: vector.transfer_read
func.func @negative_store_to_load_tensor_broadcast_masked(
%arg0 : tensor<?x?xf32>, %v0 : vector<4x2xf32>, %mask : vector<4x2xi1>)
-> vector<4x2x6xf32>
-> vector<4x2x6xf32>
{
%c0 = arith.constant 0 : index
%cf0 = arith.constant 0.0 : f32
Expand Down