Skip to content

Commit 1543f10

Browse files
committed
[CIR] Upstream ComplexRealPtrOp for ComplexType
1 parent 53183be commit 1543f10

File tree

9 files changed

+174
-13
lines changed

9 files changed

+174
-13
lines changed

clang/include/clang/CIR/Dialect/IR/CIROps.td

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2612,6 +2612,36 @@ def ComplexImagOp : CIR_Op<"complex.imag", [Pure]> {
26122612
let hasFolder = 1;
26132613
}
26142614

2615+
//===----------------------------------------------------------------------===//
2616+
// ComplexRealPtrOp
2617+
//===----------------------------------------------------------------------===//
2618+
2619+
def ComplexRealPtrOp : CIR_Op<"complex.real_ptr", [Pure]> {
2620+
let summary = "Derive a pointer to the real part of a complex value";
2621+
let description = [{
2622+
`cir.complex.real_ptr` operation takes a pointer operand that points to a
2623+
complex value of type `!cir.complex` and yields a pointer to the real part
2624+
of the operand.
2625+
2626+
Example:
2627+
2628+
```mlir
2629+
%1 = cir.complex.real_ptr %0 : !cir.ptr<!cir.complex<!cir.double>>
2630+
-> !cir.ptr<!cir.double>
2631+
```
2632+
}];
2633+
2634+
let results = (outs CIR_PtrToIntOrFloatType:$result);
2635+
let arguments = (ins CIR_PtrToComplexType:$operand);
2636+
2637+
let assemblyFormat = [{
2638+
$operand `:`
2639+
qualified(type($operand)) `->` qualified(type($result)) attr-dict
2640+
}];
2641+
2642+
let hasVerifier = 1;
2643+
}
2644+
26152645
//===----------------------------------------------------------------------===//
26162646
// ComplexAddOp
26172647
//===----------------------------------------------------------------------===//

clang/include/clang/CIR/Dialect/IR/CIRTypeConstraints.td

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,12 @@ def CIR_AnyIntOrFloatType : AnyTypeOf<[CIR_AnyFloatType, CIR_AnyIntType],
159159
let cppFunctionName = "isAnyIntegerOrFloatingPointType";
160160
}
161161

162+
//===----------------------------------------------------------------------===//
163+
// Complex Type predicates
164+
//===----------------------------------------------------------------------===//
165+
166+
def CIR_AnyComplexType : CIR_TypeBase<"::cir::ComplexType", "complex type">;
167+
162168
//===----------------------------------------------------------------------===//
163169
// Pointer Type predicates
164170
//===----------------------------------------------------------------------===//
@@ -180,6 +186,17 @@ class CIR_PtrToPtrTo<code type, string summary>
180186
: CIR_ConfinedType<CIR_AnyPtrType, [CIR_IsPtrToPtrToPred<type>],
181187
"pointer to pointer to " # summary>;
182188

189+
// Pointee type constraint bases
190+
class CIR_PointeePred<Pred pred> : SubstLeaves<"$_self",
191+
"::mlir::cast<::cir::PointerType>($_self).getPointee()", pred>;
192+
193+
class CIR_PtrToAnyOf<list<Type> types, string summary = "">
194+
: CIR_ConfinedType<CIR_AnyPtrType,
195+
[Or<!foreach(type, types, CIR_PointeePred<type.predicate>)>],
196+
!if(!empty(summary),
197+
"pointer to " # CIR_TypeSummaries<types>.value,
198+
summary)>;
199+
183200
// Void pointer type constraints
184201
def CIR_VoidPtrType
185202
: CIR_PtrTo<"::cir::VoidType", "void type">,
@@ -192,6 +209,13 @@ def CIR_PtrToVoidPtrType
192209
"$_builder.getType<" # cppType # ">("
193210
"cir::VoidType::get($_builder.getContext())))">;
194211

