Skip to content

Commit d8fcf82

Browse files
bjacoblialan
authored andcommitted
Revert "[mlir][bufferization] Use Type instead of Value in unknown conversion (llvm#144658)"
This reverts commit a1c2a71.
1 parent c6fef1e commit d8fcf82

File tree

4 files changed

+19
-19
lines changed

4 files changed

+19
-19
lines changed

mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -265,9 +265,9 @@ struct BufferizationOptions {
265265
std::function<BaseMemRefType(TensorType, Attribute memorySpace,
266266
func::FuncOp, const BufferizationOptions &)>;
267267
/// Tensor -> MemRef type converter.
268-
/// Parameters: tensor type, memory space, bufferization options
268+
/// Parameters: Value, memory space, bufferization options
269269
using UnknownTypeConverterFn = std::function<BaseMemRefType(
270-
TensorType, Attribute memorySpace, const BufferizationOptions &)>;
270+
Value, Attribute memorySpace, const BufferizationOptions &)>;
271271
// Produce a MemorySpace attribute from a tensor type
272272
using DefaultMemorySpaceFn =
273273
std::function<std::optional<Attribute>(TensorType t)>;
@@ -655,7 +655,7 @@ OpTy replaceOpWithNewBufferizedOp(RewriterBase &rewriter, Operation *op,
655655
return newOp;
656656
}
657657

658-
/// Return a MemRefType to which the TensorType can be bufferized.
658+
/// Return a MemRefType to which the type of the given value can be bufferized.
659659
///
660660
/// If possible, op bufferization implementations should not use this function
661661
/// and instead infer precise memref types for tensor results by themselves.
@@ -667,8 +667,7 @@ OpTy replaceOpWithNewBufferizedOp(RewriterBase &rewriter, Operation *op,
667667
/// Note: Canonicalization patterns could clean up layout maps and infer more
668668
/// precise layout maps after bufferization. However, many possible
669669
/// canonicalizations are currently not implemented.
670-
BaseMemRefType getMemRefType(TensorType tensorType,
671-
const BufferizationOptions &options,
670+
BaseMemRefType getMemRefType(Value value, const BufferizationOptions &options,
672671
MemRefLayoutAttrInterface layout = {},
673672
Attribute memorySpace = nullptr);
674673

mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -347,9 +347,10 @@ defaultFunctionArgTypeConverter(TensorType type, Attribute memorySpace,
347347
}
348348
/// Default unknown type converter: Use a fully dynamic layout map.
349349
BaseMemRefType
350-
defaultUnknownTypeConverter(TensorType tensorType, Attribute memorySpace,
350+
defaultUnknownTypeConverter(Value value, Attribute memorySpace,
351351
const BufferizationOptions &options) {
352-
return getMemRefTypeWithFullyDynamicLayout(tensorType, memorySpace);
352+
return getMemRefTypeWithFullyDynamicLayout(
353+
llvm::cast<TensorType>(value.getType()), memorySpace);
353354
}
354355

355356
} // namespace
@@ -725,8 +726,7 @@ bufferization::getBufferType(Value value, const BufferizationOptions &options,
725726
if (!memSpace.has_value())
726727
return op->emitError("could not infer memory space");
727728

728-
return getMemRefType(cast<TensorType>(value.getType()), options,
729-
/*layout=*/{}, *memSpace);
729+
return getMemRefType(value, options, /*layout=*/{}, *memSpace);
730730
}
731731

732732
bool bufferization::hasTensorSemantics(Operation *op) {
@@ -799,10 +799,12 @@ LogicalResult BufferizationOptions::createMemCpy(OpBuilder &b, Location loc,
799799
// Bufferization-specific IRMapping support with debugging.
800800
//===----------------------------------------------------------------------===//
801801

802-
BaseMemRefType bufferization::getMemRefType(TensorType tensorType,
802+
BaseMemRefType bufferization::getMemRefType(Value value,
803803
const BufferizationOptions &options,
804804
MemRefLayoutAttrInterface layout,
805805
Attribute memorySpace) {
806+
auto tensorType = llvm::cast<TensorType>(value.getType());
807+
806808
// Case 1: Unranked memref type.
807809
if (auto unrankedTensorType =
808810
llvm::dyn_cast<UnrankedTensorType>(tensorType)) {
@@ -819,7 +821,7 @@ BaseMemRefType bufferization::getMemRefType(TensorType tensorType,
819821
memorySpace);
820822
}
821823

822-
return options.unknownTypeConverterFn(tensorType, memorySpace, options);
824+
return options.unknownTypeConverterFn(value, memorySpace, options);
823825
}
824826

825827
BaseMemRefType
@@ -955,11 +957,10 @@ FailureOr<BaseMemRefType> bufferization::detail::defaultGetBufferType(
955957
const BufferizationState &bufferizationState,
956958
SmallVector<Value> &invocationStack) {
957959
assert(llvm::isa<TensorType>(value.getType()) && "expected tensor type");
958-
auto tensorType = cast<TensorType>(value.getType());
959960

960961
// No further analysis is possible for a block argument.
961962
if (llvm::isa<BlockArgument>(value))
962-
return bufferization::getMemRefType(tensorType, options);
963+
return bufferization::getMemRefType(value, options);
963964

964965
// Value is an OpResult.
965966
Operation *op = getOwnerOfValue(value);
@@ -982,7 +983,7 @@ FailureOr<BaseMemRefType> bufferization::detail::defaultGetBufferType(
982983
if (!memSpace.has_value())
983984
return op->emitError("could not infer memory space");
984985

985-
return getMemRefType(tensorType, options, /*layout=*/{}, *memSpace);
986+
return getMemRefType(value, options, /*layout=*/{}, *memSpace);
986987
}
987988

988989
bool bufferization::detail::defaultIsRepetitiveRegion(

mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -109,9 +109,9 @@ struct OneShotBufferizePass
109109
"'unknown-type-conversion'");
110110
return signalPassFailure();
111111
}
112-
opt.unknownTypeConverterFn = [=](TensorType tensorType,
113-
Attribute memorySpace,
112+
opt.unknownTypeConverterFn = [=](Value value, Attribute memorySpace,
114113
const BufferizationOptions &options) {
114+
auto tensorType = cast<TensorType>(value.getType());
115115
if (unknownTypeConversionOption == LayoutMapOption::IdentityLayoutMap)
116116
return bufferization::getMemRefTypeWithStaticIdentityLayout(
117117
tensorType, memorySpace);

mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -223,10 +223,10 @@ mlir::getBufferizationOptionsForSparsification(bool analysisOnly) {
223223
OneShotBufferizationOptions options;
224224
options.bufferizeFunctionBoundaries = true;
225225
options.setFunctionBoundaryTypeConversion(LayoutMapOption::IdentityLayoutMap);
226-
options.unknownTypeConverterFn = [](TensorType tensorType,
227-
Attribute memorySpace,
226+
options.unknownTypeConverterFn = [](Value value, Attribute memorySpace,
228227
const BufferizationOptions &options) {
229-
return getMemRefTypeWithStaticIdentityLayout(tensorType, memorySpace);
228+
return getMemRefTypeWithStaticIdentityLayout(
229+
cast<TensorType>(value.getType()), memorySpace);
230230
};
231231
if (analysisOnly) {
232232
options.testAnalysisOnly = true;

0 commit comments

Comments
 (0)