@@ -347,9 +347,10 @@ defaultFunctionArgTypeConverter(TensorType type, Attribute memorySpace,
347
347
}
348
348
// / Default unknown type converter: Use a fully dynamic layout map.
349
349
BaseMemRefType
350
- defaultUnknownTypeConverter (TensorType tensorType , Attribute memorySpace,
350
+ defaultUnknownTypeConverter (Value value , Attribute memorySpace,
351
351
const BufferizationOptions &options) {
352
- return getMemRefTypeWithFullyDynamicLayout (tensorType, memorySpace);
352
+ return getMemRefTypeWithFullyDynamicLayout (
353
+ llvm::cast<TensorType>(value.getType ()), memorySpace);
353
354
}
354
355
355
356
} // namespace
@@ -725,8 +726,7 @@ bufferization::getBufferType(Value value, const BufferizationOptions &options,
725
726
if (!memSpace.has_value ())
726
727
return op->emitError (" could not infer memory space" );
727
728
728
- return getMemRefType (cast<TensorType>(value.getType ()), options,
729
- /* layout=*/ {}, *memSpace);
729
+ return getMemRefType (value, options, /* layout=*/ {}, *memSpace);
730
730
}
731
731
732
732
bool bufferization::hasTensorSemantics (Operation *op) {
@@ -799,10 +799,12 @@ LogicalResult BufferizationOptions::createMemCpy(OpBuilder &b, Location loc,
799
799
// Bufferization-specific IRMapping support with debugging.
800
800
// ===----------------------------------------------------------------------===//
801
801
802
- BaseMemRefType bufferization::getMemRefType (TensorType tensorType ,
802
+ BaseMemRefType bufferization::getMemRefType (Value value ,
803
803
const BufferizationOptions &options,
804
804
MemRefLayoutAttrInterface layout,
805
805
Attribute memorySpace) {
806
+ auto tensorType = llvm::cast<TensorType>(value.getType ());
807
+
806
808
// Case 1: Unranked memref type.
807
809
if (auto unrankedTensorType =
808
810
llvm::dyn_cast<UnrankedTensorType>(tensorType)) {
@@ -819,7 +821,7 @@ BaseMemRefType bufferization::getMemRefType(TensorType tensorType,
819
821
memorySpace);
820
822
}
821
823
822
- return options.unknownTypeConverterFn (tensorType , memorySpace, options);
824
+ return options.unknownTypeConverterFn (value , memorySpace, options);
823
825
}
824
826
825
827
BaseMemRefType
@@ -955,11 +957,10 @@ FailureOr<BaseMemRefType> bufferization::detail::defaultGetBufferType(
955
957
const BufferizationState &bufferizationState,
956
958
SmallVector<Value> &invocationStack) {
957
959
assert (llvm::isa<TensorType>(value.getType ()) && " expected tensor type" );
958
- auto tensorType = cast<TensorType>(value.getType ());
959
960
960
961
// No further analysis is possible for a block argument.
961
962
if (llvm::isa<BlockArgument>(value))
962
- return bufferization::getMemRefType (tensorType , options);
963
+ return bufferization::getMemRefType (value , options);
963
964
964
965
// Value is an OpResult.
965
966
Operation *op = getOwnerOfValue (value);
@@ -982,7 +983,7 @@ FailureOr<BaseMemRefType> bufferization::detail::defaultGetBufferType(
982
983
if (!memSpace.has_value ())
983
984
return op->emitError (" could not infer memory space" );
984
985
985
- return getMemRefType (tensorType , options, /* layout=*/ {}, *memSpace);
986
+ return getMemRefType (value , options, /* layout=*/ {}, *memSpace);
986
987
}
987
988
988
989
bool bufferization::detail::defaultIsRepetitiveRegion (
0 commit comments