@@ -480,110 +480,89 @@ void LoweringPreparePass::lowerBinOp(BinOp op) {
480
480
op.erase ();
481
481
}
482
482
483
- static mlir::Value lowerScalarToComplexCast (MLIRContext &ctx, CastOp op) {
484
- CIRBaseBuilderTy builder (ctx);
483
+ static mlir::Value lowerScalarToComplexCast (mlir::MLIRContext &ctx,
484
+ cir::CastOp op) {
485
+ cir::CIRBaseBuilderTy builder (ctx);
485
486
builder.setInsertionPoint (op);
486
487
487
- auto src = op.getSrc ();
488
- auto imag = builder.getNullValue (src.getType (), op.getLoc ());
488
+ mlir::Value src = op.getSrc ();
489
+ mlir::Value imag = builder.getNullValue (src.getType (), op.getLoc ());
489
490
return builder.createComplexCreate (op.getLoc (), src, imag);
490
491
}
491
492
492
- static mlir::Value lowerComplexToScalarCast (MLIRContext &ctx, CastOp op) {
493
- CIRBaseBuilderTy builder (ctx);
493
+ static mlir::Value lowerComplexToScalarCast (mlir::MLIRContext &ctx,
494
+ cir::CastOp op,
495
+ cir::CastKind elemToBoolKind) {
496
+ cir::CIRBaseBuilderTy builder (ctx);
494
497
builder.setInsertionPoint (op);
495
498
496
- auto src = op.getSrc ();
497
-
499
+ mlir::Value src = op.getSrc ();
498
500
if (!mlir::isa<cir::BoolType>(op.getType ()))
499
501
return builder.createComplexReal (op.getLoc (), src);
500
502
501
503
// Complex cast to bool: (bool)(a+bi) => (bool)a || (bool)b
502
- auto srcReal = builder.createComplexReal (op.getLoc (), src);
503
- auto srcImag = builder.createComplexImag (op.getLoc (), src);
504
-
505
- cir::CastKind elemToBoolKind;
506
- if (op.getKind () == cir::CastKind::float_complex_to_bool)
507
- elemToBoolKind = cir::CastKind::float_to_bool;
508
- else if (op.getKind () == cir::CastKind::int_complex_to_bool)
509
- elemToBoolKind = cir::CastKind::int_to_bool;
510
- else
511
- llvm_unreachable (" invalid complex to bool cast kind" );
504
+ mlir::Value srcReal = builder.createComplexReal (op.getLoc (), src);
505
+ mlir::Value srcImag = builder.createComplexImag (op.getLoc (), src);
512
506
513
- auto boolTy = builder.getBoolTy ();
514
- auto srcRealToBool =
507
+ cir::BoolType boolTy = builder.getBoolTy ();
508
+ mlir::Value srcRealToBool =
515
509
builder.createCast (op.getLoc (), elemToBoolKind, srcReal, boolTy);
516
- auto srcImagToBool =
510
+ mlir::Value srcImagToBool =
517
511
builder.createCast (op.getLoc (), elemToBoolKind, srcImag, boolTy);
518
-
519
- // srcRealToBool || srcImagToBool
520
512
return builder.createLogicalOr (op.getLoc (), srcRealToBool, srcImagToBool);
521
513
}
522
514
523
- static mlir::Value lowerComplexToComplexCast (MLIRContext &ctx, CastOp op) {
515
+ static mlir::Value lowerComplexToComplexCast (mlir::MLIRContext &ctx,
516
+ cir::CastOp op,
517
+ cir::CastKind scalarCastKind) {
524
518
CIRBaseBuilderTy builder (ctx);
525
519
builder.setInsertionPoint (op);
526
520
527
- auto src = op.getSrc ();
521
+ mlir::Value src = op.getSrc ();
528
522
auto dstComplexElemTy =
529
523
mlir::cast<cir::ComplexType>(op.getType ()).getElementType ();
530
524
531
- auto srcReal = builder.createComplexReal (op.getLoc (), src);
532
- auto srcImag = builder.createComplexReal (op.getLoc (), src);
525
+ mlir::Value srcReal = builder.createComplexReal (op.getLoc (), src);
526
+ mlir::Value srcImag = builder.createComplexImag (op.getLoc (), src);
533
527
534
- cir::CastKind scalarCastKind;
535
- switch (op.getKind ()) {
536
- case cir::CastKind::float_complex:
537
- scalarCastKind = cir::CastKind::floating;
538
- break ;
539
- case cir::CastKind::float_complex_to_int_complex:
540
- scalarCastKind = cir::CastKind::float_to_int;
541
- break ;
542
- case cir::CastKind::int_complex:
543
- scalarCastKind = cir::CastKind::integral;
544
- break ;
545
- case cir::CastKind::int_complex_to_float_complex:
546
- scalarCastKind = cir::CastKind::int_to_float;
547
- break ;
548
- default :
549
- llvm_unreachable (" invalid complex to complex cast kind" );
550
- }
551
-
552
- auto dstReal = builder.createCast (op.getLoc (), scalarCastKind, srcReal,
553
- dstComplexElemTy);
554
- auto dstImag = builder.createCast (op.getLoc (), scalarCastKind, srcImag,
555
- dstComplexElemTy);
528
+ mlir::Value dstReal = builder.createCast (op.getLoc (), scalarCastKind, srcReal,
529
+ dstComplexElemTy);
530
+ mlir::Value dstImag = builder.createCast (op.getLoc (), scalarCastKind, srcImag,
531
+ dstComplexElemTy);
556
532
return builder.createComplexCreate (op.getLoc (), dstReal, dstImag);
557
533
}
558
534
559
535
void LoweringPreparePass::lowerCastOp (CastOp op) {
560
- mlir::Value loweredValue;
561
- switch (op.getKind ()) {
562
- case cir::CastKind::float_to_complex:
563
- case cir::CastKind::int_to_complex:
564
- loweredValue = lowerScalarToComplexCast (getContext (), op);
565
- break ;
566
-
567
- case cir::CastKind::float_complex_to_real:
568
- case cir::CastKind::int_complex_to_real:
569
- case cir::CastKind::float_complex_to_bool:
570
- case cir::CastKind::int_complex_to_bool:
571
- loweredValue = lowerComplexToScalarCast (getContext (), op);
572
- break ;
573
-
574
- case cir::CastKind::float_complex:
575
- case cir::CastKind::float_complex_to_int_complex:
576
- case cir::CastKind::int_complex:
577
- case cir::CastKind::int_complex_to_float_complex:
578
- loweredValue = lowerComplexToComplexCast (getContext (), op);
579
- break ;
536
+ mlir::MLIRContext &ctx = getContext ();
537
+ mlir::Value loweredValue = [&]() -> mlir::Value {
538
+ switch (op.getKind ()) {
539
+ case cir::CastKind::float_to_complex:
540
+ case cir::CastKind::int_to_complex:
541
+ return lowerScalarToComplexCast (ctx, op);
542
+ case cir::CastKind::float_complex_to_real:
543
+ case cir::CastKind::int_complex_to_real:
544
+ return lowerComplexToScalarCast (ctx, op, op.getKind ());
545
+ case cir::CastKind::float_complex_to_bool:
546
+ return lowerComplexToScalarCast (ctx, op, cir::CastKind::float_to_bool);
547
+ case cir::CastKind::int_complex_to_bool:
548
+ return lowerComplexToScalarCast (ctx, op, cir::CastKind::int_to_bool);
549
+ case cir::CastKind::float_complex:
550
+ return lowerComplexToComplexCast (ctx, op, cir::CastKind::floating);
551
+ case cir::CastKind::float_complex_to_int_complex:
552
+ return lowerComplexToComplexCast (ctx, op, cir::CastKind::float_to_int);
553
+ case cir::CastKind::int_complex:
554
+ return lowerComplexToComplexCast (ctx, op, cir::CastKind::integral);
555
+ case cir::CastKind::int_complex_to_float_complex:
556
+ return lowerComplexToComplexCast (ctx, op, cir::CastKind::int_to_float);
557
+ default :
558
+ return nullptr ;
559
+ }
560
+ }();
580
561
581
- default :
582
- return ;
562
+ if (loweredValue) {
563
+ op.replaceAllUsesWith (loweredValue);
564
+ op.erase ();
583
565
}
584
-
585
- op.replaceAllUsesWith (loweredValue);
586
- op.erase ();
587
566
}
588
567
589
568
static mlir::Value buildComplexBinOpLibCall (
0 commit comments