Skip to content

[MLIR] [Vector] Linearization patterns for vector.load and vector.store #145115

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from

Conversation

nbpatel
Copy link
Contributor

@nbpatel nbpatel commented Jun 20, 2025

This PR add inearizarion pattern for vector.load and vector.store. It is follow up PR to #143420 (comment)

@nbpatel
Copy link
Contributor Author

nbpatel commented Jun 20, 2025

@newling following up on #143420 (comment)

For 2,
Is this what you meant? created a draft PR because that would be quicker

@newling
Copy link
Contributor

newling commented Jun 20, 2025

@newling following up on #143420 (comment)

For 2, Is this what you meant? created a draft PR because that would be quicker

Yeah, this is what I meant 👍

Ideally we'll eventually have something like flattening of transfer_read, done here. i.e. linearize even when there is more than 1 dimension of size > 1, and it needn't be the inner-most dim. But I guess that can wait. FWIW IMO that transfer_read code should be in VectorLinearize too, I mentioned that at the bottom of this comment. And the vector.load linearization code could probably then reuse some of it. Something for the future, maybe!

@nbpatel
Copy link
Contributor Author

nbpatel commented Jun 23, 2025

@newling can you review it as well?

@llvmbot
Copy link
Member

llvmbot commented Jun 23, 2025

@llvm/pr-subscribers-mlir-vector

Author: Nishant Patel (nbpatel)

Changes

Full diff: https://github.com/llvm/llvm-project/pull/145115.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp (+69-2)
  • (modified) mlir/test/Dialect/Vector/linearize.mlir (+23)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index 678a88627ca82..f0b77da5acd02 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -623,6 +623,73 @@ struct LinearizeVectorCreateMask final
   }
 };
 
+/// This pattern linearizes vector.load from vector<1xN> to vector<N>.
+/// It currently supports only lineariztion of <1XN> to <N>
+/// Following,
+///   vector.load %arg0[%c0, %c0] : memref<1x4xf32>, vector<1x4xf32>
+/// is converted to:
+///   vector.load %arg0[%c0, %c0] : memref<1x4xf32>, vector<4xf32>
+///   vector.shape_cast %load_result : vector<4xf32> to vector<1x4xf32>
+struct LinearizeVectorLoad final : public OpConversionPattern<vector::LoadOp> {
+  using OpConversionPattern::OpConversionPattern;
+  LinearizeVectorLoad(const TypeConverter &typeConverter, MLIRContext *context,
+                      PatternBenefit benefit = 1)
+      : OpConversionPattern(typeConverter, context, benefit) {}
+
+  LogicalResult
+  matchAndRewrite(vector::LoadOp loadOp, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    VectorType vecTy = loadOp.getType();
+    if (!vecTy || vecTy.getRank() != 2 || vecTy.getShape()[0] != 1)
+      return rewriter.notifyMatchFailure(loadOp, "only vector<1xN> supported");
+    auto linearTy = VectorType::get(vecTy.getShape()[1], vecTy.getElementType(),
+                                    vecTy.isScalable());
+    auto newLoad = rewriter.create<vector::LoadOp>(
+        loadOp.getLoc(), linearTy, adaptor.getBase(), adaptor.getIndices());
+    auto shapeCast = rewriter.create<vector::ShapeCastOp>(
+        loadOp.getLoc(), vecTy, newLoad.getResult());
+    rewriter.replaceOp(loadOp, shapeCast.getResult());
+    return success();
+  }
+};
+
+/// This pattern linearizes vector.store from vector<1xN> to vector<N>.
+/// It currently supports only lineariztion of <1XN> to <N>
+/// Following,
+///   vector.store %arg0, %arg1[%c0, %c0]
+///     : vector<1x4xf32>, memref<1x4xf32>
+/// is converted to:
+///   vector.shape_cast %arg0 : vector<1x4xf32> to vector<4xf32>
+///   vector.store %arg0, %arg1[%c0, %%c0]
+///     : vector<4xf32>, memref<1x4xf32>
+struct LinearizeVectorStore final
+    : public OpConversionPattern<vector::StoreOp> {
+  using OpConversionPattern::OpConversionPattern;
+  LinearizeVectorStore(const TypeConverter &typeConverter, MLIRContext *context,
+                       PatternBenefit benefit = 1)
+      : OpConversionPattern(typeConverter, context, benefit) {}
+
+  LogicalResult
+  matchAndRewrite(vector::StoreOp storeOp, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    VectorType vecTy = storeOp.getValueToStore().getType();
+    if (!vecTy || vecTy.getRank() != 2 || vecTy.getShape()[0] != 1)
+      return rewriter.notifyMatchFailure(storeOp, "only vector<1xN> supported");
+    auto linearTy = VectorType::get(vecTy.getShape()[1], vecTy.getElementType(),
+                                    vecTy.isScalable());
+
+    Value valueToStore = adaptor.getValueToStore();
+    if (valueToStore.getType() != linearTy) {
+      valueToStore = rewriter.create<vector::ShapeCastOp>(
+          storeOp.getLoc(), linearTy, valueToStore);
+    }
+
+    rewriter.replaceOpWithNewOp<vector::StoreOp>(
+        storeOp, valueToStore, adaptor.getBase(), adaptor.getIndices());
+    return success();
+  }
+};
+
 } // namespace
 
 /// This method defines the set of operations that are linearizable, and hence
