Skip to content

Commit 834591e

Browse files
authored
[MLIR] [Vector] Linearization patterns for vector.load and vector.store (#145115)
This PR add inearizarion pattern for vector.load and vector.store. It is follow up PR to #143420 (comment)
1 parent 45fa0b2 commit 834591e

File tree

2 files changed

+135
-2
lines changed

2 files changed

+135
-2
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp

Lines changed: 89 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -674,6 +674,93 @@ struct LinearizeVectorCreateMask final
674674
}
675675
};
676676

677+
/// This pattern linearizes vector.load from vector<1x1x...xN> to vector<N>
678+
/// It currently supports linearization where all but the last dimension are 1
679+
/// The following,
680+
/// vector.load %arg0[%c0, %c0] : memref<1x4xf32>, vector<1x4xf32>
681+
/// is converted to:
682+
/// vector.load %arg0[%c0, %c0] : memref<1x4xf32>, vector<4xf32>
683+
/// vector.shape_cast %load_result : vector<4xf32> to vector<1x4xf32>
684+
/// For generic cases, the vector unroll pass should be used to unroll the load
685+
/// to vector<1x1x...xN> form and then linearized
686+
struct LinearizeVectorLoad final : public OpConversionPattern<vector::LoadOp> {
687+
using OpConversionPattern::OpConversionPattern;
688+
LinearizeVectorLoad(const TypeConverter &typeConverter, MLIRContext *context,
689+
PatternBenefit benefit = 1)
690+
: OpConversionPattern(typeConverter, context, benefit) {}
691+
692+
LogicalResult
693+
matchAndRewrite(vector::LoadOp loadOp, OpAdaptor adaptor,
694+
ConversionPatternRewriter &rewriter) const override {
695+
VectorType vecTy = loadOp.getType();
696+
if (!vecTy)
697+
return rewriter.notifyMatchFailure(loadOp, "expected vector type");
698+
699+
auto shape = vecTy.getShape();
700+
auto scalableDims = vecTy.getScalableDims();
701+
// All but the last dim must be 1, and only the last dim may be scalable (if
702+
// any).
703+
if (!llvm::all_of(shape.drop_back(1), [](auto d) { return d == 1; }))
704+
return rewriter.notifyMatchFailure(loadOp,
705+
"only vector<1x1x...xN> supported");
706+
707+
if (llvm::any_of(scalableDims.drop_back(1), [](bool s) { return s; }))
708+
return rewriter.notifyMatchFailure(loadOp,
709+
"only innermost dim may be scalable");
710+
711+
auto linearTy = typeConverter->convertType<VectorType>(vecTy);
712+
713+
auto newLoad = rewriter.create<vector::LoadOp>(
714+
loadOp.getLoc(), linearTy, adaptor.getBase(), adaptor.getIndices());
715+
rewriter.replaceOp(loadOp, newLoad.getResult());
716+
return success();
717+
}
718+
};
719+
720+
/// This pattern linearizes vector.store from vector<1x1x...xN> to vector<N>
721+
/// It currently supports linearization where all but the last dimension are 1
722+
/// The following,
723+
/// vector.store %arg0, %arg1[%c0, %c0]s
724+
/// : vector<1x4xf32>, memref<1x4xf32>
725+
/// is converted to:
726+
/// vector.shape_cast %arg0 : vector<1x4xf32> to vector<4xf32>
727+
/// vector.store %arg0, %arg1[%c0, %c0]
728+
/// : vector<4xf32>, memref<1x4xf32>
729+
/// For generic cases, the vector unroll pass should be used to unroll the store
730+
/// to vector<1x1x...xN> form and then linearized
731+
struct LinearizeVectorStore final
732+
: public OpConversionPattern<vector::StoreOp> {
733+
using OpConversionPattern::OpConversionPattern;
734+
LinearizeVectorStore(const TypeConverter &typeConverter, MLIRContext *context,
735+
PatternBenefit benefit = 1)
736+
: OpConversionPattern(typeConverter, context, benefit) {}
737+
738+
LogicalResult
739+
matchAndRewrite(vector::StoreOp storeOp, OpAdaptor adaptor,
740+
ConversionPatternRewriter &rewriter) const override {
741+
VectorType vecTy = storeOp.getValueToStore().getType();
742+
if (!vecTy)
743+
return rewriter.notifyMatchFailure(storeOp, "expected vector type");
744+
745+
auto shape = vecTy.getShape();
746+
auto scalableDims = vecTy.getScalableDims();
747+
// All but the last dim must be 1, and only the last dim may be scalable (if
748+
// any).
749+
if (!llvm::all_of(shape.drop_back(1), [](auto d) { return d == 1; }))
750+
return rewriter.notifyMatchFailure(storeOp,
751+
"only vector<1x1x...xN> supported");
752+
753+
if (llvm::any_of(scalableDims.drop_back(1), [](bool s) { return s; }))
754+
return rewriter.notifyMatchFailure(storeOp,
755+
"only innermost dim may be scalable");
756+
757+
rewriter.replaceOpWithNewOp<vector::StoreOp>(
758+
storeOp, adaptor.getValueToStore(), adaptor.getBase(),
759+
adaptor.getIndices());
760+
return success();
761+
}
762+
};
763+
677764
} // namespace
678765

