diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocation.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocation.cpp index a0a81d4add712..eb8055ea9aa79 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocation.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocation.cpp @@ -283,7 +283,9 @@ class BufferDeallocation : public BufferPlacementTransformationBase { if (!dominators.dominates(definingBlock, parentBlock) || (definingBlock == parentBlock && isa(value))) { toProcess.emplace_back(value, parentBlock); - valuesToFree.insert(value); + if (isa(value.getType())) { + valuesToFree.insert(value); + } } else if (visitedValues.insert(std::make_tuple(value, definingBlock)) .second) toProcess.emplace_back(value, definingBlock); diff --git a/mlir/test/Dialect/Bufferization/Transforms/buffer-deallocation.mlir b/mlir/test/Dialect/Bufferization/Transforms/buffer-deallocation.mlir index 3fbe3913c6549..7bcf32ef1441f 100644 --- a/mlir/test/Dialect/Bufferization/Transforms/buffer-deallocation.mlir +++ b/mlir/test/Dialect/Bufferization/Transforms/buffer-deallocation.mlir @@ -1271,6 +1271,27 @@ func.func @while_two_arg(%arg0: index) { // ----- +// CHECK-LABEL: func @while_fun +func.func @while_fun() { + %c0 = arith.constant 0 : i1 + %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref + %4 = scf.while (%arg1 = %c0) : (i1) -> (i1) { + scf.condition(%c0) %arg1 : i1 + } do { + ^bb0(%arg1: i1): + %7 = func.call @foo(%alloc_1, %arg1) : (memref, i1) -> (i1) + scf.yield %7#0: i1 + } + return +} + +func.func private @foo(%arg1: memref, %arg2: i1) -> (i1) { + return %arg2 : i1 +} + +// ----- + +// CHECK-LABEL: func @while_three_arg func.func @while_three_arg(%arg0: index) { // CHECK: %[[ALLOC:.*]] = memref.alloc %a = memref.alloc(%arg0) : memref