@@ -714,8 +781,8 @@ void mlir::vector::populateVectorLinearizeBasePatterns(
     RewritePatternSet &patterns) {
   patterns
       .add<LinearizeConstantLike, LinearizeVectorizable, LinearizeVectorBitCast,
-           LinearizeVectorSplat, LinearizeVectorCreateMask>(
-          typeConverter, patterns.getContext());
+           LinearizeVectorSplat, LinearizeVectorCreateMask, LinearizeVectorLoad,
+           LinearizeVectorStore>(typeConverter, patterns.getContext());
 }
 
 void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns(
diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir
index 9cbf319ffddb2..fa0436792d3f0 100644
--- a/mlir/test/Dialect/Vector/linearize.mlir
+++ b/mlir/test/Dialect/Vector/linearize.mlir
@@ -464,3 +464,26 @@ func.func @linearize_scalable_create_mask(%arg0 : index, %arg1 : index) -> vecto
   %0 = vector.create_mask %arg0, %arg1 : vector<1x[16]xi1>
   return %0 : vector<1x[16]xi1>
 }
+
+// CHECK-LABEL: linearize_vector_load
+// CHECK-SAME: (%[[ARG0:.*]]: memref<1x4xf32>) -> vector<1x4xf32>
+func.func @linearize_vector_load(%arg0: memref<1x4xf32>) -> vector<1x4xf32> {
+  // CHECK: %[[CST0:.*]] = arith.constant 0 : index
+  // CHECK: %[[LOAD:.*]] = vector.load %[[ARG0]][%[[CST0]], %[[CST0]]] : memref<1x4xf32>, vector<4xf32>
+  // CHECK: %[[CAST:.*]] = vector.shape_cast %[[LOAD]] : vector<4xf32> to vector<1x4xf32>
+  // CHECK: return %[[CAST]] : vector<1x4xf32>
+  %c0 = arith.constant 0 : index
+  %0 = vector.load %arg0[%c0, %c0] : memref<1x4xf32>, vector<1x4xf32>
+  return %0 : vector<1x4xf32>
+}
+
+// CHECK-LABEL: linearize_vector_store
+// CHECK-SAME: (%[[ARG0:.*]]: memref<1x4xf32>, %[[ARG1:.*]]: vector<1x4xf32>)
+func.func @linearize_vector_store(%arg0: memref<1x4xf32>, %arg1: vector<1x4xf32>) {
+  // CHECK: %[[CAST:.*]] = vector.shape_cast %arg1 : vector<1x4xf32> to vector<4xf32>
+  // CHECK: %[[CST0:.*]] = arith.constant 0 : index
+  // CHECK: vector.store %[[CAST]], %[[ARG0]][%[[CST0]], %[[CST0]]] : memref<1x4xf32>, vector<4xf32>
+  %c0 = arith.constant 0 : index
+  vector.store %arg1, %arg0[%c0, %c0] : memref<1x4xf32>, vector<1x4xf32>
+  return
+}

@llvmbot
Copy link
Member

llvmbot commented Jun 23, 2025

@llvm/pr-subscribers-mlir

Author: Nishant Patel (nbpatel)

Changes

Full diff: https://github.com/llvm/llvm-project/pull/145115.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp (+69-2)
  • (modified) mlir/test/Dialect/Vector/linearize.mlir (+23)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index 678a88627ca82..f0b77da5acd02 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -623,6 +623,73 @@ struct LinearizeVectorCreateMask final
   }
 };
 
