Skip to content

Commit 754a11a

Browse files
authored
[CIR][NFC] Pass enum kind directly to complex cast helpers (#1757)
Backporting passing enum kind directly to complex cast helpers
1 parent 2a126d2 commit 754a11a

File tree

1 file changed

+53
-74
lines changed

1 file changed

+53
-74
lines changed

clang/lib/CIR/Dialect/Transforms/LoweringPrepare.cpp

Lines changed: 53 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -480,110 +480,89 @@ void LoweringPreparePass::lowerBinOp(BinOp op) {
480480
op.erase();
481481
}
482482

483-
static mlir::Value lowerScalarToComplexCast(MLIRContext &ctx, CastOp op) {
484-
CIRBaseBuilderTy builder(ctx);
483+
static mlir::Value lowerScalarToComplexCast(mlir::MLIRContext &ctx,
484+
cir::CastOp op) {
485+
cir::CIRBaseBuilderTy builder(ctx);
485486
builder.setInsertionPoint(op);
486487

487-
auto src = op.getSrc();
488-
auto imag = builder.getNullValue(src.getType(), op.getLoc());
488+
mlir::Value src = op.getSrc();
489+
mlir::Value imag = builder.getNullValue(src.getType(), op.getLoc());
489490
return builder.createComplexCreate(op.getLoc(), src, imag);
490491
}
491492

492-
static mlir::Value lowerComplexToScalarCast(MLIRContext &ctx, CastOp op) {
493-
CIRBaseBuilderTy builder(ctx);
493+
static mlir::Value lowerComplexToScalarCast(mlir::MLIRContext &ctx,
494+
cir::CastOp op,
495+
cir::CastKind elemToBoolKind) {
496+
cir::CIRBaseBuilderTy builder(ctx);
494497
builder.setInsertionPoint(op);
495498

496-
auto src = op.getSrc();
497-
499+
mlir::Value src = op.getSrc();
498500
if (!mlir::isa<cir::BoolType>(op.getType()))
499501
return builder.createComplexReal(op.getLoc(), src);
500502

501503
// Complex cast to bool: (bool)(a+bi) => (bool)a || (bool)b
502-
auto srcReal = builder.createComplexReal(op.getLoc(), src);
503-
auto srcImag = builder.createComplexImag(op.getLoc(), src);
504-
505-
cir::CastKind elemToBoolKind;
506-
if (op.getKind() == cir::CastKind::float_complex_to_bool)
507-
elemToBoolKind = cir::CastKind::float_to_bool;
508-
else if (op.getKind() == cir::CastKind::int_complex_to_bool)
509-
elemToBoolKind = cir::CastKind::int_to_bool;
510-
else
511-
llvm_unreachable("invalid complex to bool cast kind");
504+
mlir::Value srcReal = builder.createComplexReal(op.getLoc(), src);
505+
mlir::Value srcImag = builder.createComplexImag(op.getLoc(), src);
512506

513-
auto boolTy = builder.getBoolTy();
514-
auto srcRealToBool =
507+
cir::BoolType boolTy = builder.getBoolTy();
508+
mlir::Value srcRealToBool =
515509
builder.createCast(op.getLoc(), elemToBoolKind, srcReal, boolTy);
516-
auto srcImagToBool =
510+
mlir::Value srcImagToBool =
517511
builder.createCast(op.getLoc(), elemToBoolKind, srcImag, boolTy);
518-
519-
// srcRealToBool || srcImagToBool
520512
return builder.createLogicalOr(op.getLoc(), srcRealToBool, srcImagToBool);
521513
}
522514

523-
static mlir::Value lowerComplexToComplexCast(MLIRContext &ctx, CastOp op) {
515+
static mlir::Value lowerComplexToComplexCast(mlir::MLIRContext &ctx,
516+
cir::CastOp op,
517+
cir::CastKind scalarCastKind) {
524518
CIRBaseBuilderTy builder(ctx);
525519
builder.setInsertionPoint(op);
526520

527-
auto src = op.getSrc();
521+
mlir::Value src = op.getSrc();
528522
auto dstComplexElemTy =
529523
mlir::cast<cir::ComplexType>(op.getType()).getElementType();
530524

531-
auto srcReal = builder.createComplexReal(op.getLoc(), src);
532-
auto srcImag = builder.createComplexReal(op.getLoc(), src);
525+
mlir::Value srcReal = builder.createComplexReal(op.getLoc(), src);
526+
mlir::Value srcImag = builder.createComplexImag(op.getLoc(), src);
533527

534-
cir::CastKind scalarCastKind;
535-
switch (op.getKind()) {
536-
case cir::CastKind::float_complex:
537-
scalarCastKind = cir::CastKind::floating;
538-
break;
539-
case cir::CastKind::float_complex_to_int_complex:
540-
scalarCastKind = cir::CastKind::float_to_int;
541-
break;
542-
case cir::CastKind::int_complex:
543-
scalarCastKind = cir::CastKind::integral;
544-
break;
545-
case cir::CastKind::int_complex_to_float_complex:
546-
scalarCastKind = cir::CastKind::int_to_float;
547-
break;
548-
default:
549-
llvm_unreachable("invalid complex to complex cast kind");
550-
}
551-
552-
auto dstReal = builder.createCast(op.getLoc(), scalarCastKind, srcReal,
553-
dstComplexElemTy);
554-
auto dstImag = builder.createCast(op.getLoc(), scalarCastKind, srcImag,
555-
dstComplexElemTy);
528+
mlir::Value dstReal = builder.createCast(op.getLoc(), scalarCastKind, srcReal,
529+
dstComplexElemTy);
530+
mlir::Value dstImag = builder.createCast(op.getLoc(), scalarCastKind, srcImag,
531+
dstComplexElemTy);
556532
return builder.createComplexCreate(op.getLoc(), dstReal, dstImag);
557533
}
558534

559535
void LoweringPreparePass::lowerCastOp(CastOp op) {
560-
mlir::Value loweredValue;
561-
switch (op.getKind()) {
562-
case cir::CastKind::float_to_complex:
563-
case cir::CastKind::int_to_complex:
564-
loweredValue = lowerScalarToComplexCast(getContext(), op);
565-
break;
566-
567-
case cir::CastKind::float_complex_to_real:
568-
case cir::CastKind::int_complex_to_real:
569-
case cir::CastKind::float_complex_to_bool:
570-
case cir::CastKind::int_complex_to_bool:
571-
loweredValue = lowerComplexToScalarCast(getContext(), op);
572-
break;
573-
574-
case cir::CastKind::float_complex:
575-
case cir::CastKind::float_complex_to_int_complex:
576-
case cir::CastKind::int_complex:
577-
case cir::CastKind::int_complex_to_float_complex:
578-
loweredValue = lowerComplexToComplexCast(getContext(), op);
579-
break;
536+
mlir::MLIRContext &ctx = getContext();
537+
mlir::Value loweredValue = [&]() -> mlir::Value {
538+
switch (op.getKind()) {
539+
case cir::CastKind::float_to_complex:
540+
case cir::CastKind::int_to_complex:
541+
return lowerScalarToComplexCast(ctx, op);
542+
case cir::CastKind::float_complex_to_real:
543+
case cir::CastKind::int_complex_to_real:
544+
return lowerComplexToScalarCast(ctx, op, op.getKind());
545+
case cir::CastKind::float_complex_to_bool:
546+
return lowerComplexToScalarCast(ctx, op, cir::CastKind::float_to_bool);
547+
case cir::CastKind::int_complex_to_bool:
548+
return lowerComplexToScalarCast(ctx, op, cir::CastKind::int_to_bool);
549+
case cir::CastKind::float_complex:
550+
return lowerComplexToComplexCast(ctx, op, cir::CastKind::floating);
551+
case cir::CastKind::float_complex_to_int_complex:
552+
return lowerComplexToComplexCast(ctx, op, cir::CastKind::float_to_int);
553+
case cir::CastKind::int_complex:
554+
return lowerComplexToComplexCast(ctx, op, cir::CastKind::integral);
555+
case cir::CastKind::int_complex_to_float_complex:
556+
return lowerComplexToComplexCast(ctx, op, cir::CastKind::int_to_float);
557+
default:
558+
return nullptr;
559+
}
560+
}();
580561

581-
default:
582-
return;
562+
if (loweredValue) {
563+
op.replaceAllUsesWith(loweredValue);
564+
op.erase();
583565
}
584-
585-
op.replaceAllUsesWith(loweredValue);
586-
op.erase();
587566
}
588567

589568
static mlir::Value buildComplexBinOpLibCall(

0 commit comments

Comments
 (0)