Skip to content

Commit ddf9b91

Browse files
authored
[mlir][Vector] Add vector.shuffle tree transformation (#145740)
This PR adds a new transformation that turns sequences of `vector.to_elements` and `vector.from_elements` into a binary tree of `vector.shuffle` operations. (Related RFC: https://discourse.llvm.org/t/rfc-adding-vector-to-elements-op-to-the-vector-dialect/86779). Example: ``` %0:4 = vector.to_elements %a : vector<4xf32> %1:4 = vector.to_elements %b : vector<4xf32> %2:4 = vector.to_elements %c : vector<4xf32> %3 = vector.from_elements %0#0, %0#1, %0#2, %0#3, %1#0, %1#1, %1#2, %1#3, %2#0, %2#1, %2#2, %2#3 : vector<12xf32> ==> %0 = vector.shuffle %a, %b [0, 1, 2, 3, 4, 5, 6, 7] : vector<4xf32>, vector<4xf32> %1 = vector.shuffle %c, %c [0, 1, 2, 3, -1, -1, -1, -1] : vector<4xf32>, vector<4xf32> %2 = vector.shuffle %0, %1 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11] : vector<8xf32>, vector<8xf32> ``` The algorithm leverages the structured extraction/insertion information of `vector.to_elements` and `vector.from_elements` operations and builds a set of intervals to determine the vector length that should be used at each level of the tree to combine the level inputs in pairs. There are a few improvements that can be implemented in the future, such as shuffle mask compression to avoid unnecessarily large vector lengths with poison values, but I decided to keep things "simpler" and spend more time documenting the different steps of the algorithm so that people can follow along.
1 parent 7f3afab commit ddf9b91

File tree

6 files changed

+1182
-0
lines changed

6 files changed

+1182
-0
lines changed

mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -297,6 +297,13 @@ void populateVectorBitCastLoweringPatterns(RewritePatternSet &patterns,
297297
/// n > 1.
298298
void populateVectorRankReducingFMAPattern(RewritePatternSet &patterns);
299299

300+
/// Populate patterns to rewrite sequences of `vector.to_elements` +
301+
/// `vector.from_elements` operations into a tree of `vector.shuffle`
302+
/// operations.
303+
void populateVectorToFromElementsToShuffleTreePatterns(
304+
RewritePatternSet &patterns, PatternBenefit benefit = 1);
305+
300306
} // namespace vector
301307
} // namespace mlir
308+
302309
#endif // MLIR_DIALECT_VECTOR_TRANSFORMS_LOWERINGPATTERNS_H

mlir/include/mlir/Dialect/Vector/Transforms/Passes.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#ifndef MLIR_DIALECT_VECTOR_TRANSFORMS_PASSES_H_
1010
#define MLIR_DIALECT_VECTOR_TRANSFORMS_PASSES_H_
1111

12+
#include "mlir/Dialect/Func/IR/FuncOps.h"
1213
#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
1314
#include "mlir/Pass/Pass.h"
1415

mlir/include/mlir/Dialect/Vector/Transforms/Passes.td

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,4 +34,9 @@ def LowerVectorMultiReduction : Pass<"lower-vector-multi-reduction", "func::Func
3434
];
3535
}
3636

37+
def LowerVectorToFromElementsToShuffleTree
38+
: Pass<"lower-vector-to-from-elements-to-shuffle-tree", "func::FuncOp"> {
39+
let summary = "Lower `vector.to_elements` and `vector.from_elements` to a tree of `vector.shuffle` operations";
40+
}
41+
3742
#endif // MLIR_DIALECT_VECTOR_TRANSFORMS_PASSES

mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ add_mlir_dialect_library(MLIRVectorTransforms
1010
LowerVectorScan.cpp
1111
LowerVectorShapeCast.cpp
1212
LowerVectorStep.cpp
13+
LowerVectorToFromElementsToShuffleTree.cpp
1314
LowerVectorTransfer.cpp
1415
LowerVectorTranspose.cpp
1516
SubsetOpInterfaceImpl.cpp

0 commit comments

Comments
 (0)