-
Notifications
You must be signed in to change notification settings - Fork 14.4k
[MLIR] [Vector] Linearization patterns for vector.load and vector.store #145115
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
@newling following up on #143420 (comment) For 2, |
Yeah, this is what I meant 👍 Ideally we'll eventually have something like flattening of transfer_read, done here. i.e. linearize even when there is more than 1 dimension of size > 1, and it needn't be the inner-most dim. But I guess that can wait. FWIW IMO that transfer_read code should be in VectorLinearize too, I mentioned that at the bottom of this comment. And the vector.load linearization code could probably then reuse some of it. Something for the future, maybe! |
@newling can you review it as well? |
@llvm/pr-subscribers-mlir-vector Author: Nishant Patel (nbpatel) ChangesFull diff: https://github.com/llvm/llvm-project/pull/145115.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index 678a88627ca82..f0b77da5acd02 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -623,6 +623,73 @@ struct LinearizeVectorCreateMask final
}
};
+/// This pattern linearizes vector.load from vector<1xN> to vector<N>.
+/// It currently supports only lineariztion of <1XN> to <N>
+/// Following,
+/// vector.load %arg0[%c0, %c0] : memref<1x4xf32>, vector<1x4xf32>
+/// is converted to:
+/// vector.load %arg0[%c0, %c0] : memref<1x4xf32>, vector<4xf32>
+/// vector.shape_cast %load_result : vector<4xf32> to vector<1x4xf32>
+struct LinearizeVectorLoad final : public OpConversionPattern<vector::LoadOp> {
+ using OpConversionPattern::OpConversionPattern;
+ LinearizeVectorLoad(const TypeConverter &typeConverter, MLIRContext *context,
+ PatternBenefit benefit = 1)
+ : OpConversionPattern(typeConverter, context, benefit) {}
+
+ LogicalResult
+ matchAndRewrite(vector::LoadOp loadOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ VectorType vecTy = loadOp.getType();
+ if (!vecTy || vecTy.getRank() != 2 || vecTy.getShape()[0] != 1)
+ return rewriter.notifyMatchFailure(loadOp, "only vector<1xN> supported");
+ auto linearTy = VectorType::get(vecTy.getShape()[1], vecTy.getElementType(),
+ vecTy.isScalable());
+ auto newLoad = rewriter.create<vector::LoadOp>(
+ loadOp.getLoc(), linearTy, adaptor.getBase(), adaptor.getIndices());
+ auto shapeCast = rewriter.create<vector::ShapeCastOp>(
+ loadOp.getLoc(), vecTy, newLoad.getResult());
+ rewriter.replaceOp(loadOp, shapeCast.getResult());
+ return success();
+ }
+};
+
+/// This pattern linearizes vector.store from vector<1xN> to vector<N>.
+/// It currently supports only lineariztion of <1XN> to <N>
+/// Following,
+/// vector.store %arg0, %arg1[%c0, %c0]
+/// : vector<1x4xf32>, memref<1x4xf32>
+/// is converted to:
+/// vector.shape_cast %arg0 : vector<1x4xf32> to vector<4xf32>
+/// vector.store %arg0, %arg1[%c0, %%c0]
+/// : vector<4xf32>, memref<1x4xf32>
+struct LinearizeVectorStore final
+ : public OpConversionPattern<vector::StoreOp> {
+ using OpConversionPattern::OpConversionPattern;
+ LinearizeVectorStore(const TypeConverter &typeConverter, MLIRContext *context,
+ PatternBenefit benefit = 1)
+ : OpConversionPattern(typeConverter, context, benefit) {}
+
+ LogicalResult
+ matchAndRewrite(vector::StoreOp storeOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ VectorType vecTy = storeOp.getValueToStore().getType();
+ if (!vecTy || vecTy.getRank() != 2 || vecTy.getShape()[0] != 1)
+ return rewriter.notifyMatchFailure(storeOp, "only vector<1xN> supported");
+ auto linearTy = VectorType::get(vecTy.getShape()[1], vecTy.getElementType(),
+ vecTy.isScalable());
+
+ Value valueToStore = adaptor.getValueToStore();
+ if (valueToStore.getType() != linearTy) {
+ valueToStore = rewriter.create<vector::ShapeCastOp>(
+ storeOp.getLoc(), linearTy, valueToStore);
+ }
+
+ rewriter.replaceOpWithNewOp<vector::StoreOp>(
+ storeOp, valueToStore, adaptor.getBase(), adaptor.getIndices());
+ return success();
+ }
+};
+
} // namespace
/// This method defines the set of operations that are linearizable, and hence
@@ -714,8 +781,8 @@ void mlir::vector::populateVectorLinearizeBasePatterns(
RewritePatternSet &patterns) {
patterns
.add<LinearizeConstantLike, LinearizeVectorizable, LinearizeVectorBitCast,
- LinearizeVectorSplat, LinearizeVectorCreateMask>(
- typeConverter, patterns.getContext());
+ LinearizeVectorSplat, LinearizeVectorCreateMask, LinearizeVectorLoad,
+ LinearizeVectorStore>(typeConverter, patterns.getContext());
}
void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns(
diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir
index 9cbf319ffddb2..fa0436792d3f0 100644
--- a/mlir/test/Dialect/Vector/linearize.mlir
+++ b/mlir/test/Dialect/Vector/linearize.mlir
@@ -464,3 +464,26 @@ func.func @linearize_scalable_create_mask(%arg0 : index, %arg1 : index) -> vecto
%0 = vector.create_mask %arg0, %arg1 : vector<1x[16]xi1>
return %0 : vector<1x[16]xi1>
}
+
+// CHECK-LABEL: linearize_vector_load
+// CHECK-SAME: (%[[ARG0:.*]]: memref<1x4xf32>) -> vector<1x4xf32>
+func.func @linearize_vector_load(%arg0: memref<1x4xf32>) -> vector<1x4xf32> {
+ // CHECK: %[[CST0:.*]] = arith.constant 0 : index
+ // CHECK: %[[LOAD:.*]] = vector.load %[[ARG0]][%[[CST0]], %[[CST0]]] : memref<1x4xf32>, vector<4xf32>
+ // CHECK: %[[CAST:.*]] = vector.shape_cast %[[LOAD]] : vector<4xf32> to vector<1x4xf32>
+ // CHECK: return %[[CAST]] : vector<1x4xf32>
+ %c0 = arith.constant 0 : index
+ %0 = vector.load %arg0[%c0, %c0] : memref<1x4xf32>, vector<1x4xf32>
+ return %0 : vector<1x4xf32>
+}
+
+// CHECK-LABEL: linearize_vector_store
+// CHECK-SAME: (%[[ARG0:.*]]: memref<1x4xf32>, %[[ARG1:.*]]: vector<1x4xf32>)
+func.func @linearize_vector_store(%arg0: memref<1x4xf32>, %arg1: vector<1x4xf32>) {
+ // CHECK: %[[CAST:.*]] = vector.shape_cast %arg1 : vector<1x4xf32> to vector<4xf32>
+ // CHECK: %[[CST0:.*]] = arith.constant 0 : index
+ // CHECK: vector.store %[[CAST]], %[[ARG0]][%[[CST0]], %[[CST0]]] : memref<1x4xf32>, vector<4xf32>
+ %c0 = arith.constant 0 : index
+ vector.store %arg1, %arg0[%c0, %c0] : memref<1x4xf32>, vector<1x4xf32>
+ return
+}
|
@llvm/pr-subscribers-mlir Author: Nishant Patel (nbpatel) ChangesFull diff: https://github.com/llvm/llvm-project/pull/145115.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index 678a88627ca82..f0b77da5acd02 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -623,6 +623,73 @@ struct LinearizeVectorCreateMask final
}
};
+/// This pattern linearizes vector.load from vector<1xN> to vector<N>.
+/// It currently supports only lineariztion of <1XN> to <N>
+/// Following,
+/// vector.load %arg0[%c0, %c0] : memref<1x4xf32>, vector<1x4xf32>
+/// is converted to:
+/// vector.load %arg0[%c0, %c0] : memref<1x4xf32>, vector<4xf32>
+/// vector.shape_cast %load_result : vector<4xf32> to vector<1x4xf32>
+struct LinearizeVectorLoad final : public OpConversionPattern<vector::LoadOp> {
+ using OpConversionPattern::OpConversionPattern;
+ LinearizeVectorLoad(const TypeConverter &typeConverter, MLIRContext *context,
+ PatternBenefit benefit = 1)
+ : OpConversionPattern(typeConverter, context, benefit) {}
+
+ LogicalResult
+ matchAndRewrite(vector::LoadOp loadOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ VectorType vecTy = loadOp.getType();
+ if (!vecTy || vecTy.getRank() != 2 || vecTy.getShape()[0] != 1)
+ return rewriter.notifyMatchFailure(loadOp, "only vector<1xN> supported");
+ auto linearTy = VectorType::get(vecTy.getShape()[1], vecTy.getElementType(),
+ vecTy.isScalable());
+ auto newLoad = rewriter.create<vector::LoadOp>(
+ loadOp.getLoc(), linearTy, adaptor.getBase(), adaptor.getIndices());
+ auto shapeCast = rewriter.create<vector::ShapeCastOp>(
+ loadOp.getLoc(), vecTy, newLoad.getResult());
+ rewriter.replaceOp(loadOp, shapeCast.getResult());
+ return success();
+ }
+};
+
+/// This pattern linearizes vector.store from vector<1xN> to vector<N>.
+/// It currently supports only lineariztion of <1XN> to <N>
+/// Following,
+/// vector.store %arg0, %arg1[%c0, %c0]
+/// : vector<1x4xf32>, memref<1x4xf32>
+/// is converted to:
+/// vector.shape_cast %arg0 : vector<1x4xf32> to vector<4xf32>
+/// vector.store %arg0, %arg1[%c0, %%c0]
+/// : vector<4xf32>, memref<1x4xf32>
+struct LinearizeVectorStore final
+ : public OpConversionPattern<vector::StoreOp> {
+ using OpConversionPattern::OpConversionPattern;
+ LinearizeVectorStore(const TypeConverter &typeConverter, MLIRContext *context,
+ PatternBenefit benefit = 1)
+ : OpConversionPattern(typeConverter, context, benefit) {}
+
+ LogicalResult
+ matchAndRewrite(vector::StoreOp storeOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ VectorType vecTy = storeOp.getValueToStore().getType();
+ if (!vecTy || vecTy.getRank() != 2 || vecTy.getShape()[0] != 1)
+ return rewriter.notifyMatchFailure(storeOp, "only vector<1xN> supported");
+ auto linearTy = VectorType::get(vecTy.getShape()[1], vecTy.getElementType(),
+ vecTy.isScalable());
+
+ Value valueToStore = adaptor.getValueToStore();
+ if (valueToStore.getType() != linearTy) {
+ valueToStore = rewriter.create<vector::ShapeCastOp>(
+ storeOp.getLoc(), linearTy, valueToStore);
+ }
+
+ rewriter.replaceOpWithNewOp<vector::StoreOp>(
+ storeOp, valueToStore, adaptor.getBase(), adaptor.getIndices());
+ return success();
+ }
+};
+
} // namespace
/// This method defines the set of operations that are linearizable, and hence
@@ -714,8 +781,8 @@ void mlir::vector::populateVectorLinearizeBasePatterns(
RewritePatternSet &patterns) {
patterns
.add<LinearizeConstantLike, LinearizeVectorizable, LinearizeVectorBitCast,
- LinearizeVectorSplat, LinearizeVectorCreateMask>(
- typeConverter, patterns.getContext());
+ LinearizeVectorSplat, LinearizeVectorCreateMask, LinearizeVectorLoad,
+ LinearizeVectorStore>(typeConverter, patterns.getContext());
}
void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns(
diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir
index 9cbf319ffddb2..fa0436792d3f0 100644
--- a/mlir/test/Dialect/Vector/linearize.mlir
+++ b/mlir/test/Dialect/Vector/linearize.mlir
@@ -464,3 +464,26 @@ func.func @linearize_scalable_create_mask(%arg0 : index, %arg1 : index) -> vecto
%0 = vector.create_mask %arg0, %arg1 : vector<1x[16]xi1>
return %0 : vector<1x[16]xi1>
}
+
+// CHECK-LABEL: linearize_vector_load
+// CHECK-SAME: (%[[ARG0:.*]]: memref<1x4xf32>) -> vector<1x4xf32>
+func.func @linearize_vector_load(%arg0: memref<1x4xf32>) -> vector<1x4xf32> {
+ // CHECK: %[[CST0:.*]] = arith.constant 0 : index
+ // CHECK: %[[LOAD:.*]] = vector.load %[[ARG0]][%[[CST0]], %[[CST0]]] : memref<1x4xf32>, vector<4xf32>
+ // CHECK: %[[CAST:.*]] = vector.shape_cast %[[LOAD]] : vector<4xf32> to vector<1x4xf32>
+ // CHECK: return %[[CAST]] : vector<1x4xf32>
+ %c0 = arith.constant 0 : index
+ %0 = vector.load %arg0[%c0, %c0] : memref<1x4xf32>, vector<1x4xf32>
+ return %0 : vector<1x4xf32>
+}
+
+// CHECK-LABEL: linearize_vector_store
+// CHECK-SAME: (%[[ARG0:.*]]: memref<1x4xf32>, %[[ARG1:.*]]: vector<1x4xf32>)
+func.func @linearize_vector_store(%arg0: memref<1x4xf32>, %arg1: vector<1x4xf32>) {
+ // CHECK: %[[CAST:.*]] = vector.shape_cast %arg1 : vector<1x4xf32> to vector<4xf32>
+ // CHECK: %[[CST0:.*]] = arith.constant 0 : index
+ // CHECK: vector.store %[[CAST]], %[[ARG0]][%[[CST0]], %[[CST0]]] : memref<1x4xf32>, vector<4xf32>
+ %c0 = arith.constant 0 : index
+ vector.store %arg1, %arg0[%c0, %c0] : memref<1x4xf32>, vector<1x4xf32>
+ return
+}
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! I added 2 small comments about simplifying and generalizing a bit.
Addressed the feedback, thanks :) |
What should be happening if one of the outer dims != 1? This is a general design question. Could you add some negative tests? Also, this should work for scalable vectors provided that there is only one scalable dim. Would you mind taking care of that as well? Thanks! |
Hey Andrzej, I was on vacation for past two weeks and hence the delay in the response.
Thanks. |
That is the expected error. It's difficult to provide good error diagnostics with legalization but note that I am currently working on a major overhaul to use simple rewrite patterns #146030 (am quite close to posting a PR, but please don't wait for it). |
Hey, could we add more context about what we are trying to do here, preferably to the PR description? Why only unit dims are supported? What is the long-term plan to support more generic cases? What are vector linearization responsibilities vs applying drop unit dims or other transformations? I think elaborating on this aspect of the design would be helpful to understand where this is going and set users' expectations. Thanks! |
The motivation for this was to have rank-1 loads/stores. Initially there was a version which was more generic (PR) but the feedback was to add them to unroll #143420 so that user can unroll it to some form of <1x..xN> and then linearize it to 1D using these patterns. We can support generic cases in an iterative PR I think. |
Yes, that's what @newling proposed and it made sense to me back then and it still does :) To make sure that we continue on this path, please add comments in the code re-directing people to unrolling for more general cases.
No need to, IMHO, unless there's a specific need. |
This PR add inearizarion pattern for vector.load and vector.store. It is follow up PR to #143420 (comment)