212+
class CIR_PtrToType<Type type> : CIR_PtrToAnyOf<[type]>;
213+
214+
// Pointer to type constraints
215+
def CIR_PtrToIntOrFloatType : CIR_PtrToType<CIR_AnyIntOrFloatType>;
216+
217+
def CIR_PtrToComplexType : CIR_PtrToType<CIR_AnyComplexType>;
218+
195219
//===----------------------------------------------------------------------===//
196220
// Vector Type predicates
197221
//===----------------------------------------------------------------------===//

clang/lib/CIR/CodeGen/CIRGenBuilder.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -364,6 +364,20 @@ class CIRGenBuilderTy : public cir::CIRBaseBuilderTy {
364364
return create<cir::ComplexImagOp>(loc, operandTy.getElementType(), operand);
365365
}
366366

367+
/// Create a cir.complex.real_ptr operation that derives a pointer to the real
368+
/// part of the complex value pointed to by the specified pointer value.
369+
mlir::Value createComplexRealPtr(mlir::Location loc, mlir::Value value) {
370+
auto srcPtrTy = mlir::cast<cir::PointerType>(value.getType());
371+
auto srcComplexTy = mlir::cast<cir::ComplexType>(srcPtrTy.getPointee());
372+
return create<cir::ComplexRealPtrOp>(
373+
loc, getPointerTo(srcComplexTy.getElementType()), value);
374+
}
375+
376+
Address createComplexRealPtr(mlir::Location loc, Address addr) {
377+
return Address{createComplexRealPtr(loc, addr.getPointer()),
378+
addr.getAlignment()};
379+
}
380+
367381
/// Create a cir.ptr_stride operation to get access to an array element.
368382
/// \p idx is the index of the element to access, \p shouldDecay is true if
369383
/// the result should decay to a pointer to the element type.

