@@ -159,6 +159,107 @@ struct SimplifySelect : public OpRewritePattern<SelectOp> {
159
159
}
160
160
};
161
161
162
+ // / Simplify `cir.switch` operations by folding cascading cases
163
+ // / into a single `cir.case` with the `anyof` kind.
164
+ // /
165
+ // / This pattern identifies cascading cases within a `cir.switch` operation.
166
+ // / Cascading cases are defined as consecutive `cir.case` operations of kind
167
+ // / `equal`, each containing a single `cir.yield` operation in their body.
168
+ // /
169
+ // / The pattern merges these cascading cases into a single `cir.case` operation
170
+ // / with kind `anyof`, aggregating all the case values.
171
+ // /
172
+ // / The merging process continues until a `cir.case` with a different body
173
+ // / (e.g., containing `cir.break` or compound stmt) is encountered, which
174
+ // / breaks the chain.
175
+ // /
176
+ // / Example:
177
+ // /
178
+ // / Before:
179
+ // / cir.case equal, [#cir.int<0> : !s32i] {
180
+ // / cir.yield
181
+ // / }
182
+ // / cir.case equal, [#cir.int<1> : !s32i] {
183
+ // / cir.yield
184
+ // / }
185
+ // / cir.case equal, [#cir.int<2> : !s32i] {
186
+ // / cir.break
187
+ // / }
188
+ // /
189
+ // / After applying SimplifySwitch:
190
+ // / cir.case anyof, [#cir.int<0> : !s32i, #cir.int<1> : !s32i, #cir.int<2> :
191
+ // / !s32i] {
192
+ // / cir.break
193
+ // / }
194
+ struct SimplifySwitch : public OpRewritePattern <SwitchOp> {
195
+ using OpRewritePattern<SwitchOp>::OpRewritePattern;
196
+ LogicalResult matchAndRewrite (SwitchOp op,
197
+ PatternRewriter &rewriter) const override {
198
+
199
+ LogicalResult changed = mlir::failure ();
200
+ SmallVector<CaseOp, 8 > cases;
201
+ SmallVector<CaseOp, 4 > cascadingCases;
202
+ SmallVector<mlir::Attribute, 4 > cascadingCaseValues;
203
+
204
+ op.collectCases (cases);
205
+ if (cases.empty ())
206
+ return mlir::failure ();
207
+
208
+ auto flushMergedOps = [&]() {
209
+ for (CaseOp &c : cascadingCases)
210
+ rewriter.eraseOp (c);
211
+ cascadingCases.clear ();
212
+ cascadingCaseValues.clear ();
213
+ };
214
+
215
+ auto mergeCascadingInto = [&](CaseOp &target) {
216
+ rewriter.modifyOpInPlace (target, [&]() {
217
+ target.setValueAttr (rewriter.getArrayAttr (cascadingCaseValues));
218
+ target.setKind (CaseOpKind::Anyof);
219
+ });
220
+ changed = mlir::success ();
221
+ };
222
+
223
+ for (CaseOp c : cases) {
224
+ cir::CaseOpKind kind = c.getKind ();
225
+ if (kind == cir::CaseOpKind::Equal &&
226
+ isa<YieldOp>(c.getCaseRegion ().front ().front ())) {
227
+ // If the case contains only a YieldOp, collect it for cascading merge
228
+ cascadingCases.push_back (c);
229
+ cascadingCaseValues.push_back (c.getValue ()[0 ]);
230
+ } else if (kind == cir::CaseOpKind::Equal && !cascadingCases.empty ()) {
231
+ // merge previously collected cascading cases
232
+ cascadingCaseValues.push_back (c.getValue ()[0 ]);
233
+ mergeCascadingInto (c);
234
+ flushMergedOps ();
235
+ } else if (kind != cir::CaseOpKind::Equal && cascadingCases.size () > 1 ) {
236
+ // If a Default, Anyof or Range case is found and there are previous
237
+ // cascading cases, merge all of them into the last cascading case.
238
+ // We don't currently fold case range statements with other case
239
+ // statements.
240
+ assert (!cir::MissingFeatures::foldRangeCase ());
241
+ CaseOp lastCascadingCase = cascadingCases.back ();
242
+ mergeCascadingInto (lastCascadingCase);
243
+ cascadingCases.pop_back ();
244
+ flushMergedOps ();
245
+ } else {
246
+ cascadingCases.clear ();
247
+ cascadingCaseValues.clear ();
248
+ }
249
+ }
250
+
251
+ // Edge case: all cases are simple cascading cases
252
+ if (cascadingCases.size () == cases.size ()) {
253
+ CaseOp lastCascadingCase = cascadingCases.back ();
254
+ mergeCascadingInto (lastCascadingCase);
255
+ cascadingCases.pop_back ();
256
+ flushMergedOps ();
257
+ }
258
+
259
+ return changed;
260
+ }
261
+ };
262
+
162
263
// ===----------------------------------------------------------------------===//
163
264
// CIRSimplifyPass
164
265
// ===----------------------------------------------------------------------===//
@@ -173,7 +274,8 @@ void populateMergeCleanupPatterns(RewritePatternSet &patterns) {
173
274
// clang-format off
174
275
patterns.add <
175
276
SimplifyTernary,
176
- SimplifySelect
277
+ SimplifySelect,
278
+ SimplifySwitch
177
279
>(patterns.getContext ());
178
280
// clang-format on
179
281
}
@@ -186,7 +288,7 @@ void CIRSimplifyPass::runOnOperation() {
186
288
// Collect operations to apply patterns.
187
289
llvm::SmallVector<Operation *, 16 > ops;
188
290
getOperation ()->walk ([&](Operation *op) {
189
- if (isa<TernaryOp, SelectOp>(op))
291
+ if (isa<TernaryOp, SelectOp, SwitchOp >(op))
190
292
ops.push_back (op);
191
293
});
192
294
0 commit comments