Skip to content

Commit 6033544

Browse files
authored
[mlir][linalg] Fix memref type verification in CollapseLinalgDimensions (#147245)
When collapsing linalg dimensions we check if its memref operands are guaranteed to be collapsible. However, we currently assume that the matching indexing map is the identity map. This commit modifies this behavior and checks if the memref is collapsible on the transformed dimensions.
1 parent 755b8f9 commit 6033544

File tree

2 files changed

+60
-10
lines changed

2 files changed

+60
-10
lines changed

mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1717,26 +1717,30 @@ FailureOr<CollapseResult> mlir::linalg::collapseOpIterationDims(
17171717
}))
17181718
return failure();
17191719

1720+
CollapsingInfo collapsingInfo;
1721+
if (failed(
1722+
collapsingInfo.initialize(op.getNumLoops(), foldedIterationDims))) {
1723+
return rewriter.notifyMatchFailure(
1724+
op, "illegal to collapse specified dimensions");
1725+
}
1726+
17201727
bool hasPureBufferSemantics = op.hasPureBufferSemantics();
17211728
if (hasPureBufferSemantics &&
1722-
!llvm::all_of(op->getOperands(), [&](Value operand) -> bool {
1723-
MemRefType memRefToCollapse = dyn_cast<MemRefType>(operand.getType());
1729+
!llvm::all_of(op->getOpOperands(), [&](OpOperand &opOperand) -> bool {
1730+
MemRefType memRefToCollapse =
1731+
dyn_cast<MemRefType>(opOperand.get().getType());
17241732
if (!memRefToCollapse)
17251733
return true;
17261734

1735+
AffineMap indexingMap = op.getMatchingIndexingMap(&opOperand);
1736+
SmallVector<ReassociationIndices> operandReassociation =
1737+
getOperandReassociation(indexingMap, collapsingInfo);
17271738
return memref::CollapseShapeOp::isGuaranteedCollapsible(
1728-
memRefToCollapse, foldedIterationDims);
1739+
memRefToCollapse, operandReassociation);
17291740
}))
17301741
return rewriter.notifyMatchFailure(op,
17311742
"memref is not guaranteed collapsible");
17321743

1733-
CollapsingInfo collapsingInfo;
1734-
if (failed(
1735-
collapsingInfo.initialize(op.getNumLoops(), foldedIterationDims))) {
1736-
return rewriter.notifyMatchFailure(
1737-
op, "illegal to collapse specified dimensions");
1738-
}
1739-
17401744
// Bail on non-canonical ranges.
17411745
SmallVector<Range> loopRanges = op.createLoopRanges(rewriter, op.getLoc());
17421746
auto opFoldIsConstantValue = [](OpFoldResult ofr, int64_t value) {

mlir/test/Dialect/Linalg/collapse-dim.mlir

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,35 @@ func.func private @collapsable_memref(%arg0: memref<1x24x32x8xf32>, %arg1: memre
100100

101101
// -----
102102

103+
// CHECK-DAG: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0 * 7680 + d1 * 320 + d2 * 10 + d3)>
104+
// CHECK-DAG: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d0, d2, d1)>
105+
106+
// CHECK-LABEL: func.func @collapsable_memref_projected_ops(
107+
// CHECK-SAME: %[[ARG0:.*]]: memref<1x24x32x8xf32>, %[[ARG1:.*]]: memref<1x24x32x8xf32>, %[[ARG2:.*]]: memref<1x24x32x8xf32, #[[$ATTR_0]]>) {
108+
// CHECK: %[[VAL_0:.*]] = memref.collapse_shape %[[ARG0]] {{\[\[}}0], [1, 2], [3]] : memref<1x24x32x8xf32> into memref<1x768x8xf32>
109+
// CHECK: %[[VAL_1:.*]] = memref.collapse_shape %[[ARG1]] {{\[\[}}0], [1, 2], [3]] : memref<1x24x32x8xf32> into memref<1x768x8xf32>
110+
// CHECK: %[[VAL_2:.*]] = memref.collapse_shape %[[ARG2]] {{\[\[}}0], [1, 2], [3]] : memref<1x24x32x8xf32, #[[$ATTR_0]]> into memref<1x768x8xf32, strided<[7680, 10, 1]>>
111+
// CHECK: linalg.generic {indexing_maps = [#[[$ATTR_1]], #[[$ATTR_1]], #[[$ATTR_1]]], iterator_types = ["parallel", "parallel", "parallel"]} ins(%[[VAL_0]], %[[VAL_1]] : memref<1x768x8xf32>, memref<1x768x8xf32>) outs(%[[VAL_2]] : memref<1x768x8xf32, strided<[7680, 10, 1]>>) {
112+
// CHECK: ^bb0(%[[VAL_3:.*]]: f32, %[[VAL_4:.*]]: f32, %[[VAL_5:.*]]: f32):
113+
// CHECK: %[[VAL_6:.*]] = arith.addf %[[VAL_3]], %[[VAL_4]] : f32
114+
// CHECK: linalg.yield %[[VAL_6]] : f32
115+
// CHECK: }
116+
// CHECK: return
117+
// CHECK: }
118+
119+
#map = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3, d1)>
120+
#map1 = affine_map<(d0, d1, d2, d3) -> (d0 * 7680 + d1 * 320 + d2 * 10 + d3)>
121+
func.func @collapsable_memref_projected_ops(%arg0: memref<1x24x32x8xf32>, %arg1: memref<1x24x32x8xf32>, %arg2: memref<1x24x32x8xf32, #map1>) {
122+
linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg0, %arg1 : memref<1x24x32x8xf32>, memref<1x24x32x8xf32>) outs(%arg2 : memref<1x24x32x8xf32, #map1>) {
123+
^bb0(%in: f32, %in_0: f32, %out: f32):
124+
%0 = arith.addf %in, %in_0 : f32
125+
linalg.yield %0 : f32
126+
}
127+
return
128+
}
129+
130+
// -----
131+
103132
// CHECK-LABEL: func @uncollapsable_strided_memref(
104133
// CHECK: linalg.generic
105134
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel"]
@@ -119,6 +148,23 @@ func.func @uncollapsable_strided_memref(%arg0: memref<2x6x24x48xi32>, %arg1: mem
119148

120149
// -----
121150

151+
// CHECK-LABEL: func @uncollapsable_memref_projected_ops(
152+
// CHECK: linalg.generic
153+
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel"]
154+
155+
#map = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3, d1)>
156+
#map1 = affine_map<(d0, d1, d2, d3) -> (d0 * 7680 + d1 * 320 + d2 * 8 + d3)>
157+
func.func @uncollapsable_memref_projected_ops(%arg0: memref<1x24x32x8xf32>, %arg1: memref<1x24x32x8xf32>, %arg2: memref<1x24x32x8xf32, #map1>) {
158+
linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg0, %arg1 : memref<1x24x32x8xf32>, memref<1x24x32x8xf32>) outs(%arg2 : memref<1x24x32x8xf32, #map1>) {
159+
^bb0(%in: f32, %in_0: f32, %out: f32):
160+
%0 = arith.addf %in, %in_0 : f32
161+
linalg.yield %0 : f32
162+
}
163+
return
164+
}
165+
166+
// -----
167+
122168
// CHECK-LABEL: func.func @linalg_copy(
123169
// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x2x3x4x5xf32, 1 : i64>,
124170
// CHECK-SAME: %[[VAL_1:.*]]: tensor<1x2x3x4x5xf32, 3 : i64>) -> tensor<1x2x3x4x5xf32, 3 : i64> {

0 commit comments

Comments
 (0)