Skip to content

Commit a583616

Browse files
[mlir][vector] Fix error handling in VectorizationState::initState
This function used to create new ops even if the vectorization failed. Those ops were then folded away. This caused a failure of the GreedyPatternRewriter, which no longer terminated (each time the IR is modified => one more iteration). Differential Revision: https://reviews.llvm.org/D140286
1 parent 7ccbb4d commit a583616

File tree

2 files changed

+34
-2
lines changed

2 files changed

+34
-2
lines changed

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,9 @@ VectorizationState::initState(RewriterBase &rewriter, LinalgOp linalgOp,
179179
LLVM_DEBUG(llvm::interleaveComma(canonicalVecShape, llvm::dbgs()));
180180
LLVM_DEBUG(llvm::dbgs() << "\n");
181181

182+
if (ShapedType::isDynamicShape(canonicalVecShape))
183+
return failure();
184+
182185
// Initialize iteration space static sizes.
183186
initIterSpaceStaticSizes(linalgOp);
184187

@@ -187,8 +190,6 @@ VectorizationState::initState(RewriterBase &rewriter, LinalgOp linalgOp,
187190
if (failed(precomputeIterSpaceDynamicSizes(rewriter, linalgOp)))
188191
return failure();
189192

190-
if (ShapedType::isDynamicShape(canonicalVecShape))
191-
return failure();
192193
return success();
193194
}
194195

mlir/test/Dialect/Linalg/vectorization.mlir

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1750,3 +1750,34 @@ transform.sequence failures(propagate) {
17501750
transform.structured.masked_vectorize %0 vector_sizes [4, 8]
17511751
}
17521752

1753+
// -----
1754+
1755+
// This is a regression test. This IR cannot be vectorized, but
1756+
// structured.vectorize should nevertheless succeed.
1757+
1758+
#map = affine_map<(d0) -> (d0)>
1759+
// CHECK-LABEL: @not_vectorizable
1760+
func.func @not_vectorizable(%arg0: tensor<1x?xf32>, %arg1: index, %arg2: index, %arg3: index) -> tensor<1x128xf32> {
1761+
%0 = tensor.empty() : tensor<1x128xf32>
1762+
%1 = scf.for %arg5 = %arg2 to %arg1 step %arg3 iter_args(%arg6 = %0) -> (tensor<1x128xf32>) {
1763+
%extracted_slice = tensor.extract_slice %arg6[0, 0] [1, %arg1] [1, 1] : tensor<1x128xf32> to tensor<?xf32>
1764+
%expanded = tensor.expand_shape %extracted_slice [[0, 1]] : tensor<?xf32> into tensor<1x?xf32>
1765+
%extracted_slice_0 = tensor.extract_slice %arg0[0, %arg3] [1, %arg2] [1, 1] : tensor<1x?xf32> to tensor<?xf32>
1766+
%extracted_slice_1 = tensor.extract_slice %expanded[0, %arg3] [1, %arg2] [1, 1] : tensor<1x?xf32> to tensor<?xf32>
1767+
%2 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%extracted_slice_0 : tensor<?xf32>) outs(%extracted_slice_1 : tensor<?xf32>) {
1768+
^bb0(%in: f32, %out: f32):
1769+
%3 = arith.addf %in, %out : f32
1770+
linalg.yield %3 : f32
1771+
} -> tensor<?xf32>
1772+
%inserted_slice = tensor.insert_slice %2 into %expanded[0, %arg3] [1, %arg2] [1, 1] : tensor<?xf32> into tensor<1x?xf32>
1773+
%collapsed = tensor.collapse_shape %inserted_slice [[0, 1]] : tensor<1x?xf32> into tensor<?xf32>
1774+
%inserted_slice_2 = tensor.insert_slice %collapsed into %arg6[0, 0] [1, %arg1] [1, 1] : tensor<?xf32> into tensor<1x128xf32>
1775+
scf.yield %inserted_slice_2 : tensor<1x128xf32>
1776+
}
1777+
return %1 : tensor<1x128xf32>
1778+
}
1779+
transform.sequence failures(propagate) {
1780+
^bb0(%arg0: !pdl.operation):
1781+
%0 = transform.structured.match ops{["func.func"]} in %arg0
1782+
%1 = transform.structured.vectorize %0
1783+
}

0 commit comments

Comments
 (0)