679766
/// This method defines the set of operations that are linearizable, and hence
@@ -765,8 +852,8 @@ void mlir::vector::populateVectorLinearizeBasePatterns(
765852
RewritePatternSet &patterns) {
766853
patterns
767854
.add<LinearizeConstantLike, LinearizeVectorizable, LinearizeVectorBitCast,
768-
LinearizeVectorSplat, LinearizeVectorCreateMask>(
769-
typeConverter, patterns.getContext());
855+
LinearizeVectorSplat, LinearizeVectorCreateMask, LinearizeVectorLoad,
856+
LinearizeVectorStore>(typeConverter, patterns.getContext());
770857
}
771858

772859
void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns(

mlir/test/Dialect/Vector/linearize.mlir

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -478,3 +478,49 @@ func.func @linearize_scalable_create_mask(%arg0 : index, %arg1 : index) -> vecto
478478
%0 = vector.create_mask %arg0, %arg1 : vector<1x[16]xi1>
479479
return %0 : vector<1x[16]xi1>
480480
}
481+
482+
// CHECK-LABEL: linearize_vector_load
483+
// CHECK-SAME: (%[[ARG0:.*]]: memref<2x8xf32>) -> vector<1x4xf32>
484+
func.func @linearize_vector_load(%arg0: memref<2x8xf32>) -> vector<1x4xf32> {
485+
// CHECK: %[[CST0:.*]] = arith.constant 0 : index
486+
// CHECK: %[[LOAD:.*]] = vector.load %[[ARG0]][%[[CST0]], %[[CST0]]] : memref<2x8xf32>, vector<4xf32>
487+
// CHECK: %[[CAST:.*]] = vector.shape_cast %[[LOAD]] : vector<4xf32> to vector<1x4xf32>
488+
// CHECK: return %[[CAST]] : vector<1x4xf32>
489+
%c0 = arith.constant 0 : index
490+
%0 = vector.load %arg0[%c0, %c0] : memref<2x8xf32>, vector<1x4xf32>
491+
return %0 : vector<1x4xf32>
492+
}
493+
494+
// CHECK-LABEL: linearize_vector_store
495+
// CHECK-SAME: (%[[ARG0:.*]]: memref<2x8xf32>, %[[ARG1:.*]]: vector<1x4xf32>)
496+
func.func @linearize_vector_store(%arg0: memref<2x8xf32>, %arg1: vector<1x4xf32>) {
497+
// CHECK: %[[CAST:.*]] = vector.shape_cast %arg1 : vector<1x4xf32> to vector<4xf32>
498+
// CHECK: %[[CST0:.*]] = arith.constant 0 : index
499+
// CHECK: vector.store %[[CAST]], %[[ARG0]][%[[CST0]], %[[CST0]]] : memref<2x8xf32>, vector<4xf32>
500+
%c0 = arith.constant 0 : index
501+
vector.store %arg1, %arg0[%c0, %c0] : memref<2x8xf32>, vector<1x4xf32>
502+
return
503+
}
504+
505+
// CHECK-LABEL: linearize_vector_load_scalable
506+
// CHECK-SAME: (%[[ARG0:.*]]: memref<2x8xf32>) -> vector<1x[4]xf32>
507+
func.func @linearize_vector_load_scalable(%arg0: memref<2x8xf32>) -> vector<1x[4]xf32> {
508+
// CHECK: %[[CST0:.*]] = arith.constant 0 : index
509+
// CHECK: %[[LOAD:.*]] = vector.load %[[ARG0]][%[[CST0]], %[[CST0]]] : memref<2x8xf32>, vector<[4]xf32>
510+
// CHECK: %[[CAST:.*]] = vector.shape_cast %[[LOAD]] : vector<[4]xf32> to vector<1x[4]xf32>
511+
// CHECK: return %[[CAST]] : vector<1x[4]xf32>
512+
%c0 = arith.constant 0 : index
513+
%0 = vector.load %arg0[%c0, %c0] : memref<2x8xf32>, vector<1x[4]xf32>
514+
return %0 : vector<1x[4]xf32>
515+
}
516+
517+
// CHECK-LABEL: linearize_vector_store_scalable
518+
// CHECK-SAME: (%[[ARG0:.*]]: memref<2x8xf32>, %[[ARG1:.*]]: vector<1x[4]xf32>)
519+
func.func @linearize_vector_store_scalable(%arg0: memref<2x8xf32>, %arg1: vector<1x[4]xf32>) {
520+
// CHECK: %[[CAST:.*]] = vector.shape_cast %arg1 : vector<1x[4]xf32> to vector<[4]xf32>
521+
// CHECK: %[[CST0:.*]] = arith.constant 0 : index
522+
// CHECK: vector.store %[[CAST]], %[[ARG0]][%[[CST0]], %[[CST0]]] : memref<2x8xf32>, vector<[4]xf32>
523+
%c0 = arith.constant 0 : index
524+
vector.store %arg1, %arg0[%c0, %c0] : memref<2x8xf32>, vector<1x[4]xf32>
525+
return
526+
}

0 commit comments

Comments
 (0)