Skip to content

Commit 6c867e2

Browse files
[mlir] Use getSingleElement/hasSingleElement in various places (llvm#131460)
This is a code cleanup. Update a few places in MLIR that should use `hasSingleElement`/`getSingleElement`. Note: `hasSingleElement` is faster than `.getSize() == 1` when it is used with linked lists etc. Depends on llvm#131508.
1 parent f402953 commit 6c867e2

File tree

22 files changed

+48
-91
lines changed

22 files changed

+48
-91
lines changed

mlir/include/mlir/Dialect/CommonFolders.h

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -196,8 +196,7 @@ template <class AttrElementT,
196196
function_ref<std::optional<ElementValueT>(ElementValueT)>>
197197
Attribute constFoldUnaryOpConditional(ArrayRef<Attribute> operands,
198198
CalculationT &&calculate) {
199-
assert(operands.size() == 1 && "unary op takes one operands");
200-
if (!operands[0])
199+
if (!llvm::getSingleElement(operands))
201200
return {};
202201

203202
static_assert(
@@ -268,8 +267,7 @@ template <
268267
class CalculationT = function_ref<TargetElementValueT(ElementValueT, bool)>>
269268
Attribute constFoldCastOp(ArrayRef<Attribute> operands, Type resType,
270269
CalculationT &&calculate) {
271-
assert(operands.size() == 1 && "Cast op takes one operand");
272-
if (!operands[0])
270+
if (!llvm::getSingleElement(operands))
273271
return {};
274272

275273
static_assert(

mlir/lib/Analysis/SliceAnalysis.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ static void getBackwardSliceImpl(Operation *op,
107107
// into us. For now, just bail.
108108
if (parentOp && backwardSlice->count(parentOp) == 0) {
109109
assert(parentOp->getNumRegions() == 1 &&
110-
parentOp->getRegion(0).getBlocks().size() == 1);
110+
llvm::hasSingleElement(parentOp->getRegion(0).getBlocks()));
111111
getBackwardSliceImpl(parentOp, backwardSlice, options);
112112
}
113113
} else {

mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -834,8 +834,7 @@ struct TypeCastingOpPattern final : public OpConversionPattern<Op> {
834834
LogicalResult
835835
matchAndRewrite(Op op, typename Op::Adaptor adaptor,
836836
ConversionPatternRewriter &rewriter) const override {
837-
assert(adaptor.getOperands().size() == 1);
838-
Type srcType = adaptor.getOperands().front().getType();
837+
Type srcType = llvm::getSingleElement(adaptor.getOperands()).getType();
839838
Type dstType = this->getTypeConverter()->convertType(op.getType());
840839
if (!dstType)
841840
return getTypeConversionFailure(rewriter, op);

mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -101,8 +101,7 @@ struct WmmaConstantOpToSPIRVLowering final
101101
LogicalResult
102102
matchAndRewrite(gpu::SubgroupMmaConstantMatrixOp op, OpAdaptor adaptor,
103103
ConversionPatternRewriter &rewriter) const override {
104-
assert(adaptor.getOperands().size() == 1);
105-
Value cst = adaptor.getOperands().front();
104+
Value cst = llvm::getSingleElement(adaptor.getOperands());
106105
auto coopType = getTypeConverter()->convertType(op.getType());
107106
if (!coopType)
108107
return rewriter.notifyMatchFailure(op, "type conversion failed");
@@ -181,8 +180,7 @@ struct WmmaElementwiseOpToSPIRVScalarMulLowering final
181180
"splat is not a composite construct");
182181
}
183182

184-
assert(cc.getConstituents().size() == 1);
185-
scalar = cc.getConstituents().front();
183+
scalar = llvm::getSingleElement(cc.getConstituents());
186184

187185
auto coopType = getTypeConverter()->convertType(op.getType());
188186
if (!coopType)

mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -419,13 +419,11 @@ struct ConvertShardShapeOp : public OpConversionPattern<ShardShapeOp> {
419419
SmallVector<Value> dynDims, dynDevice;
420420
for (auto dim : adaptor.getDimsDynamic()) {
421421
// type conversion should be 1:1 for ints
422-
assert(dim.size() == 1);
423-
dynDims.emplace_back(dim[0]);
422+
dynDims.emplace_back(llvm::getSingleElement(dim));
424423
}
425424
// same for device
426425
for (auto device : adaptor.getDeviceDynamic()) {
427-
assert(device.size() == 1);
428-
dynDevice.emplace_back(device[0]);
426+
dynDevice.emplace_back(llvm::getSingleElement(device));
429427
}
430428

431429
// To keep the code simple, convert dims/device to values when they are

mlir/lib/Dialect/Affine/IR/AffineOps.cpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1236,8 +1236,7 @@ mlir::affine::makeComposedFoldedAffineApply(OpBuilder &b, Location loc,
12361236
}
12371237

12381238
applyOp->erase();
1239-
assert(foldResults.size() == 1 && "expected 1 folded result");
1240-
return foldResults.front();
1239+
return llvm::getSingleElement(foldResults);
12411240
}
12421241

12431242
OpFoldResult
@@ -1306,8 +1305,7 @@ static OpFoldResult makeComposedFoldedMinMax(OpBuilder &b, Location loc,
13061305
}
13071306

13081307
minMaxOp->erase();
1309-
assert(foldResults.size() == 1 && "expected 1 folded result");
1310-
return foldResults.front();
1308+
return llvm::getSingleElement(foldResults);
13111309
}
13121310

13131311
OpFoldResult

mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1249,8 +1249,7 @@ struct GreedyFusion {
12491249
SmallVector<Operation *, 2> sibLoadOpInsts;
12501250
sibNode->getLoadOpsForMemref(memref, &sibLoadOpInsts);
12511251
// Currently findSiblingNodeToFuse searches for siblings with one load.
1252-
assert(sibLoadOpInsts.size() == 1);
1253-
Operation *sibLoadOpInst = sibLoadOpInsts[0];
1252+
Operation *sibLoadOpInst = llvm::getSingleElement(sibLoadOpInsts);
12541253

12551254
// Gather 'dstNode' load ops to 'memref'.
12561255
SmallVector<Operation *, 2> dstLoadOpInsts;

mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1604,10 +1604,8 @@ SmallVector<AffineForOp, 8> mlir::affine::tile(ArrayRef<AffineForOp> forOps,
16041604
ArrayRef<uint64_t> sizes,
16051605
AffineForOp target) {
16061606
SmallVector<AffineForOp, 8> res;
1607-
for (auto loops : tile(forOps, sizes, ArrayRef<AffineForOp>(target))) {
1608-
assert(loops.size() == 1);
1609-
res.push_back(loops[0]);
1610-
}
1607+
for (auto loops : tile(forOps, sizes, ArrayRef<AffineForOp>(target)))
1608+
res.push_back(llvm::getSingleElement(loops));
16111609
return res;
16121610
}
16131611

mlir/lib/Dialect/Linalg/Transforms/SubsetInsertionOpInterfaceImpl.cpp

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -44,30 +44,27 @@ struct LinalgCopyOpInterface
4444
linalg::CopyOp> {
4545
OpOperand &getSourceOperand(Operation *op) const {
4646
auto copyOp = cast<CopyOp>(op);
47-
assert(copyOp.getInputs().size() == 1 && "expected single input");
48-
return copyOp.getInputsMutable()[0];
47+
return llvm::getSingleElement(copyOp.getInputsMutable());
4948
}
5049

5150
bool
5251
isEquivalentSubset(Operation *op, Value candidate,
5352
function_ref<bool(Value, Value)> equivalenceFn) const {
5453
auto copyOp = cast<CopyOp>(op);
55-
assert(copyOp.getOutputs().size() == 1 && "expected single output");
56-
return equivalenceFn(candidate, copyOp.getOutputs()[0]);
54+
return equivalenceFn(candidate,
55+
llvm::getSingleElement(copyOp.getOutputs()));
5756
}
5857

5958
Value buildSubsetExtraction(Operation *op, OpBuilder &builder,
6059
Location loc) const {
6160
auto copyOp = cast<CopyOp>(op);
62-
assert(copyOp.getOutputs().size() == 1 && "expected single output");
63-
return copyOp.getOutputs()[0];
61+
return llvm::getSingleElement(copyOp.getOutputs());
6462
}
6563

6664
SmallVector<Value>
6765
getValuesNeededToBuildSubsetExtraction(Operation *op) const {
6866
auto copyOp = cast<CopyOp>(op);
69-
assert(copyOp.getOutputs().size() == 1 && "expected single output");
70-
return {copyOp.getOutputs()[0]};
67+
return {llvm::getSingleElement(copyOp.getOutputs())};
7168
}
7269
};
7370
} // namespace

mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -471,7 +471,7 @@ static bool isOpItselfPotentialAutomaticAllocation(Operation *op) {
471471
/// extending the lifetime of allocations.
472472
static bool lastNonTerminatorInRegion(Operation *op) {
473473
return op->getNextNode() == op->getBlock()->getTerminator() &&
474-
op->getParentRegion()->getBlocks().size() == 1;
474+
llvm::hasSingleElement(op->getParentRegion()->getBlocks());
475475
}
476476

477477
/// Inline an AllocaScopeOp if either the direct parent is an allocation scope

0 commit comments

Comments
 (0)