+/// This pattern linearizes vector.load from vector<1xN> to vector<N>.
+/// It currently supports only lineariztion of <1XN> to <N>
+/// Following,
+///   vector.load %arg0[%c0, %c0] : memref<1x4xf32>, vector<1x4xf32>
+/// is converted to:
+///   vector.load %arg0[%c0, %c0] : memref<1x4xf32>, vector<4xf32>
+///   vector.shape_cast %load_result : vector<4xf32> to vector<1x4xf32>
+struct LinearizeVectorLoad final : public OpConversionPattern<vector::LoadOp> {
+  using OpConversionPattern::OpConversionPattern;
+  LinearizeVectorLoad(const TypeConverter &typeConverter, MLIRContext *context,
+                      PatternBenefit benefit = 1)
+      : OpConversionPattern(typeConverter, context, benefit) {}
+
+  LogicalResult
+  matchAndRewrite(vector::LoadOp loadOp, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    VectorType vecTy = loadOp.getType();
+    if (!vecTy || vecTy.getRank() != 2 || vecTy.getShape()[0] != 1)
+      return rewriter.notifyMatchFailure(loadOp, "only vector<1xN> supported");
+    auto linearTy = VectorType::get(vecTy.getShape()[1], vecTy.getElementType(),
+                                    vecTy.isScalable());
+    auto newLoad = rewriter.create<vector::LoadOp>(
+        loadOp.getLoc(), linearTy, adaptor.getBase(), adaptor.getIndices());
+    auto shapeCast = rewriter.create<vector::ShapeCastOp>(
+        loadOp.getLoc(), vecTy, newLoad.getResult());
+    rewriter.replaceOp(loadOp, shapeCast.getResult());
+    return success();
+  }
+};
+
+/// This pattern linearizes vector.store from vector<1xN> to vector<N>.
+/// It currently supports only lineariztion of <1XN> to <N>
+/// Following,
+///   vector.store %arg0, %arg1[%c0, %c0]
+///     : vector<1x4xf32>, memref<1x4xf32>
+/// is converted to:
+///   vector.shape_cast %arg0 : vector<1x4xf32> to vector<4xf32>
+///   vector.store %arg0, %arg1[%c0, %%c0]
+///     : vector<4xf32>, memref<1x4xf32>
+struct LinearizeVectorStore final
+    : public OpConversionPattern<vector::StoreOp> {
+  using OpConversionPattern::OpConversionPattern;
+  LinearizeVectorStore(const TypeConverter &typeConverter, MLIRContext *context,
+                       PatternBenefit benefit = 1)
+      : OpConversionPattern(typeConverter, context, benefit) {}
+
+  LogicalResult
+  matchAndRewrite(vector::StoreOp storeOp, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    VectorType vecTy = storeOp.getValueToStore().getType();
+    if (!vecTy || vecTy.getRank() != 2 || vecTy.getShape()[0] != 1)
+      return rewriter.notifyMatchFailure(storeOp, "only vector<1xN> supported");
+    auto linearTy = VectorType::get(vecTy.getShape()[1], vecTy.getElementType(),
+                                    vecTy.isScalable());
+
+    Value valueToStore = adaptor.getValueToStore();
+    if (valueToStore.getType() != linearTy) {
+      valueToStore = rewriter.create<vector::ShapeCastOp>(
+          storeOp.getLoc(), linearTy, valueToStore);
+    }
+
+    rewriter.replaceOpWithNewOp<vector::StoreOp>(
+        storeOp, valueToStore, adaptor.getBase(), adaptor.getIndices());
+    return success();
+  }
+};
+
 } // namespace
 
 /// This method defines the set of operations that are linearizable, and hence
@@ -714,8 +781,8 @@ void mlir::vector::populateVectorLinearizeBasePatterns(
     RewritePatternSet &patterns) {
   patterns
       .add<LinearizeConstantLike, LinearizeVectorizable, LinearizeVectorBitCast,
-           LinearizeVectorSplat, LinearizeVectorCreateMask>(
-          typeConverter, patterns.getContext());
+           LinearizeVectorSplat, LinearizeVectorCreateMask, LinearizeVectorLoad,
+           LinearizeVectorStore>(typeConverter, patterns.getContext());
 }
 
 void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns(
diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir
index 9cbf319ffddb2..fa0436792d3f0 100644
--- a/mlir/test/Dialect/Vector/linearize.mlir
+++ b/mlir/test/Dialect/Vector/linearize.mlir
@@ -464,3 +464,26 @@ func.func @linearize_scalable_create_mask(%arg0 : index, %arg1 : index) -> vecto
   %0 = vector.create_mask %arg0, %arg1 : vector<1x[16]xi1>
   return %0 : vector<1x[16]xi1>
 }
