Skip to content

Commit db20cd3

Browse files
authored
Merge branch 'llvm:main' into main
2 parents 4100dfe + 3eb9e77 commit db20cd3

File tree

9 files changed

+953
-13
lines changed

9 files changed

+953
-13
lines changed

clang/include/clang/CIR/MissingFeatures.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,6 @@ struct MissingFeatures {
120120
static bool opUnaryPromotionType() { return false; }
121121

122122
// SwitchOp handling
123-
static bool foldCascadingCases() { return false; }
124123
static bool foldRangeCase() { return false; }
125124

126125
// Clang early optimizations or things defered to LLVM lowering.

clang/lib/CIR/CodeGen/CIRGenStmt.cpp

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -533,12 +533,6 @@ mlir::LogicalResult CIRGenFunction::emitCaseStmt(const CaseStmt &s,
533533
value = builder.getArrayAttr({cir::IntAttr::get(condType, intVal),
534534
cir::IntAttr::get(condType, endVal)});
535535
kind = cir::CaseOpKind::Range;
536-
537-
// We don't currently fold case range statements with other case statements.
538-
// TODO(cir): Add this capability. Folding these cases is going to be
539-
// implemented in CIRSimplify when it is upstreamed.
540-
assert(!cir::MissingFeatures::foldRangeCase());
541-
assert(!cir::MissingFeatures::foldCascadingCases());
542536
} else {
543537
value = builder.getArrayAttr({cir::IntAttr::get(condType, intVal)});
544538
kind = cir::CaseOpKind::Equal;

clang/lib/CIR/Dialect/Transforms/CIRSimplify.cpp

Lines changed: 104 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,107 @@ struct SimplifySelect : public OpRewritePattern<SelectOp> {
159159
}
160160
};
161161

162+
/// Simplify `cir.switch` operations by folding cascading cases
163+
/// into a single `cir.case` with the `anyof` kind.
164+
///
165+
/// This pattern identifies cascading cases within a `cir.switch` operation.
166+
/// Cascading cases are defined as consecutive `cir.case` operations of kind
167+
/// `equal`, each containing a single `cir.yield` operation in their body.
168+
///
169+
/// The pattern merges these cascading cases into a single `cir.case` operation
170+
/// with kind `anyof`, aggregating all the case values.
171+
///
172+
/// The merging process continues until a `cir.case` with a different body
173+
/// (e.g., containing `cir.break` or compound stmt) is encountered, which
174+
/// breaks the chain.
175+
///
176+
/// Example:
177+
///
178+
/// Before:
179+
/// cir.case equal, [#cir.int<0> : !s32i] {
180+
/// cir.yield
181+
/// }
182+
/// cir.case equal, [#cir.int<1> : !s32i] {
183+
/// cir.yield
184+
/// }
185+
/// cir.case equal, [#cir.int<2> : !s32i] {
186+
/// cir.break
187+
/// }
188+
///
189+
/// After applying SimplifySwitch:
190+
/// cir.case anyof, [#cir.int<0> : !s32i, #cir.int<1> : !s32i, #cir.int<2> :
191+
/// !s32i] {
192+
/// cir.break
193+
/// }
194+
struct SimplifySwitch : public OpRewritePattern<SwitchOp> {
195+
using OpRewritePattern<SwitchOp>::OpRewritePattern;
196+
LogicalResult matchAndRewrite(SwitchOp op,
197+
PatternRewriter &rewriter) const override {
198+
199+
LogicalResult changed = mlir::failure();
200+
SmallVector<CaseOp, 8> cases;
201+
SmallVector<CaseOp, 4> cascadingCases;
202+
SmallVector<mlir::Attribute, 4> cascadingCaseValues;
203+
204+
op.collectCases(cases);
205+
if (cases.empty())
206+
return mlir::failure();
207+
208+
auto flushMergedOps = [&]() {
209+
for (CaseOp &c : cascadingCases)
210+
rewriter.eraseOp(c);
211+
cascadingCases.clear();
212+
cascadingCaseValues.clear();
213+
};
214+
215+
auto mergeCascadingInto = [&](CaseOp &target) {
216+
rewriter.modifyOpInPlace(target, [&]() {
217+
target.setValueAttr(rewriter.getArrayAttr(cascadingCaseValues));
218+
target.setKind(CaseOpKind::Anyof);
219+
});
220+
changed = mlir::success();
221+
};
222+
223+
for (CaseOp c : cases) {
224+
cir::CaseOpKind kind = c.getKind();
225+
if (kind == cir::CaseOpKind::Equal &&
226+
isa<YieldOp>(c.getCaseRegion().front().front())) {
227+
// If the case contains only a YieldOp, collect it for cascading merge
228+
cascadingCases.push_back(c);
229+
cascadingCaseValues.push_back(c.getValue()[0]);
230+
} else if (kind == cir::CaseOpKind::Equal && !cascadingCases.empty()) {
231+
// merge previously collected cascading cases
232+
cascadingCaseValues.push_back(c.getValue()[0]);
233+
mergeCascadingInto(c);
234+
flushMergedOps();
235+
} else if (kind != cir::CaseOpKind::Equal && cascadingCases.size() > 1) {
236+
// If a Default, Anyof or Range case is found and there are previous
237+
// cascading cases, merge all of them into the last cascading case.
238+
// We don't currently fold case range statements with other case
239+
// statements.
240+
assert(!cir::MissingFeatures::foldRangeCase());
241+
CaseOp lastCascadingCase = cascadingCases.back();
242+
mergeCascadingInto(lastCascadingCase);
243+
cascadingCases.pop_back();
244+
flushMergedOps();
245+
} else {
246+
cascadingCases.clear();
247+
cascadingCaseValues.clear();
248+
}
249+
}
250+
251+
// Edge case: all cases are simple cascading cases
252+
if (cascadingCases.size() == cases.size()) {
253+
CaseOp lastCascadingCase = cascadingCases.back();
254+
mergeCascadingInto(lastCascadingCase);
255+
cascadingCases.pop_back();
256+
flushMergedOps();
257+
}
258+
259+
return changed;
260+
}
261+
};
262+
162263
//===----------------------------------------------------------------------===//
163264
// CIRSimplifyPass
164265
//===----------------------------------------------------------------------===//
@@ -173,7 +274,8 @@ void populateMergeCleanupPatterns(RewritePatternSet &patterns) {
173274
// clang-format off
174275
patterns.add<
175276
SimplifyTernary,
176-
SimplifySelect
277+
SimplifySelect,
278+
SimplifySwitch
177279
>(patterns.getContext());
178280
// clang-format on
179281
}
@@ -186,7 +288,7 @@ void CIRSimplifyPass::runOnOperation() {
186288
// Collect operations to apply patterns.
187289
llvm::SmallVector<Operation *, 16> ops;
188290
getOperation()->walk([&](Operation *op) {
189-
if (isa<TernaryOp, SelectOp>(op))
291+
if (isa<TernaryOp, SelectOp, SwitchOp>(op))
190292
ops.push_back(op);
191293
});
192294

clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1065,6 +1065,33 @@ mlir::LogicalResult CIRToLLVMGlobalOpLowering::matchAndRewrite(
10651065
return mlir::success();
10661066
}
10671067

1068+
mlir::LogicalResult CIRToLLVMSwitchFlatOpLowering::matchAndRewrite(
1069+
cir::SwitchFlatOp op, OpAdaptor adaptor,
1070+
mlir::ConversionPatternRewriter &rewriter) const {
1071+
1072+
llvm::SmallVector<mlir::APInt, 8> caseValues;
1073+
for (mlir::Attribute val : op.getCaseValues()) {
1074+
auto intAttr = cast<cir::IntAttr>(val);
1075+
caseValues.push_back(intAttr.getValue());
1076+
}
1077+
1078+
llvm::SmallVector<mlir::Block *, 8> caseDestinations;
1079+
llvm::SmallVector<mlir::ValueRange, 8> caseOperands;
1080+
1081+
for (mlir::Block *x : op.getCaseDestinations())
1082+
caseDestinations.push_back(x);
1083+
1084+
for (mlir::OperandRange x : op.getCaseOperands())
1085+
caseOperands.push_back(x);
1086+
1087+
// Set switch op to branch to the newly created blocks.
1088+
rewriter.setInsertionPoint(op);
1089+
rewriter.replaceOpWithNewOp<mlir::LLVM::SwitchOp>(
1090+
op, adaptor.getCondition(), op.getDefaultDestination(),
1091+
op.getDefaultOperands(), caseValues, caseDestinations, caseOperands);
1092+
return mlir::success();
1093+
}
1094+
10681095
mlir::LogicalResult CIRToLLVMUnaryOpLowering::matchAndRewrite(
10691096
cir::UnaryOp op, OpAdaptor adaptor,
10701097
mlir::ConversionPatternRewriter &rewriter) const {
@@ -1681,6 +1708,7 @@ void ConvertCIRToLLVMPass::runOnOperation() {
16811708
CIRToLLVMGetGlobalOpLowering,
16821709
CIRToLLVMGetMemberOpLowering,
16831710
CIRToLLVMSelectOpLowering,
1711+
CIRToLLVMSwitchFlatOpLowering,
16841712
CIRToLLVMShiftOpLowering,
16851713
CIRToLLVMStackSaveOpLowering,
16861714
CIRToLLVMStackRestoreOpLowering,

clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,16 @@ class CIRToLLVMFuncOpLowering : public mlir::OpConversionPattern<cir::FuncOp> {
149149
mlir::ConversionPatternRewriter &) const override;
150150
};
151151

152+
class CIRToLLVMSwitchFlatOpLowering
153+
: public mlir::OpConversionPattern<cir::SwitchFlatOp> {
154+
public:
155+
using mlir::OpConversionPattern<cir::SwitchFlatOp>::OpConversionPattern;
156+
157+
mlir::LogicalResult
158+
matchAndRewrite(cir::SwitchFlatOp op, OpAdaptor,
159+
mlir::ConversionPatternRewriter &) const override;
160+
};
161+
152162
class CIRToLLVMGetGlobalOpLowering
153163
: public mlir::OpConversionPattern<cir::GetGlobalOp> {
154164
public:

0 commit comments

Comments
 (0)