Skip to content

Commit 08f57cd

Browse files
committed
[mlir][memref] Add better computeCollapsedLayoutMap support for unit collapse
1 parent 896575e commit 08f57cd

File tree

2 files changed

+49
-1
lines changed

2 files changed

+49
-1
lines changed

mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2422,7 +2422,11 @@ computeCollapsedLayoutMap(MemRefType srcType,
24222422
ArrayRef<int64_t> ref = llvm::ArrayRef(reassoc);
24232423
while (srcShape[ref.back()] == 1 && ref.size() > 1)
24242424
ref = ref.drop_back();
2425-
if (ShapedType::isStatic(srcShape[ref.back()]) || ref.size() == 1) {
2425+
auto precedingRef = ref.drop_back();
2426+
bool allUnitPreceding = llvm::all_of(
2427+
precedingRef, [&srcShape](int idx) { return srcShape[idx] == 1; });
2428+
if (ShapedType::isStatic(srcShape[ref.back()]) || ref.size() == 1 ||
2429+
allUnitPreceding) {
24262430
resultStrides.push_back(srcStrides[ref.back()]);
24272431
} else {
24282432
// Dynamically-sized dims may turn out to be dims of size 1 at runtime, so
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
// RUN: mlir-opt %s | FileCheck %s
2+
3+
// CHECK-LABEL: test_collapse(
4+
func.func @test_collapse(%arg0: memref<20x5xf32>, %arg1: index) {
5+
%subview = memref.subview %arg0[0, 0] [1, %arg1] [1, 1] : memref<20x5xf32> to memref<1x?xf32, strided<[5, 1]>>
6+
%collapse_shape = memref.collapse_shape %subview [[0, 1]] : memref<1x?xf32, strided<[5, 1]>> into memref<?xf32, strided<[1]>>
7+
return
8+
}
9+
10+
// CHECK-LABEL: test_collapse_5d_middle_dynamic(
11+
func.func @test_collapse_5d_middle_dynamic(%arg0: memref<8x5x6x9x2xf32>, %arg1: index) {
12+
%subview = memref.subview %arg0[0, 0, 0, 0, 0] [1, 5, 1, %arg1, 1] [1, 1, 1, 1, 1]
13+
: memref<8x5x6x9x2xf32> to memref<1x5x1x?x1xf32, strided<[540, 108, 18, 2, 1]>>
14+
%collapse_shape = memref.collapse_shape %subview [[0, 1, 2, 3, 4]]
15+
: memref<1x5x1x?x1xf32, strided<[540, 108, 18, 2, 1]>> into memref<?xf32, strided<[?]>>
16+
return
17+
}
18+
19+
// CHECK-LABEL: test_collapse_5d_mostly_units(
20+
func.func @test_collapse_5d_mostly_units(%arg0: memref<3x4x5x8x2xf32>, %arg1: index) {
21+
%subview = memref.subview %arg0[0, 0, 0, 0, 0] [1, 1, 1, %arg1, 1] [1, 1, 1, 1, 1]
22+
: memref<3x4x5x8x2xf32> to memref<1x1x1x?x1xf32, strided<[320, 80, 16, 2, 1]>>
23+
%collapse_shape = memref.collapse_shape %subview [[0, 1, 2, 3, 4]]
24+
: memref<1x1x1x?x1xf32, strided<[320, 80, 16, 2, 1]>> into memref<?xf32, strided<[2]>>
25+
return
26+
}
27+
28+
// CHECK-LABEL: test_partial_collapse_6d(
29+
func.func @test_partial_collapse_6d(%arg0: memref<10x8x3x4x5x7xf32>, %arg1: index) {
30+
%subview = memref.subview %arg0[0, 0, 0, 0, 0, 0] [1, %arg1, 1, 1, 5, 1] [1, 1, 1, 1, 1, 1]
31+
: memref<10x8x3x4x5x7xf32> to memref<1x?x1x1x5x1xf32, strided<[3360, 420, 140, 35, 7, 1]>>
32+
%collapse_shape = memref.collapse_shape %subview [[0, 1, 2, 3], [4, 5]]
33+
: memref<1x?x1x1x5x1xf32, strided<[3360, 420, 140, 35, 7, 1]>> into memref<?x5xf32, strided<[420, 7]>>
34+
return
35+
}
36+
37+
// CHECK-LABEL: test_collapse_5d_grouped(
38+
func.func @test_collapse_5d_grouped(%arg0: memref<8x5x6x9x2xf32>, %arg1: index) {
39+
%subview = memref.subview %arg0[0, 0, 0, 0, 0] [1, 5, 1, %arg1, 1] [1, 1, 1, 1, 1]
40+
: memref<8x5x6x9x2xf32> to memref<1x5x1x?x1xf32, strided<[540, 108, 18, 2, 1]>>
41+
%collapse_shape = memref.collapse_shape %subview [[0], [1, 2, 3, 4]]
42+
: memref<1x5x1x?x1xf32, strided<[540, 108, 18, 2, 1]>> into memref<1x?xf32, strided<[540, ?]>>
43+
return
44+
}

0 commit comments

Comments
 (0)