Skip to content

Commit 1ca7c99

Browse files
committed
[mlir][memref] Add better computeCollapsedLayoutMap support for unit collapse
1 parent 896575e commit 1ca7c99

File tree

2 files changed

+47
-1
lines changed

2 files changed

+47
-1
lines changed

mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2422,7 +2422,11 @@ computeCollapsedLayoutMap(MemRefType srcType,
24222422
ArrayRef<int64_t> ref = llvm::ArrayRef(reassoc);
24232423
while (srcShape[ref.back()] == 1 && ref.size() > 1)
24242424
ref = ref.drop_back();
2425-
if (ShapedType::isStatic(srcShape[ref.back()]) || ref.size() == 1) {
2425+
auto precedingRef = ref.drop_back();
2426+
bool allUnitPreceding = llvm::all_of(
2427+
precedingRef, [&srcShape](int idx) { return srcShape[idx] == 1; });
2428+
if (ShapedType::isStatic(srcShape[ref.back()]) || ref.size() == 1 ||
2429+
allUnitPreceding) {
24262430
resultStrides.push_back(srcStrides[ref.back()]);
24272431
} else {
24282432
// Dynamically-sized dims may turn out to be dims of size 1 at runtime, so
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
// RUN: mlir-opt %s | FileCheck %s
2+
3+
// CHECK-LABEL: test_collapse(
4+
func.func @test_collapse(%arg0: memref<1x?xf32, strided<[5, 1]>>) {
5+
%collapse_shape = memref.collapse_shape %arg0 [[0, 1]] : memref<1x?xf32, strided<[5, 1]>> into memref<?xf32, strided<[1]>>
6+
return
7+
}
8+
9+
// CHECK-LABEL: test_collapse_5d_middle_dynamic(
10+
func.func @test_collapse_5d_middle_dynamic(%arg0: memref<1x5x1x?x1xf32, strided<[540, 108, 18, 2, 1]>>) {
11+
%collapse_shape = memref.collapse_shape %arg0 [[0, 1, 2, 3, 4]]
12+
: memref<1x5x1x?x1xf32, strided<[540, 108, 18, 2, 1]>> into memref<?xf32, strided<[?]>>
13+
return
14+
}
15+
16+
// CHECK-LABEL: test_collapse_5d_mostly_units(
17+
func.func @test_collapse_5d_mostly_units(%arg0: memref<1x1x1x?x1xf32, strided<[320, 80, 16, 2, 1]>>) {
18+
%collapse_shape = memref.collapse_shape %arg0 [[0, 1, 2, 3, 4]]
19+
: memref<1x1x1x?x1xf32, strided<[320, 80, 16, 2, 1]>> into memref<?xf32, strided<[2]>>
20+
return
21+
}
22+
23+
// CHECK-LABEL: test_partial_collapse_6d(
24+
func.func @test_partial_collapse_6d(%arg0: memref<1x?x1x1x5x1xf32, strided<[3360, 420, 140, 35, 7, 1]>>) {
25+
%collapse_shape = memref.collapse_shape %arg0 [[0, 1, 2, 3], [4, 5]]
26+
: memref<1x?x1x1x5x1xf32, strided<[3360, 420, 140, 35, 7, 1]>> into memref<?x5xf32, strided<[420, 7]>>
27+
return
28+
}
29+
30+
// CHECK-LABEL: test_collapse_5d_grouped(
31+
func.func @test_collapse_5d_grouped(%arg0: memref<1x5x1x?x1xf32, strided<[540, 108, 18, 2, 1]>>) {
32+
%collapse_shape = memref.collapse_shape %arg0 [[0], [1, 2, 3, 4]]
33+
: memref<1x5x1x?x1xf32, strided<[540, 108, 18, 2, 1]>> into memref<1x?xf32, strided<[540, ?]>>
34+
return
35+
}
36+
37+
// CHECK-LABEL: test_collapse_all_units(
38+
func.func @test_collapse_all_units(%arg0: memref<1x1x1x1x1xf32, strided<[100, 50, 25, 10, 1]>>) {
39+
%collapse_shape = memref.collapse_shape %arg0 [[0, 1, 2, 3, 4]]
40+
: memref<1x1x1x1x1xf32, strided<[100, 50, 25, 10, 1]>> into memref<1xf32, strided<[100]>>
41+
return
42+
}

0 commit comments

Comments
 (0)