diff --git a/mlir/lib/Dialect/MemRef/Transforms/ComposeSubView.cpp b/mlir/lib/Dialect/MemRef/Transforms/ComposeSubView.cpp index d25ddb41aa4eb..020d08ddda408 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/ComposeSubView.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/ComposeSubView.cpp @@ -81,10 +81,6 @@ struct ComposeSubViewOpPattern : public OpRewritePattern { for (auto &&[opOffset, sourceOffset, sourceStride, opSize] : llvm::zip(op.getMixedOffsets(), sourceOp.getMixedOffsets(), sourceOp.getMixedStrides(), op.getMixedSizes())) { - // We only support static sizes. - if (isa(opSize)) { - return failure(); - } sizes.push_back(opSize); Attribute opOffsetAttr = llvm::dyn_cast_if_present(opOffset), sourceOffsetAttr = diff --git a/mlir/test/Transforms/compose-subview.mlir b/mlir/test/Transforms/compose-subview.mlir index 53fbb8a356def..ec82562cd24f8 100644 --- a/mlir/test/Transforms/compose-subview.mlir +++ b/mlir/test/Transforms/compose-subview.mlir @@ -101,3 +101,29 @@ func.func @subview_strided(%input: memref<4x1024xf32>) -> memref<1x64xf32, strid %1 = memref.subview %0[%cst_1, 64] [1, 64] [2, 2] : memref<2x256xf32, strided<[2048, 2], offset: ?>> to memref<1x64xf32, strided<[4096, 4], offset: ?>> return %1 : memref<1x64xf32, strided<[4096, 4], offset: ?>> } + +// ----- + +// CHECK-LABEL: func.func @single_dynamic_size_subview( +// CHECK-SAME: %[[SRC:.*]]: memref<256x?xf32, strided<[?, ?], offset: ?>>, +// CHECK-SAME: %{{.*}}: index, +// CHECK-SAME: %[[SIZE:.*]]: index) -> memref<8x?xf32, strided<[?, ?], offset: ?>> { +func.func @single_dynamic_size_subview(%arg0: memref<256x?xf32, strided<[?, ?], offset: ?>>, %arg1 : index, %arg2 : index) -> memref<8x?xf32, strided<[?, ?], offset: ?>>{ + %subview = memref.subview %arg0[0, 0][8, %arg1][1, 1] : memref<256x?xf32, strided<[?, ?], offset: ?>> to memref<8x?xf32, strided<[?, ?], offset: ?>> + %subview_1 = memref.subview %subview[0, 0][8, %arg2][1, 1] : memref<8x?xf32, strided<[?, ?], offset: ?>> to memref<8x?xf32, strided<[?, ?], offset: ?>> + // CHECK: %{{.*}} = memref.subview %[[SRC]][0, 0] [8, %[[SIZE]]] [1, 1] : memref<256x?xf32, strided<[?, ?], offset: ?>> to memref<8x?xf32, strided<[?, ?], offset: ?>> + return %subview_1 : memref<8x?xf32, strided<[?, ?], offset: ?>> +} + +// ----- + +// CHECK-LABEL: func.func @all_dynamic_size_subview( +// CHECK-SAME: %[[SRC:.*]]: memref<256x?xf32, strided<[?, ?], offset: ?>>, +// CHECK-SAME: %{{.*}}: index, +// CHECK-SAME: %[[SIZE:.*]]: index) -> memref> { +func.func @all_dynamic_size_subview(%arg0: memref<256x?xf32, strided<[?, ?], offset: ?>>, %arg1 : index, %arg2 : index) -> memref>{ + %subview = memref.subview %arg0[0, 0][%arg1, %arg1][1, 1] : memref<256x?xf32, strided<[?, ?], offset: ?>> to memref> + %subview_1 = memref.subview %subview[0, 0][%arg2, %arg2][1, 1] : memref> to memref> + // CHECK: {{.*}} = memref.subview %[[SRC]][0, 0] {{\[}}%[[SIZE]], %[[SIZE]]] [1, 1] : memref<256x?xf32, strided<[?, ?], offset: ?>> to memref> + return %subview_1 : memref> +}