Skip to content

Commit 6512ca7

Browse files
authored
[mlir] Add isStatic* size check for ShapedTypes. NFCI. (#147085)
The motivation is to avoid having to negate `isDynamic*` checks, avoid double negations, and allow for `ShapedType::isStaticDim` to be used in ADT functions without having to wrap it in a lambda performing the negation. Also add the new functions to C and Python bindings.
1 parent 0032148 commit 6512ca7

37 files changed

+206
-118
lines changed

mlir/include/mlir-c/BuiltinTypes.h

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -289,9 +289,12 @@ MLIR_CAPI_EXPORTED int64_t mlirShapedTypeGetRank(MlirType type);
289289
/// Checks whether the given shaped type has a static shape.
290290
MLIR_CAPI_EXPORTED bool mlirShapedTypeHasStaticShape(MlirType type);
291291

292-
/// Checks wither the dim-th dimension of the given shaped type is dynamic.
292+
/// Checks whether the dim-th dimension of the given shaped type is dynamic.
293293
MLIR_CAPI_EXPORTED bool mlirShapedTypeIsDynamicDim(MlirType type, intptr_t dim);
294294

295+
/// Checks whether the dim-th dimension of the given shaped type is static.
296+
MLIR_CAPI_EXPORTED bool mlirShapedTypeIsStaticDim(MlirType type, intptr_t dim);
297+
295298
/// Returns the dim-th dimension of the given ranked shaped type.
296299
MLIR_CAPI_EXPORTED int64_t mlirShapedTypeGetDimSize(MlirType type,
297300
intptr_t dim);
@@ -300,17 +303,25 @@ MLIR_CAPI_EXPORTED int64_t mlirShapedTypeGetDimSize(MlirType type,
300303
/// in shaped types.
301304
MLIR_CAPI_EXPORTED bool mlirShapedTypeIsDynamicSize(int64_t size);
302305

306+
/// Checks whether the given shaped type dimension value is statically-sized.
307+
MLIR_CAPI_EXPORTED bool mlirShapedTypeIsStaticSize(int64_t size);
308+
303309
/// Returns the value indicating a dynamic size in a shaped type. Prefer
304-
/// mlirShapedTypeIsDynamicSize to direct comparisons with this value.
310+
/// mlirShapedTypeIsDynamicSize and mlirShapedTypeIsStaticSize to direct
311+
/// comparisons with this value.
305312
MLIR_CAPI_EXPORTED int64_t mlirShapedTypeGetDynamicSize(void);
306313

307314
/// Checks whether the given value is used as a placeholder for dynamic strides
308315
/// and offsets in shaped types.
309316
MLIR_CAPI_EXPORTED bool mlirShapedTypeIsDynamicStrideOrOffset(int64_t val);
310317

318+
/// Checks whether the given dimension value of a stride or an offset is
319+
/// statically-sized.
320+
MLIR_CAPI_EXPORTED bool mlirShapedTypeIsStaticStrideOrOffset(int64_t val);
321+
311322
/// Returns the value indicating a dynamic stride or offset in a shaped type.
312-
/// Prefer mlirShapedTypeGetDynamicStrideOrOffset to direct comparisons with
313-
/// this value.
323+
/// Prefer mlirShapedTypeIsDynamicStrideOrOffset and
324+
/// mlirShapedTypeIsStaticStrideOrOffset to direct comparisons with this value.
314325
MLIR_CAPI_EXPORTED int64_t mlirShapedTypeGetDynamicStrideOrOffset(void);
315326

316327
//===----------------------------------------------------------------------===//

mlir/include/mlir/IR/BuiltinTypeInterfaces.td

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def VectorElementTypeInterface : TypeInterface<"VectorElementTypeInterface"> {
3636
This may change in the future, for example, to require types to provide
3737
their size or alignment given a data layout. Please post an RFC before
3838
adding this interface to additional types. Implementing this interface on
39-
downstream types is discourged, until we specified the exact properties of
39+
downstream types is discouraged, until we specified the exact properties of
4040
a vector element type in more detail.
4141
}];
4242
}
@@ -221,7 +221,17 @@ def ShapedTypeInterface : TypeInterface<"ShapedType"> {
221221

222222
/// Whether the given shape has any size that indicates a dynamic dimension.
223223
static bool isDynamicShape(ArrayRef<int64_t> dSizes) {
224-
return any_of(dSizes, [](int64_t dSize) { return isDynamic(dSize); });
224+
return llvm::any_of(dSizes, isDynamic);
225+
}
226+
227+
/// Whether the given dimension size indicates a statically-sized dimension.
228+
static constexpr bool isStatic(int64_t dValue) {
229+
return dValue != kDynamic;
230+
}
231+
232+
/// Whether the given shape has static dimensions only.
233+
static bool isStaticShape(ArrayRef<int64_t> dSizes) {
234+
return llvm::all_of(dSizes, isStatic);
225235
}
226236

227237
/// Return the number of elements present in the given shape.
@@ -273,11 +283,18 @@ def ShapedTypeInterface : TypeInterface<"ShapedType"> {
273283
return ::mlir::ShapedType::isDynamic($_type.getShape()[idx]);
274284
}
275285

286+
/// Returns true if this dimension has a static size (for ranked types);
287+
/// aborts for unranked types.
288+
bool isStaticDim(unsigned idx) const {
289+
assert(idx < getRank() && "invalid index for shaped type");
290+
return ::mlir::ShapedType::isStatic($_type.getShape()[idx]);
291+
}
292+
276293
/// Returns if this type has a static shape, i.e. if the type is ranked and
277294
/// all dimensions have known size (>= 0).
278295
bool hasStaticShape() const {
279296
return $_type.hasRank() &&
280-
!::mlir::ShapedType::isDynamicShape($_type.getShape());
297+
::mlir::ShapedType::isStaticShape($_type.getShape());
281298
}
282299

283300
/// Returns if this type has a static shape and the shape is equal to

mlir/lib/Bindings/Python/IRTypes.cpp

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -544,6 +544,15 @@ void mlir::PyShapedType::bindDerived(ClassTy &c) {
544544
nb::arg("dim"),
545545
"Returns whether the dim-th dimension of the given shaped type is "
546546
"dynamic.");
547+
c.def(
548+
"is_static_dim",
549+
[](PyShapedType &self, intptr_t dim) -> bool {
550+
self.requireHasRank();
551+
return mlirShapedTypeIsStaticDim(self, dim);
552+
},
553+
nb::arg("dim"),
554+
"Returns whether the dim-th dimension of the given shaped type is "
555+
"static.");
547556
c.def(
548557
"get_dim_size",
549558
[](PyShapedType &self, intptr_t dim) {
@@ -558,6 +567,12 @@ void mlir::PyShapedType::bindDerived(ClassTy &c) {
558567
nb::arg("dim_size"),
559568
"Returns whether the given dimension size indicates a dynamic "
560569
"dimension.");
570+
c.def_static(
571+
"is_static_size",
572+
[](int64_t size) -> bool { return mlirShapedTypeIsStaticSize(size); },
573+
nb::arg("dim_size"),
574+
"Returns whether the given dimension size indicates a static "
575+
"dimension.");
561576
c.def(
562577
"is_dynamic_stride_or_offset",
563578
[](PyShapedType &self, int64_t val) -> bool {
@@ -567,6 +582,15 @@ void mlir::PyShapedType::bindDerived(ClassTy &c) {
567582
nb::arg("dim_size"),
568583
"Returns whether the given value is used as a placeholder for dynamic "
569584
"strides and offsets in shaped types.");
585+
c.def(
586+
"is_static_stride_or_offset",
587+
[](PyShapedType &self, int64_t val) -> bool {
588+
self.requireHasRank();
589+
return mlirShapedTypeIsStaticStrideOrOffset(val);
590+
},
591+
nb::arg("dim_size"),
592+
"Returns whether the given shaped type stride or offset value is "
593+
"statically-sized.");
570594
c.def_prop_ro(
571595
"shape",
572596
[](PyShapedType &self) {

mlir/lib/CAPI/IR/BuiltinTypes.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -332,6 +332,11 @@ bool mlirShapedTypeIsDynamicDim(MlirType type, intptr_t dim) {
332332
.isDynamicDim(static_cast<unsigned>(dim));
333333
}
334334

335+
bool mlirShapedTypeIsStaticDim(MlirType type, intptr_t dim) {
336+
return llvm::cast<ShapedType>(unwrap(type))
337+
.isStaticDim(static_cast<unsigned>(dim));
338+
}
339+
335340
int64_t mlirShapedTypeGetDimSize(MlirType type, intptr_t dim) {
336341
return llvm::cast<ShapedType>(unwrap(type))
337342
.getDimSize(static_cast<unsigned>(dim));
@@ -343,10 +348,18 @@ bool mlirShapedTypeIsDynamicSize(int64_t size) {
343348
return ShapedType::isDynamic(size);
344349
}
345350

351+
bool mlirShapedTypeIsStaticSize(int64_t size) {
352+
return ShapedType::isStatic(size);
353+
}
354+
346355
bool mlirShapedTypeIsDynamicStrideOrOffset(int64_t val) {
347356
return ShapedType::isDynamic(val);
348357
}
349358

359+
bool mlirShapedTypeIsStaticStrideOrOffset(int64_t val) {
360+
return ShapedType::isStatic(val);
361+
}
362+
350363
int64_t mlirShapedTypeGetDynamicStrideOrOffset() {
351364
return ShapedType::kDynamic;
352365
}

mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ MemRefDescriptor MemRefDescriptor::fromStaticShape(
5353

5454
// Extract all strides and offsets and verify they are static.
5555
auto [strides, offset] = type.getStridesAndOffset();
56-
assert(!ShapedType::isDynamic(offset) && "expected static offset");
56+
assert(ShapedType::isStatic(offset) && "expected static offset");
5757
assert(!llvm::any_of(strides, ShapedType::isDynamic) &&
5858
"expected static strides");
5959

mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -609,7 +609,7 @@ bool LLVMTypeConverter::canConvertToBarePtr(BaseMemRefType type) {
609609
if (ShapedType::isDynamic(stride))
610610
return false;
611611

612-
return !ShapedType::isDynamic(offset);
612+
return ShapedType::isStatic(offset);
613613
}
614614

615615
/// Convert a memref type to a bare pointer to the memref element type.

mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ static constexpr LLVM::GEPNoWrapFlags kNoWrapFlags =
4343
namespace {
4444

4545
static bool isStaticStrideOrOffset(int64_t strideOrOffset) {
46-
return !ShapedType::isDynamic(strideOrOffset);
46+
return ShapedType::isStatic(strideOrOffset);
4747
}
4848

4949
static FailureOr<LLVM::LLVMFuncOp>
@@ -1468,7 +1468,7 @@ struct MemRefReshapeOpLowering
14681468
Value stride = nullptr;
14691469
int64_t targetRank = targetMemRefType.getRank();
14701470
for (auto i : llvm::reverse(llvm::seq<int64_t>(0, targetRank))) {
1471-
if (!ShapedType::isDynamic(strides[i])) {
1471+
if (ShapedType::isStatic(strides[i])) {
14721472
// If the stride for this dimension is dynamic, then use the product
14731473
// of the sizes of the inner dimensions.
14741474
stride =
@@ -1722,7 +1722,7 @@ struct ViewOpLowering : public ConvertOpToLLVMPattern<memref::ViewOp> {
17221722
ArrayRef<int64_t> shape, ValueRange dynamicSizes, unsigned idx,
17231723
Type indexType) const {
17241724
assert(idx < shape.size());
1725-
if (!ShapedType::isDynamic(shape[idx]))
1725+
if (ShapedType::isStatic(shape[idx]))
17261726
return createIndexAttrConstant(rewriter, loc, indexType, shape[idx]);
17271727
// Count the number of dynamic dims in range [0, idx]
17281728
unsigned nDynamic =
@@ -1738,7 +1738,7 @@ struct ViewOpLowering : public ConvertOpToLLVMPattern<memref::ViewOp> {
17381738
ArrayRef<int64_t> strides, Value nextSize,
17391739
Value runningStride, unsigned idx, Type indexType) const {
17401740
assert(idx < strides.size());
1741-
if (!ShapedType::isDynamic(strides[idx]))
1741+
if (ShapedType::isStatic(strides[idx]))
17421742
return createIndexAttrConstant(rewriter, loc, indexType, strides[idx]);
17431743
if (nextSize)
17441744
return runningStride

mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -757,7 +757,7 @@ computeTargetSize(PatternRewriter &rewriter, Location loc, IndexPool &indexPool,
757757
// dimension greater than 1 with a different value is undefined behavior.
758758
for (auto operand : operands) {
759759
auto size = cast<RankedTensorType>(operand.getType()).getDimSize(dim);
760-
if (!ShapedType::isDynamic(size) && size > 1)
760+
if (ShapedType::isStatic(size) && size > 1)
761761
return {rewriter.getIndexAttr(size), operand};
762762
}
763763

mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ TensorType inferReshapeExpandedType(TensorType inputType,
8383
return totalSize / totalSizeNoPlaceholder;
8484
});
8585

86-
bool resultIsStatic = !ShapedType::isDynamicShape(resultShape);
86+
bool resultIsStatic = ShapedType::isStaticShape(resultShape);
8787

8888
// A syntactic restriction in 'tensor.expand_shape' forbids a dynamically
8989
// shaped input from being reshaped into a statically shaped result. We may
@@ -305,7 +305,7 @@ class SliceConverter : public OpConversionPattern<tosa::SliceOp> {
305305
int64_t size = i.value();
306306
size_t index = i.index();
307307
sizes.push_back(size == -1 ? ShapedType::kDynamic : size);
308-
if (!ShapedType::isDynamic(sizes.back()))
308+
if (ShapedType::isStatic(sizes.back()))
309309
continue;
310310

311311
auto dim = rewriter.create<tensor::DimOp>(loc, input, index);

mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ FailureOr<Value> mlir::bufferization::castOrReallocMemRefValue(
4444
failed(target.getStridesAndOffset(targetStrides, targetOffset)))
4545
return false;
4646
auto dynamicToStatic = [](int64_t a, int64_t b) {
47-
return ShapedType::isDynamic(a) && !ShapedType::isDynamic(b);
47+
return ShapedType::isDynamic(a) && ShapedType::isStatic(b);
4848
};
4949
if (dynamicToStatic(sourceOffset, targetOffset))
5050
return false;

0 commit comments

Comments
 (0)