clang/lib/CIR/CodeGen/CIRGenExpr.cpp

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -637,8 +637,29 @@ LValue CIRGenFunction::emitUnaryOpLValue(const UnaryOperator *e) {
637637
}
638638
case UO_Real:
639639
case UO_Imag: {
640-
cgm.errorNYI(e->getSourceRange(), "UnaryOp real/imag");
641-
return LValue();
640+
if (op == UO_Imag) {
641+
cgm.errorNYI(e->getSourceRange(), "UnaryOp real/imag");
642+
return LValue();
643+
}
644+
645+
LValue lv = emitLValue(e->getSubExpr());
646+
assert(lv.isSimple() && "real/imag on non-ordinary l-value");
647+
648+
// __real is valid on scalars. This is a faster way of testing that.
649+
// __imag can only produce an rvalue on scalars.
650+
if (e->getOpcode() == UO_Real &&
651+
!mlir::isa<cir::ComplexType>(lv.getAddress().getElementType())) {
652+
assert(e->getSubExpr()->getType()->isArithmeticType());
653+
return lv;
654+
}
655+
656+
QualType exprTy = getContext().getCanonicalType(e->getSubExpr()->getType());
657+
QualType elemTy = exprTy->castAs<clang::ComplexType>()->getElementType();
658+
mlir::Location loc = getLoc(e->getExprLoc());
659+
Address component = builder.createComplexRealPtr(loc, lv.getAddress());
660+
LValue elemLV = makeAddrLValue(component, elemTy);
661+
elemLV.getQuals().addQualifiers(lv.getQuals());
662+
return elemLV;
642663
}
643664
case UO_PreInc:
644665
case UO_PreDec: {

clang/lib/CIR/Dialect/IR/CIRDialect.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2097,6 +2097,24 @@ OpFoldResult cir::ComplexImagOp::fold(FoldAdaptor adaptor) {
20972097
return complex ? complex.getImag() : nullptr;
20982098
}
20992099

2100+
//===----------------------------------------------------------------------===//
2101+
// ComplexRealPtrOp
2102+
//===----------------------------------------------------------------------===//
2103+
2104+
LogicalResult cir::ComplexRealPtrOp::verify() {
2105+
mlir::Type resultPointeeTy = getType().getPointee();
2106+
cir::PointerType operandPtrTy = getOperand().getType();
2107+
auto operandPointeeTy =
2108+
mlir::cast<cir::ComplexType>(operandPtrTy.getPointee());
2109+
2110+
if (resultPointeeTy != operandPointeeTy.getElementType()) {
2111+
emitOpError() << ": result type does not match operand type";
2112+
return failure();
2113+
}
2114+
2115+
return success();
2116+
}
2117+
21002118
//===----------------------------------------------------------------------===//
21012119
// TableGen'd op method definitions
21022120
//===----------------------------------------------------------------------===//

clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2052,6 +2052,7 @@ void ConvertCIRToLLVMPass::runOnOperation() {
20522052
CIRToLLVMComplexCreateOpLowering,
20532053
CIRToLLVMComplexImagOpLowering,
20542054
CIRToLLVMComplexRealOpLowering,
2055+
CIRToLLVMComplexRealPtrOpLowering,
20552056
CIRToLLVMConstantOpLowering,
20562057
CIRToLLVMExpectOpLowering,
20572058
CIRToLLVMFuncOpLowering,
@@ -2526,6 +2527,23 @@ mlir::LogicalResult CIRToLLVMSetBitfieldOpLowering::matchAndRewrite(
25262527
return mlir::success();
25272528
}
25282529

2530+
mlir::LogicalResult CIRToLLVMComplexRealPtrOpLowering::matchAndRewrite(
2531+
cir::ComplexRealPtrOp op, OpAdaptor adaptor,
2532+
mlir::ConversionPatternRewriter &rewriter) const {
2533+
cir::PointerType operandTy = op.getOperand().getType();
2534+
mlir::Type resultLLVMTy = getTypeConverter()->convertType(op.getType());
2535+
mlir::Type elementLLVMTy =
2536+
getTypeConverter()->convertType(operandTy.getPointee());
2537+
2538+
mlir::LLVM::GEPArg gepIndices[2] = {0, 0};
2539+
mlir::LLVM::GEPNoWrapFlags inboundsNuw =
2540+
mlir::LLVM::GEPNoWrapFlags::inbounds | mlir::LLVM::GEPNoWrapFlags::nuw;
2541+
rewriter.replaceOpWithNewOp<mlir::LLVM::GEPOp>(
2542+
op, resultLLVMTy, elementLLVMTy, adaptor.getOperand(), gepIndices,
2543+
inboundsNuw);
2544+
return mlir::success();
2545+
}
2546+
25292547
mlir::LogicalResult CIRToLLVMGetBitfieldOpLowering::matchAndRewrite(
25302548
cir::GetBitfieldOp op, OpAdaptor adaptor,
25312549
mlir::ConversionPatternRewriter &rewriter) const {

clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -513,6 +513,26 @@ class CIRToLLVMComplexImagOpLowering
513513
mlir::ConversionPatternRewriter &) const override;
514514
};
515515

516+
class CIRToLLVMComplexRealPtrOpLowering
517+
: public mlir::OpConversionPattern<cir::ComplexRealPtrOp> {
518+
public:
519+
using mlir::OpConversionPattern<cir::ComplexRealPtrOp>::OpConversionPattern;
520+
521+
mlir::LogicalResult
522+
matchAndRewrite(cir::ComplexRealPtrOp op, OpAdaptor,
523+
mlir::ConversionPatternRewriter &) const override;
524+
};
525+
526+
class CIRToLLVMComplexAddOpLowering
527+
: public mlir::OpConversionPattern<cir::ComplexAddOp> {
528+
public:
529+
using mlir::OpConversionPattern<cir::ComplexAddOp>::OpConversionPattern;
530+
531+
mlir::LogicalResult
532+
matchAndRewrite(cir::ComplexAddOp op, OpAdaptor,
533+
mlir::ConversionPatternRewriter &) const override;
534+
};
535+
516536
class CIRToLLVMSetBitfieldOpLowering
517537
: public mlir::OpConversionPattern<cir::SetBitfieldOp> {
518538
public:
@@ -533,16 +553,6 @@ class CIRToLLVMGetBitfieldOpLowering
533553
mlir::ConversionPatternRewriter &) const override;
534554
};
535555