+
+// CHECK-LABEL: linearize_vector_load
+// CHECK-SAME: (%[[ARG0:.*]]: memref<1x4xf32>) -> vector<1x4xf32>
+func.func @linearize_vector_load(%arg0: memref<1x4xf32>) -> vector<1x4xf32> {
+  // CHECK: %[[CST0:.*]] = arith.constant 0 : index
+  // CHECK: %[[LOAD:.*]] = vector.load %[[ARG0]][%[[CST0]], %[[CST0]]] : memref<1x4xf32>, vector<4xf32>
+  // CHECK: %[[CAST:.*]] = vector.shape_cast %[[LOAD]] : vector<4xf32> to vector<1x4xf32>
+  // CHECK: return %[[CAST]] : vector<1x4xf32>
+  %c0 = arith.constant 0 : index
+  %0 = vector.load %arg0[%c0, %c0] : memref<1x4xf32>, vector<1x4xf32>
+  return %0 : vector<1x4xf32>
+}
+
+// CHECK-LABEL: linearize_vector_store
+// CHECK-SAME: (%[[ARG0:.*]]: memref<1x4xf32>, %[[ARG1:.*]]: vector<1x4xf32>)
+func.func @linearize_vector_store(%arg0: memref<1x4xf32>, %arg1: vector<1x4xf32>) {
+  // CHECK: %[[CAST:.*]] = vector.shape_cast %arg1 : vector<1x4xf32> to vector<4xf32>
+  // CHECK: %[[CST0:.*]] = arith.constant 0 : index
+  // CHECK: vector.store %[[CAST]], %[[ARG0]][%[[CST0]], %[[CST0]]] : memref<1x4xf32>, vector<4xf32>
+  %c0 = arith.constant 0 : index
+  vector.store %arg1, %arg0[%c0, %c0] : memref<1x4xf32>, vector<1x4xf32>
+  return
+}

Copy link
Contributor

@newling newling left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! I added 2 small comments about simplifying and generalizing a bit.

@nbpatel
Copy link
Contributor Author

nbpatel commented Jun 23, 2025

Thanks! I added 2 small comments about simplifying and generalizing a bit.

Addressed the feedback, thanks :)

@banach-space
Copy link
Contributor

What should be happening if one of the outer dims != 1? This is a general design question.

Could you add some negative tests? Also, this should work for scalable vectors provided that there is only one scalable dim. Would you mind taking care of that as well? Thanks!

@nbpatel
Copy link
Contributor Author

nbpatel commented Jul 7, 2025

What should be happening if one of the outer dims != 1? This is a general design question.

Could you add some negative tests? Also, this should work for scalable vectors provided that there is only one scalable dim. Would you mind taking care of that as well? Thanks!

Hey Andrzej, I was on vacation for past two weeks and hence the delay in the response.

  1. My understanding is the user can use the unrolling patterns for load and store to get to the vector<1x...xN> form if one of the outer dim is not 1 and then use the linearization patterns to linearize it.

  2. For negative cases, the tests will error out currently with legalization error, so not sure what is the correct way to add them. I don't see any other ops having negative tests. @newling ?

  3. I added the check and test case for scalable dim. Please take a look.

Thanks.

@nbpatel nbpatel requested a review from newling July 7, 2025 20:59
@newling
Copy link
Contributor

newling commented Jul 7, 2025

  1. For negative cases, the tests will error out currently with legalization error, so not sure what is the correct way to add them. I don't see any other ops having negative tests. @newling ?

That is the expected error. It's difficult to provide good error diagnostics with legalization but note that I am currently working on a major overhaul to use simple rewrite patterns #146030 (am quite close to posting a PR, but please don't wait for it).

@dcaballe
Copy link
Contributor

dcaballe commented Jul 8, 2025

Hey, could we add more context about what we are trying to do here, preferably to the PR description? Why only unit dims are supported? What is the long-term plan to support more generic cases? What are vector linearization responsibilities vs applying drop unit dims or other transformations? I think elaborating on this aspect of the design would be helpful to understand where this is going and set users' expectations. Thanks!

@nbpatel
Copy link
Contributor Author

nbpatel commented Jul 9, 2025

Hey, could we add more context about what we are trying to do here, preferably to the PR description? Why only unit dims are supported? What is the long-term plan to support more generic cases? What are vector linearization responsibilities vs applying drop unit dims or other transformations? I think elaborating on this aspect of the design would be helpful to understand where this is going and set users' expectations. Thanks!

The motivation for this was to have rank-1 loads/stores. Initially there was a version which was more generic (PR) but the feedback was to add them to unroll #143420 so that user can unroll it to some form of <1x..xN> and then linearize it to 1D using these patterns. We can support generic cases in an iterative PR I think.

@banach-space
Copy link
Contributor

the feedback was to add them to unroll #143420 so that user can unroll it to some form of <1x..xN> and then linearize it to 1D using these patterns.

Yes, that's what @newling proposed and it made sense to me back then and it still does :) To make sure that we continue on this path, please add comments in the code re-directing people to unrolling for more general cases.

We can support generic cases in an iterative PR I think.

No need to, IMHO, unless there's a specific need.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants