@@ -674,6 +674,93 @@ struct LinearizeVectorCreateMask final
674
674
}
675
675
};
676
676
677
+ // / This pattern linearizes vector.load from vector<1x1x...xN> to vector<N>
678
+ // / It currently supports linearization where all but the last dimension are 1
679
+ // / The following,
680
+ // / vector.load %arg0[%c0, %c0] : memref<1x4xf32>, vector<1x4xf32>
681
+ // / is converted to:
682
+ // / vector.load %arg0[%c0, %c0] : memref<1x4xf32>, vector<4xf32>
683
+ // / vector.shape_cast %load_result : vector<4xf32> to vector<1x4xf32>
684
+ // / For generic cases, the vector unroll pass should be used to unroll the load
685
+ // / to vector<1x1x...xN> form and then linearized
686
+ struct LinearizeVectorLoad final : public OpConversionPattern<vector::LoadOp> {
687
+ using OpConversionPattern::OpConversionPattern;
688
+ LinearizeVectorLoad (const TypeConverter &typeConverter, MLIRContext *context,
689
+ PatternBenefit benefit = 1 )
690
+ : OpConversionPattern(typeConverter, context, benefit) {}
691
+
692
+ LogicalResult
693
+ matchAndRewrite (vector::LoadOp loadOp, OpAdaptor adaptor,
694
+ ConversionPatternRewriter &rewriter) const override {
695
+ VectorType vecTy = loadOp.getType ();
696
+ if (!vecTy)
697
+ return rewriter.notifyMatchFailure (loadOp, " expected vector type" );
698
+
699
+ auto shape = vecTy.getShape ();
700
+ auto scalableDims = vecTy.getScalableDims ();
701
+ // All but the last dim must be 1, and only the last dim may be scalable (if
702
+ // any).
703
+ if (!llvm::all_of (shape.drop_back (1 ), [](auto d) { return d == 1 ; }))
704
+ return rewriter.notifyMatchFailure (loadOp,
705
+ " only vector<1x1x...xN> supported" );
706
+
707
+ if (llvm::any_of (scalableDims.drop_back (1 ), [](bool s) { return s; }))
708
+ return rewriter.notifyMatchFailure (loadOp,
709
+ " only innermost dim may be scalable" );
710
+
711
+ auto linearTy = typeConverter->convertType <VectorType>(vecTy);
712
+
713
+ auto newLoad = rewriter.create <vector::LoadOp>(
714
+ loadOp.getLoc (), linearTy, adaptor.getBase (), adaptor.getIndices ());
715
+ rewriter.replaceOp (loadOp, newLoad.getResult ());
716
+ return success ();
717
+ }
718
+ };
719
+
720
+ // / This pattern linearizes vector.store from vector<1x1x...xN> to vector<N>
721
+ // / It currently supports linearization where all but the last dimension are 1
722
+ // / The following,
723
+ // / vector.store %arg0, %arg1[%c0, %c0]s
724
+ // / : vector<1x4xf32>, memref<1x4xf32>
725
+ // / is converted to:
726
+ // / vector.shape_cast %arg0 : vector<1x4xf32> to vector<4xf32>
727
+ // / vector.store %arg0, %arg1[%c0, %c0]
728
+ // / : vector<4xf32>, memref<1x4xf32>
729
+ // / For generic cases, the vector unroll pass should be used to unroll the store
730
+ // / to vector<1x1x...xN> form and then linearized
731
+ struct LinearizeVectorStore final
732
+ : public OpConversionPattern<vector::StoreOp> {
733
+ using OpConversionPattern::OpConversionPattern;
734
+ LinearizeVectorStore (const TypeConverter &typeConverter, MLIRContext *context,
735
+ PatternBenefit benefit = 1 )
736
+ : OpConversionPattern(typeConverter, context, benefit) {}
737
+
738
+ LogicalResult
739
+ matchAndRewrite (vector::StoreOp storeOp, OpAdaptor adaptor,
740
+ ConversionPatternRewriter &rewriter) const override {
741
+ VectorType vecTy = storeOp.getValueToStore ().getType ();
742
+ if (!vecTy)
743
+ return rewriter.notifyMatchFailure (storeOp, " expected vector type" );
744
+
745
+ auto shape = vecTy.getShape ();
746
+ auto scalableDims = vecTy.getScalableDims ();
747
+ // All but the last dim must be 1, and only the last dim may be scalable (if
748
+ // any).
749
+ if (!llvm::all_of (shape.drop_back (1 ), [](auto d) { return d == 1 ; }))
750
+ return rewriter.notifyMatchFailure (storeOp,
751
+ " only vector<1x1x...xN> supported" );
752
+
753
+ if (llvm::any_of (scalableDims.drop_back (1 ), [](bool s) { return s; }))
754
+ return rewriter.notifyMatchFailure (storeOp,
755
+ " only innermost dim may be scalable" );
756
+
757
+ rewriter.replaceOpWithNewOp <vector::StoreOp>(
758
+ storeOp, adaptor.getValueToStore (), adaptor.getBase (),
759
+ adaptor.getIndices ());
760
+ return success ();
761
+ }
762
+ };
763
+
677
764
} // namespace
678
765
679
766
// / This method defines the set of operations that are linearizable, and hence
@@ -765,8 +852,8 @@ void mlir::vector::populateVectorLinearizeBasePatterns(
765
852
RewritePatternSet &patterns) {
766
853
patterns
767
854
.add <LinearizeConstantLike, LinearizeVectorizable, LinearizeVectorBitCast,
768
- LinearizeVectorSplat, LinearizeVectorCreateMask>(
769
- typeConverter, patterns.getContext ());
855
+ LinearizeVectorSplat, LinearizeVectorCreateMask, LinearizeVectorLoad,
856
+ LinearizeVectorStore>( typeConverter, patterns.getContext ());
770
857
}
771
858
772
859
void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns (
0 commit comments