536-
class CIRToLLVMComplexAddOpLowering
537-
: public mlir::OpConversionPattern<cir::ComplexAddOp> {
538-
public:
539-
using mlir::OpConversionPattern<cir::ComplexAddOp>::OpConversionPattern;
540-
541-
mlir::LogicalResult
542-
matchAndRewrite(cir::ComplexAddOp op, OpAdaptor,
543-
mlir::ConversionPatternRewriter &) const override;
544-
};
545-
546556
} // namespace direct
547557
} // namespace cir
548558

clang/test/CIR/CodeGen/complex.cpp

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,20 @@ void foo9(double a, double b) {
216216
// OGCG: store double %[[TMP_A]], ptr %[[C_REAL_PTR]], align 8
217217
// OGCG: store double %[[TMP_B]], ptr %[[C_IMAG_PTR]], align 8
218218

219+
void foo10() {
220+
double _Complex c;
221+
double *realPtr = &__real__ c;
222+
}
223+
224+
// CIR: %[[COMPLEX:.*]] = cir.alloca !cir.complex<!cir.double>, !cir.ptr<!cir.complex<!cir.double>>, ["c"]
225+
// CIR: %[[REAL_PTR:.*]] = cir.complex.real_ptr %[[COMPLEX]] : !cir.ptr<!cir.complex<!cir.double>> -> !cir.ptr<!cir.double>
226+
227+
// LLVM: %[[COMPLEX:.*]] = alloca { double, double }, i64 1, align 8
228+
// LLVM: %[[REAL_PTR:.*]] = getelementptr inbounds nuw { double, double }, ptr %[[COMPLEX]], i32 0, i32 0
229+
230+
// OGCG: %[[COMPLEX:.*]] = alloca { double, double }, align 8
231+
// OGCG: %[[REAL_PTR:.*]] = getelementptr inbounds nuw { double, double }, ptr %[[COMPLEX]], i32 0, i32 0
232+
219233
void foo12() {
220234
double _Complex c;
221235
double imag = __imag__ c;
@@ -751,4 +765,4 @@ void foo29() {
751765
// OGCG: %[[INIT_REAL_PTR:.*]] = getelementptr inbounds nuw { i32, i32 }, ptr %[[INIT]], i32 0, i32 0
752766
// OGCG: %[[INIT_IMAG_PTR:.*]] = getelementptr inbounds nuw { i32, i32 }, ptr %[[INIT]], i32 0, i32 1
753767
// OGCG: store i32 0, ptr %[[INIT_REAL_PTR]], align 4
754-
// OGCG: store i32 0, ptr %[[INIT_IMAG_PTR]], align 4
768+
// OGCG: store i32 0, ptr %[[INIT_IMAG_PTR]], align 4

clang/test/CIR/IR/invalid-complex.cir

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,3 +45,15 @@ module {
4545
cir.return
4646
}
4747
}
48+
49+
50+
// -----
51+
52+
module {
53+
cir.func @complex_real_ptr_invalid_result_type() -> !cir.double {
54+
%0 = cir.alloca !cir.complex<!cir.double>, !cir.ptr<!cir.complex<!cir.double>>, ["c"]
55+
// expected-error @below {{result type does not match operand type}}
56+
%1 = cir.complex.real_ptr %0 : !cir.ptr<!cir.complex<!cir.double>> -> !cir.ptr<!cir.float>
57+
cir.return
58+
}
59+
}

0 commit comments

Comments
 (0)