From 1ca7c9954553ee30e9e4b824f9c6bd90bd1593fe Mon Sep 17 00:00:00 2001 From: Hocky Yudhiono Date: Thu, 10 Jul 2025 21:48:52 +0800 Subject: [PATCH] [mlir][memref] Add better computeCollapsedLayoutMap support for unit collapse --- mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp | 6 ++- .../test/Dialect/MemRef/collapse-strided.mlir | 42 +++++++++++++++++++ 2 files changed, 47 insertions(+), 1 deletion(-) create mode 100644 mlir/test/Dialect/MemRef/collapse-strided.mlir diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp index d1a9920aa66c5..ac8451ba0c45c 100644 --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -2422,7 +2422,11 @@ computeCollapsedLayoutMap(MemRefType srcType, ArrayRef ref = llvm::ArrayRef(reassoc); while (srcShape[ref.back()] == 1 && ref.size() > 1) ref = ref.drop_back(); - if (ShapedType::isStatic(srcShape[ref.back()]) || ref.size() == 1) { + auto precedingRef = ref.drop_back(); + bool allUnitPreceding = llvm::all_of( + precedingRef, [&srcShape](int idx) { return srcShape[idx] == 1; }); + if (ShapedType::isStatic(srcShape[ref.back()]) || ref.size() == 1 || + allUnitPreceding) { resultStrides.push_back(srcStrides[ref.back()]); } else { // Dynamically-sized dims may turn out to be dims of size 1 at runtime, so diff --git a/mlir/test/Dialect/MemRef/collapse-strided.mlir b/mlir/test/Dialect/MemRef/collapse-strided.mlir new file mode 100644 index 0000000000000..c6c624aba7f2a --- /dev/null +++ b/mlir/test/Dialect/MemRef/collapse-strided.mlir @@ -0,0 +1,42 @@ +// RUN: mlir-opt %s | FileCheck %s + +// CHECK-LABEL: test_collapse( +func.func @test_collapse(%arg0: memref<1x?xf32, strided<[5, 1]>>) { + %collapse_shape = memref.collapse_shape %arg0 [[0, 1]] : memref<1x?xf32, strided<[5, 1]>> into memref> + return +} + +// CHECK-LABEL: test_collapse_5d_middle_dynamic( +func.func @test_collapse_5d_middle_dynamic(%arg0: memref<1x5x1x?x1xf32, strided<[540, 108, 18, 2, 1]>>) { + %collapse_shape = memref.collapse_shape %arg0 [[0, 1, 2, 3, 4]] + : memref<1x5x1x?x1xf32, strided<[540, 108, 18, 2, 1]>> into memref> + return +} + +// CHECK-LABEL: test_collapse_5d_mostly_units( +func.func @test_collapse_5d_mostly_units(%arg0: memref<1x1x1x?x1xf32, strided<[320, 80, 16, 2, 1]>>) { + %collapse_shape = memref.collapse_shape %arg0 [[0, 1, 2, 3, 4]] + : memref<1x1x1x?x1xf32, strided<[320, 80, 16, 2, 1]>> into memref> + return +} + +// CHECK-LABEL: test_partial_collapse_6d( +func.func @test_partial_collapse_6d(%arg0: memref<1x?x1x1x5x1xf32, strided<[3360, 420, 140, 35, 7, 1]>>) { + %collapse_shape = memref.collapse_shape %arg0 [[0, 1, 2, 3], [4, 5]] + : memref<1x?x1x1x5x1xf32, strided<[3360, 420, 140, 35, 7, 1]>> into memref> + return +} + +// CHECK-LABEL: test_collapse_5d_grouped( +func.func @test_collapse_5d_grouped(%arg0: memref<1x5x1x?x1xf32, strided<[540, 108, 18, 2, 1]>>) { + %collapse_shape = memref.collapse_shape %arg0 [[0], [1, 2, 3, 4]] + : memref<1x5x1x?x1xf32, strided<[540, 108, 18, 2, 1]>> into memref<1x?xf32, strided<[540, ?]>> + return +} + +// CHECK-LABEL: test_collapse_all_units( +func.func @test_collapse_all_units(%arg0: memref<1x1x1x1x1xf32, strided<[100, 50, 25, 10, 1]>>) { + %collapse_shape = memref.collapse_shape %arg0 [[0, 1, 2, 3, 4]] + : memref<1x1x1x1x1xf32, strided<[100, 50, 25, 10, 1]>> into memref<1xf32, strided<[100]>> + return +} \ No newline at end of file