Skip to content

Commit fdb9e6a

Browse files
[mlir][bufferization] Fix crash in EmptyTensorElimination
Differential Revision: https://reviews.llvm.org/D144389
1 parent 2502dc8 commit fdb9e6a

File tree

3 files changed

+20
-4
lines changed

3 files changed

+20
-4
lines changed

mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -482,11 +482,17 @@ llvm::SetVector<Value> AnalysisState::findValueInReverseUseDefChain(
482482

483483
while (!workingSet.empty()) {
484484
Value value = workingSet.pop_back_val();
485-
if (condition(value) || value.isa<BlockArgument>()) {
485+
if (condition(value)) {
486486
result.insert(value);
487487
continue;
488488
}
489489

490+
if (value.isa<BlockArgument>()) {
491+
if (alwaysIncludeLeaves)
492+
result.insert(value);
493+
continue;
494+
}
495+
490496
OpResult opResult = value.cast<OpResult>();
491497
BufferizableOpInterface bufferizableOp =
492498
options.dynCastBufferizableOp(opResult.getDefiningOp());

mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -220,4 +220,14 @@ func.func @regression_do_not_eliminate_non_empty(
220220
%2 = tensor.insert_slice %1 into %t2[1] [5] [1]
221221
: tensor<5xf32> into tensor<10xf32>
222222
return %2 : tensor<10xf32>
223-
}
223+
}
224+
225+
// -----
226+
227+
// This is a regression test. Make sure that there is no crash.
228+
229+
// CHECK-LABEL: func.func @regression_insert_of_bbarg(
230+
func.func @regression_insert_of_bbarg(%t0: tensor<5xf32>, %t1: tensor<10xf32>) -> tensor<10xf32> {
231+
%0 = tensor.insert_slice %t0 into %t1 [2] [5] [1] : tensor<5xf32> into tensor<10xf32>
232+
return %0 : tensor<10xf32>
233+
}

mlir/test/Dialect/SCF/one-shot-bufferize-analysis.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -707,11 +707,11 @@ func.func @read_of_bbarg_in_repetitive_region(
707707
scf.for %iv = %a to %b step %c {
708708
// Must bufferize out-of-place because definition of read is in a different
709709
// repetitive region.
710-
// CHECK: tensor.extract_slice {{.*}} {__inplace_operands_attr__ = ["false"]}
710+
// CHECK: tensor.extract_slice {{.*}} {__inplace_operands_attr__ = ["true"]}
711711
%2 = tensor.extract_slice %t[0][4][1] : tensor<10xf32> to tensor<4xf32>
712712
%3 = tensor.extract %2[%a] : tensor<4xf32>
713713
vector.print %3 : f32
714-
// CHECK: tensor.insert {{.*}} {__inplace_operands_attr__ = ["none", "true", "none"]}
714+
// CHECK: tensor.insert {{.*}} {__inplace_operands_attr__ = ["none", "false", "none"]}
715715
%4 = tensor.insert %cst into %2[%a] : tensor<4xf32>
716716
%5 = tensor.extract %4[%a] : tensor<4xf32>
717717
vector.print %5 : f32

0 commit comments

Comments
 (0)