From 8c3dbf88a7a190c3134992bb4cb3f4bcf133cfe8 Mon Sep 17 00:00:00 2001 From: nbpatel Date: Thu, 19 Jun 2025 19:40:07 +0000 Subject: [PATCH 1/6] Linearization patterns for vector.load and vector.store --- .../Vector/Transforms/VectorLinearize.cpp | 71 ++++++++++++++++++- mlir/test/Dialect/Vector/linearize.mlir | 23 ++++++ 2 files changed, 92 insertions(+), 2 deletions(-) diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp index 678a88627ca82..f0b77da5acd02 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp @@ -623,6 +623,73 @@ struct LinearizeVectorCreateMask final } }; +/// This pattern linearizes vector.load from vector<1xN> to vector. +/// It currently supports only lineariztion of <1XN> to +/// Following, +/// vector.load %arg0[%c0, %c0] : memref<1x4xf32>, vector<1x4xf32> +/// is converted to: +/// vector.load %arg0[%c0, %c0] : memref<1x4xf32>, vector<4xf32> +/// vector.shape_cast %load_result : vector<4xf32> to vector<1x4xf32> +struct LinearizeVectorLoad final : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LinearizeVectorLoad(const TypeConverter &typeConverter, MLIRContext *context, + PatternBenefit benefit = 1) + : OpConversionPattern(typeConverter, context, benefit) {} + + LogicalResult + matchAndRewrite(vector::LoadOp loadOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + VectorType vecTy = loadOp.getType(); + if (!vecTy || vecTy.getRank() != 2 || vecTy.getShape()[0] != 1) + return rewriter.notifyMatchFailure(loadOp, "only vector<1xN> supported"); + auto linearTy = VectorType::get(vecTy.getShape()[1], vecTy.getElementType(), + vecTy.isScalable()); + auto newLoad = rewriter.create( + loadOp.getLoc(), linearTy, adaptor.getBase(), adaptor.getIndices()); + auto shapeCast = rewriter.create( + loadOp.getLoc(), vecTy, newLoad.getResult()); + rewriter.replaceOp(loadOp, shapeCast.getResult()); + return success(); + } +}; + +/// This pattern linearizes vector.store from vector<1xN> to vector. +/// It currently supports only lineariztion of <1XN> to +/// Following, +/// vector.store %arg0, %arg1[%c0, %c0] +/// : vector<1x4xf32>, memref<1x4xf32> +/// is converted to: +/// vector.shape_cast %arg0 : vector<1x4xf32> to vector<4xf32> +/// vector.store %arg0, %arg1[%c0, %%c0] +/// : vector<4xf32>, memref<1x4xf32> +struct LinearizeVectorStore final + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LinearizeVectorStore(const TypeConverter &typeConverter, MLIRContext *context, + PatternBenefit benefit = 1) + : OpConversionPattern(typeConverter, context, benefit) {} + + LogicalResult + matchAndRewrite(vector::StoreOp storeOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + VectorType vecTy = storeOp.getValueToStore().getType(); + if (!vecTy || vecTy.getRank() != 2 || vecTy.getShape()[0] != 1) + return rewriter.notifyMatchFailure(storeOp, "only vector<1xN> supported"); + auto linearTy = VectorType::get(vecTy.getShape()[1], vecTy.getElementType(), + vecTy.isScalable()); + + Value valueToStore = adaptor.getValueToStore(); + if (valueToStore.getType() != linearTy) { + valueToStore = rewriter.create( + storeOp.getLoc(), linearTy, valueToStore); + } + + rewriter.replaceOpWithNewOp( + storeOp, valueToStore, adaptor.getBase(), adaptor.getIndices()); + return success(); + } +}; + } // namespace /// This method defines the set of operations that are linearizable, and hence @@ -714,8 +781,8 @@ void mlir::vector::populateVectorLinearizeBasePatterns( RewritePatternSet &patterns) { patterns .add( - typeConverter, patterns.getContext()); + LinearizeVectorSplat, LinearizeVectorCreateMask, LinearizeVectorLoad, + LinearizeVectorStore>(typeConverter, patterns.getContext()); } void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns( diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir index 9cbf319ffddb2..fa0436792d3f0 100644 --- a/mlir/test/Dialect/Vector/linearize.mlir +++ b/mlir/test/Dialect/Vector/linearize.mlir @@ -464,3 +464,26 @@ func.func @linearize_scalable_create_mask(%arg0 : index, %arg1 : index) -> vecto %0 = vector.create_mask %arg0, %arg1 : vector<1x[16]xi1> return %0 : vector<1x[16]xi1> } + +// CHECK-LABEL: linearize_vector_load +// CHECK-SAME: (%[[ARG0:.*]]: memref<1x4xf32>) -> vector<1x4xf32> +func.func @linearize_vector_load(%arg0: memref<1x4xf32>) -> vector<1x4xf32> { + // CHECK: %[[CST0:.*]] = arith.constant 0 : index + // CHECK: %[[LOAD:.*]] = vector.load %[[ARG0]][%[[CST0]], %[[CST0]]] : memref<1x4xf32>, vector<4xf32> + // CHECK: %[[CAST:.*]] = vector.shape_cast %[[LOAD]] : vector<4xf32> to vector<1x4xf32> + // CHECK: return %[[CAST]] : vector<1x4xf32> + %c0 = arith.constant 0 : index + %0 = vector.load %arg0[%c0, %c0] : memref<1x4xf32>, vector<1x4xf32> + return %0 : vector<1x4xf32> +} + +// CHECK-LABEL: linearize_vector_store +// CHECK-SAME: (%[[ARG0:.*]]: memref<1x4xf32>, %[[ARG1:.*]]: vector<1x4xf32>) +func.func @linearize_vector_store(%arg0: memref<1x4xf32>, %arg1: vector<1x4xf32>) { + // CHECK: %[[CAST:.*]] = vector.shape_cast %arg1 : vector<1x4xf32> to vector<4xf32> + // CHECK: %[[CST0:.*]] = arith.constant 0 : index + // CHECK: vector.store %[[CAST]], %[[ARG0]][%[[CST0]], %[[CST0]]] : memref<1x4xf32>, vector<4xf32> + %c0 = arith.constant 0 : index + vector.store %arg1, %arg0[%c0, %c0] : memref<1x4xf32>, vector<1x4xf32> + return +} From 92b299f07846ad55821075f08df5cba598c1766f Mon Sep 17 00:00:00 2001 From: nbpatel Date: Mon, 23 Jun 2025 19:14:00 +0000 Subject: [PATCH 2/6] Address comments --- .../Vector/Transforms/VectorLinearize.cpp | 48 ++++++++----------- mlir/test/Dialect/Vector/linearize.mlir | 16 +++---- 2 files changed, 29 insertions(+), 35 deletions(-) diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp index f0b77da5acd02..890d882ea2129 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp @@ -623,9 +623,9 @@ struct LinearizeVectorCreateMask final } }; -/// This pattern linearizes vector.load from vector<1xN> to vector. -/// It currently supports only lineariztion of <1XN> to -/// Following, +/// This pattern linearizes vector.load from vector<1x1x...xN> to vector +/// It currently supports linearization where all but the last dimension are 1 +/// The following, /// vector.load %arg0[%c0, %c0] : memref<1x4xf32>, vector<1x4xf32> /// is converted to: /// vector.load %arg0[%c0, %c0] : memref<1x4xf32>, vector<4xf32> @@ -640,27 +640,27 @@ struct LinearizeVectorLoad final : public OpConversionPattern { matchAndRewrite(vector::LoadOp loadOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { VectorType vecTy = loadOp.getType(); - if (!vecTy || vecTy.getRank() != 2 || vecTy.getShape()[0] != 1) - return rewriter.notifyMatchFailure(loadOp, "only vector<1xN> supported"); - auto linearTy = VectorType::get(vecTy.getShape()[1], vecTy.getElementType(), - vecTy.isScalable()); + if (!vecTy || !llvm::all_of(vecTy.getShape().drop_back(1), + [](auto d) { return d == 1; })) + return rewriter.notifyMatchFailure(loadOp, + "only vector<1x1x...xN> supported"); + auto linearTy = VectorType::get(vecTy.getShape().back(), + vecTy.getElementType(), vecTy.isScalable()); auto newLoad = rewriter.create( loadOp.getLoc(), linearTy, adaptor.getBase(), adaptor.getIndices()); - auto shapeCast = rewriter.create( - loadOp.getLoc(), vecTy, newLoad.getResult()); - rewriter.replaceOp(loadOp, shapeCast.getResult()); + rewriter.replaceOp(loadOp, newLoad.getResult()); return success(); } }; -/// This pattern linearizes vector.store from vector<1xN> to vector. -/// It currently supports only lineariztion of <1XN> to -/// Following, -/// vector.store %arg0, %arg1[%c0, %c0] +/// This pattern linearizes vector.store from vector<1x1x...xN> to vector +/// It currently supports linearization where all but the last dimension are 1 +/// The following, +/// vector.store %arg0, %arg1[%c0, %c0]s /// : vector<1x4xf32>, memref<1x4xf32> /// is converted to: /// vector.shape_cast %arg0 : vector<1x4xf32> to vector<4xf32> -/// vector.store %arg0, %arg1[%c0, %%c0] +/// vector.store %arg0, %arg1[%c0, %c0] /// : vector<4xf32>, memref<1x4xf32> struct LinearizeVectorStore final : public OpConversionPattern { @@ -673,19 +673,13 @@ struct LinearizeVectorStore final matchAndRewrite(vector::StoreOp storeOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { VectorType vecTy = storeOp.getValueToStore().getType(); - if (!vecTy || vecTy.getRank() != 2 || vecTy.getShape()[0] != 1) - return rewriter.notifyMatchFailure(storeOp, "only vector<1xN> supported"); - auto linearTy = VectorType::get(vecTy.getShape()[1], vecTy.getElementType(), - vecTy.isScalable()); - - Value valueToStore = adaptor.getValueToStore(); - if (valueToStore.getType() != linearTy) { - valueToStore = rewriter.create( - storeOp.getLoc(), linearTy, valueToStore); - } - + if (!vecTy || !llvm::all_of(vecTy.getShape().drop_back(1), + [](auto d) { return d == 1; })) + return rewriter.notifyMatchFailure(storeOp, + "only vector<1x1x...xN> supported"); rewriter.replaceOpWithNewOp( - storeOp, valueToStore, adaptor.getBase(), adaptor.getIndices()); + storeOp, adaptor.getValueToStore(), adaptor.getBase(), + adaptor.getIndices()); return success(); } }; diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir index fa0436792d3f0..9a017ceedcebe 100644 --- a/mlir/test/Dialect/Vector/linearize.mlir +++ b/mlir/test/Dialect/Vector/linearize.mlir @@ -466,24 +466,24 @@ func.func @linearize_scalable_create_mask(%arg0 : index, %arg1 : index) -> vecto } // CHECK-LABEL: linearize_vector_load -// CHECK-SAME: (%[[ARG0:.*]]: memref<1x4xf32>) -> vector<1x4xf32> -func.func @linearize_vector_load(%arg0: memref<1x4xf32>) -> vector<1x4xf32> { +// CHECK-SAME: (%[[ARG0:.*]]: memref<2x8xf32>) -> vector<1x4xf32> +func.func @linearize_vector_load(%arg0: memref<2x8xf32>) -> vector<1x4xf32> { // CHECK: %[[CST0:.*]] = arith.constant 0 : index - // CHECK: %[[LOAD:.*]] = vector.load %[[ARG0]][%[[CST0]], %[[CST0]]] : memref<1x4xf32>, vector<4xf32> + // CHECK: %[[LOAD:.*]] = vector.load %[[ARG0]][%[[CST0]], %[[CST0]]] : memref<2x8xf32>, vector<4xf32> // CHECK: %[[CAST:.*]] = vector.shape_cast %[[LOAD]] : vector<4xf32> to vector<1x4xf32> // CHECK: return %[[CAST]] : vector<1x4xf32> %c0 = arith.constant 0 : index - %0 = vector.load %arg0[%c0, %c0] : memref<1x4xf32>, vector<1x4xf32> + %0 = vector.load %arg0[%c0, %c0] : memref<2x8xf32>, vector<1x4xf32> return %0 : vector<1x4xf32> } // CHECK-LABEL: linearize_vector_store -// CHECK-SAME: (%[[ARG0:.*]]: memref<1x4xf32>, %[[ARG1:.*]]: vector<1x4xf32>) -func.func @linearize_vector_store(%arg0: memref<1x4xf32>, %arg1: vector<1x4xf32>) { +// CHECK-SAME: (%[[ARG0:.*]]: memref<2x8xf32>, %[[ARG1:.*]]: vector<1x4xf32>) +func.func @linearize_vector_store(%arg0: memref<2x8xf32>, %arg1: vector<1x4xf32>) { // CHECK: %[[CAST:.*]] = vector.shape_cast %arg1 : vector<1x4xf32> to vector<4xf32> // CHECK: %[[CST0:.*]] = arith.constant 0 : index - // CHECK: vector.store %[[CAST]], %[[ARG0]][%[[CST0]], %[[CST0]]] : memref<1x4xf32>, vector<4xf32> + // CHECK: vector.store %[[CAST]], %[[ARG0]][%[[CST0]], %[[CST0]]] : memref<2x8xf32>, vector<4xf32> %c0 = arith.constant 0 : index - vector.store %arg1, %arg0[%c0, %c0] : memref<1x4xf32>, vector<1x4xf32> + vector.store %arg1, %arg0[%c0, %c0] : memref<2x8xf32>, vector<1x4xf32> return } From f700fe6aa6dd06d85216f147fd33108e1ae5f352 Mon Sep 17 00:00:00 2001 From: nbpatel Date: Mon, 23 Jun 2025 19:24:14 +0000 Subject: [PATCH 3/6] Fix --- mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp index 890d882ea2129..99d18fec18120 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp @@ -644,8 +644,7 @@ struct LinearizeVectorLoad final : public OpConversionPattern { [](auto d) { return d == 1; })) return rewriter.notifyMatchFailure(loadOp, "only vector<1x1x...xN> supported"); - auto linearTy = VectorType::get(vecTy.getShape().back(), - vecTy.getElementType(), vecTy.isScalable()); + auto linearTy = typeConverter->convertType(loadOp.getType()); auto newLoad = rewriter.create( loadOp.getLoc(), linearTy, adaptor.getBase(), adaptor.getIndices()); rewriter.replaceOp(loadOp, newLoad.getResult()); From bcb5306d3a3379530ad58e4b4927ad230451b732 Mon Sep 17 00:00:00 2001 From: nbpatel Date: Thu, 3 Jul 2025 15:16:40 +0000 Subject: [PATCH 4/6] Address feedback --- .../Vector/Transforms/VectorLinearize.cpp | 33 ++++++++++++++++--- mlir/test/Dialect/Vector/linearize.mlir | 23 +++++++++++++ 2 files changed, 51 insertions(+), 5 deletions(-) diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp index 99d18fec18120..fa5f9bbd2dbf6 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp @@ -640,11 +640,23 @@ struct LinearizeVectorLoad final : public OpConversionPattern { matchAndRewrite(vector::LoadOp loadOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { VectorType vecTy = loadOp.getType(); - if (!vecTy || !llvm::all_of(vecTy.getShape().drop_back(1), - [](auto d) { return d == 1; })) + if (!vecTy) + return rewriter.notifyMatchFailure(loadOp, "expected vector type"); + + auto shape = vecTy.getShape(); + auto scalableDims = vecTy.getScalableDims(); + // All but the last dim must be 1, and only the last dim may be scalable (if + // any). + if (!llvm::all_of(shape.drop_back(1), [](auto d) { return d == 1; })) return rewriter.notifyMatchFailure(loadOp, "only vector<1x1x...xN> supported"); - auto linearTy = typeConverter->convertType(loadOp.getType()); + + if (llvm::any_of(scalableDims.drop_back(1), [](bool s) { return s; })) + return rewriter.notifyMatchFailure(loadOp, + "only innermost dim may be scalable"); + + auto linearTy = typeConverter->convertType(vecTy); + auto newLoad = rewriter.create( loadOp.getLoc(), linearTy, adaptor.getBase(), adaptor.getIndices()); rewriter.replaceOp(loadOp, newLoad.getResult()); @@ -672,10 +684,21 @@ struct LinearizeVectorStore final matchAndRewrite(vector::StoreOp storeOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { VectorType vecTy = storeOp.getValueToStore().getType(); - if (!vecTy || !llvm::all_of(vecTy.getShape().drop_back(1), - [](auto d) { return d == 1; })) + if (!vecTy) + return rewriter.notifyMatchFailure(storeOp, "expected vector type"); + + auto shape = vecTy.getShape(); + auto scalableDims = vecTy.getScalableDims(); + // All but the last dim must be 1, and only the last dim may be scalable (if + // any). + if (!llvm::all_of(shape.drop_back(1), [](auto d) { return d == 1; })) return rewriter.notifyMatchFailure(storeOp, "only vector<1x1x...xN> supported"); + + if (llvm::any_of(scalableDims.drop_back(1), [](bool s) { return s; })) + return rewriter.notifyMatchFailure(storeOp, + "only innermost dim may be scalable"); + rewriter.replaceOpWithNewOp( storeOp, adaptor.getValueToStore(), adaptor.getBase(), adaptor.getIndices()); diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir index 9a017ceedcebe..11780abfc6141 100644 --- a/mlir/test/Dialect/Vector/linearize.mlir +++ b/mlir/test/Dialect/Vector/linearize.mlir @@ -487,3 +487,26 @@ func.func @linearize_vector_store(%arg0: memref<2x8xf32>, %arg1: vector<1x4xf32> vector.store %arg1, %arg0[%c0, %c0] : memref<2x8xf32>, vector<1x4xf32> return } + +// CHECK-LABEL: linearize_vector_load_scalable +// CHECK-SAME: (%[[ARG0:.*]]: memref<2x8xf32>) -> vector<1x[4]xf32> +func.func @linearize_vector_load_scalable(%arg0: memref<2x8xf32>) -> vector<1x[4]xf32> { + // CHECK: %[[CST0:.*]] = arith.constant 0 : index + // CHECK: %[[LOAD:.*]] = vector.load %[[ARG0]][%[[CST0]], %[[CST0]]] : memref<2x8xf32>, vector<[4]xf32> + // CHECK: %[[CAST:.*]] = vector.shape_cast %[[LOAD]] : vector<[4]xf32> to vector<1x[4]xf32> + // CHECK: return %[[CAST]] : vector<1x[4]xf32 + %c0 = arith.constant 0 : index + %0 = vector.load %arg0[%c0, %c0] : memref<2x8xf32>, vector<1x[4]xf32> + return %0 : vector<1x[4]xf32> +} + +// CHECK-LABEL: linearize_vector_store_scalable +// CHECK-SAME: (%[[ARG0:.*]]: memref<2x8xf32>, %[[ARG1:.*]]: vector<1x[4]xf32>) +func.func @linearize_vector_store_scalable(%arg0: memref<2x8xf32>, %arg1: vector<1x[4]xf32>) { + // CHECK: %[[CAST:.*]] = vector.shape_cast %arg1 : vector<1x[4]xf32> to vector<[4]xf32> + // CHECK: %[[CST0:.*]] = arith.constant 0 : index + // CHECK: vector.store %[[CAST]], %[[ARG0]][%[[CST0]], %[[CST0]]] : memref<2x8xf32>, vector<[4]xf32> + %c0 = arith.constant 0 : index + vector.store %arg1, %arg0[%c0, %c0] : memref<2x8xf32>, vector<1x[4]xf32> + return +} From 62fc4473a9feaddd4239c70cd349ced2fe220d60 Mon Sep 17 00:00:00 2001 From: nbpatel Date: Mon, 7 Jul 2025 20:58:34 +0000 Subject: [PATCH 5/6] Missing brace --- mlir/test/Dialect/Vector/linearize.mlir | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir index 11780abfc6141..1ad08b9387b08 100644 --- a/mlir/test/Dialect/Vector/linearize.mlir +++ b/mlir/test/Dialect/Vector/linearize.mlir @@ -494,7 +494,7 @@ func.func @linearize_vector_load_scalable(%arg0: memref<2x8xf32>) -> vector<1x[4 // CHECK: %[[CST0:.*]] = arith.constant 0 : index // CHECK: %[[LOAD:.*]] = vector.load %[[ARG0]][%[[CST0]], %[[CST0]]] : memref<2x8xf32>, vector<[4]xf32> // CHECK: %[[CAST:.*]] = vector.shape_cast %[[LOAD]] : vector<[4]xf32> to vector<1x[4]xf32> - // CHECK: return %[[CAST]] : vector<1x[4]xf32 + // CHECK: return %[[CAST]] : vector<1x[4]xf32> %c0 = arith.constant 0 : index %0 = vector.load %arg0[%c0, %c0] : memref<2x8xf32>, vector<1x[4]xf32> return %0 : vector<1x[4]xf32> From 351b0eeabba5fdbe015318ded44a019f8e6b1689 Mon Sep 17 00:00:00 2001 From: nbpatel Date: Fri, 11 Jul 2025 14:33:14 +0000 Subject: [PATCH 6/6] Add comments --- mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp index fa5f9bbd2dbf6..0ebb477038e86 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp @@ -630,6 +630,8 @@ struct LinearizeVectorCreateMask final /// is converted to: /// vector.load %arg0[%c0, %c0] : memref<1x4xf32>, vector<4xf32> /// vector.shape_cast %load_result : vector<4xf32> to vector<1x4xf32> +/// For generic cases, the vector unroll pass should be used to unroll the load +/// to vector<1x1x...xN> form and then linearized struct LinearizeVectorLoad final : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LinearizeVectorLoad(const TypeConverter &typeConverter, MLIRContext *context, @@ -673,6 +675,8 @@ struct LinearizeVectorLoad final : public OpConversionPattern { /// vector.shape_cast %arg0 : vector<1x4xf32> to vector<4xf32> /// vector.store %arg0, %arg1[%c0, %c0] /// : vector<4xf32>, memref<1x4xf32> +/// For generic cases, the vector unroll pass should be used to unroll the store +/// to vector<1x1x...xN> form and then linearized struct LinearizeVectorStore final : public OpConversionPattern { using OpConversionPattern::OpConversionPattern;