Skip to content

Commit 47fa3eb

Browse files
authored
[CIR] Backport support for global ComplexType init (#1665)
Backport support global initialization for ComplexType from (llvm/llvm-project#141369)
1 parent 7b8a99b commit 47fa3eb

File tree

6 files changed

+163
-4
lines changed

6 files changed

+163
-4
lines changed

clang/include/clang/CIR/LoweringHelpers.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,4 +50,9 @@ lowerConstArrayAttr(cir::ConstArrayAttr constArr,
5050
std::optional<mlir::Attribute>
5151
lowerConstVectorAttr(cir::ConstVectorAttr constArr,
5252
const mlir::TypeConverter *converter);
53+
54+
std::optional<mlir::Attribute>
55+
lowerConstComplexAttr(cir::ComplexAttr constArr,
56+
const mlir::TypeConverter *converter);
57+
5358
#endif

clang/lib/CIR/CodeGen/CIRGenExprConst.cpp

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2007,9 +2007,30 @@ mlir::Attribute ConstantEmitter::tryEmitPrivate(const APValue &Value,
20072007
case APValue::Struct:
20082008
case APValue::Union:
20092009
return ConstRecordBuilder::BuildRecord(*this, Value, DestType);
2010-
case APValue::FixedPoint:
2011-
case APValue::ComplexInt:
20122010
case APValue::ComplexFloat:
2011+
case APValue::ComplexInt: {
2012+
mlir::Type desiredType = CGM.convertType(DestType);
2013+
cir::ComplexType complexType =
2014+
mlir::dyn_cast<cir::ComplexType>(desiredType);
2015+
2016+
mlir::Type complexElemTy = complexType.getElementType();
2017+
if (isa<cir::IntType>(complexElemTy)) {
2018+
llvm::APSInt real = Value.getComplexIntReal();
2019+
llvm::APSInt imag = Value.getComplexIntImag();
2020+
return builder.getAttr<cir::ComplexAttr>(
2021+
complexType, builder.getAttr<cir::IntAttr>(complexElemTy, real),
2022+
builder.getAttr<cir::IntAttr>(complexElemTy, imag));
2023+
}
2024+
2025+
assert(isa<cir::CIRFPTypeInterface>(complexElemTy) &&
2026+
"expected floating-point type");
2027+
llvm::APFloat real = Value.getComplexFloatReal();
2028+
llvm::APFloat imag = Value.getComplexFloatImag();
2029+
return builder.getAttr<cir::ComplexAttr>(
2030+
complexType, builder.getAttr<cir::FPAttr>(complexElemTy, real),
2031+
builder.getAttr<cir::FPAttr>(complexElemTy, imag));
2032+
}
2033+
case APValue::FixedPoint:
20132034
case APValue::AddrLabelDiff:
20142035
assert(0 && "not implemented");
20152036
}

clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp

Lines changed: 46 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -452,8 +452,8 @@ class CirAttrToValue {
452452
.Case<cir::IntAttr, cir::FPAttr, cir::ConstPtrAttr,
453453
cir::ConstRecordAttr, cir::ConstArrayAttr, cir::ConstVectorAttr,
454454
cir::BoolAttr, cir::ZeroAttr, cir::UndefAttr, cir::PoisonAttr,
455-
cir::GlobalViewAttr, cir::VTableAttr, cir::TypeInfoAttr>(
456-
[&](auto attrT) { return visitCirAttr(attrT); })
455+
cir::GlobalViewAttr, cir::VTableAttr, cir::TypeInfoAttr,
456+
cir::ComplexAttr>([&](auto attrT) { return visitCirAttr(attrT); })
457457
.Default([&](auto attrT) { return mlir::Value(); });
458458
}
459459

@@ -463,6 +463,7 @@ class CirAttrToValue {
463463
mlir::Value visitCirAttr(cir::ConstRecordAttr attr);
464464
mlir::Value visitCirAttr(cir::ConstArrayAttr attr);
465465
mlir::Value visitCirAttr(cir::ConstVectorAttr attr);
466+
mlir::Value visitCirAttr(cir::ComplexAttr attr);
466467
mlir::Value visitCirAttr(cir::BoolAttr attr);
467468
mlir::Value visitCirAttr(cir::ZeroAttr attr);
468469
mlir::Value visitCirAttr(cir::UndefAttr attr);
@@ -647,6 +648,33 @@ mlir::Value CirAttrToValue::visitCirAttr(cir::ConstVectorAttr constVec) {
647648
mlirValues));
648649
}
649650

651+
mlir::Value CirAttrToValue::visitCirAttr(cir::ComplexAttr complexAttr) {
652+
auto complexType = mlir::cast<cir::ComplexType>(complexAttr.getType());
653+
mlir::Type complexElemTy = complexType.getElementType();
654+
mlir::Type complexElemLLVMTy = converter->convertType(complexElemTy);
655+
656+
mlir::Attribute components[2];
657+
if (const auto intType = mlir::dyn_cast<cir::IntType>(complexElemTy)) {
658+
components[0] = rewriter.getIntegerAttr(
659+
complexElemLLVMTy,
660+
mlir::cast<cir::IntAttr>(complexAttr.getReal()).getValue());
661+
components[1] = rewriter.getIntegerAttr(
662+
complexElemLLVMTy,
663+
mlir::cast<cir::IntAttr>(complexAttr.getImag()).getValue());
664+
} else {
665+
components[0] = rewriter.getFloatAttr(
666+
complexElemLLVMTy,
667+
mlir::cast<cir::FPAttr>(complexAttr.getReal()).getValue());
668+
components[1] = rewriter.getFloatAttr(
669+
complexElemLLVMTy,
670+
mlir::cast<cir::FPAttr>(complexAttr.getImag()).getValue());
671+
}
672+
673+
mlir::Location loc = parentOp->getLoc();
674+
return rewriter.create<mlir::LLVM::ConstantOp>(
675+
loc, converter->convertType(complexAttr.getType()),
676+
rewriter.getArrayAttr(components));
677+
}
650678
// GlobalViewAttr visitor.
651679
mlir::Value CirAttrToValue::visitCirAttr(cir::GlobalViewAttr globalAttr) {
652680
auto module = parentOp->getParentOfType<mlir::ModuleOp>();
@@ -2419,6 +2447,9 @@ mlir::LogicalResult CIRToLLVMGlobalOpLowering::lowerInitializer(
24192447
} else if (mlir::isa<cir::ConstVectorAttr>(init)) {
24202448
return lowerInitializerForConstVector(rewriter, op, init,
24212449
useInitializerRegion);
2450+
} else if (mlir::isa<cir::ComplexAttr>(init)) {
2451+
return lowerInitializerForConstComplex(rewriter, op, init,
2452+
useInitializerRegion);
24222453
} else if (auto dataMemberAttr = mlir::dyn_cast<cir::DataMemberAttr>(init)) {
24232454
assert(lowerMod && "lower module is not available");
24242455
mlir::DataLayout layout(op->getParentOfType<mlir::ModuleOp>());
@@ -2455,6 +2486,19 @@ mlir::LogicalResult CIRToLLVMGlobalOpLowering::lowerInitializerForConstVector(
24552486
return mlir::failure();
24562487
}
24572488

2489+
mlir::LogicalResult CIRToLLVMGlobalOpLowering::lowerInitializerForConstComplex(
2490+
mlir::ConversionPatternRewriter &rewriter, cir::GlobalOp op,
2491+
mlir::Attribute &init, bool &useInitializerRegion) const {
2492+
auto constVec = mlir::cast<cir::ComplexAttr>(init);
2493+
if (auto val = lowerConstComplexAttr(constVec, getTypeConverter());
2494+
val.has_value()) {
2495+
init = val.value();
2496+
useInitializerRegion = false;
2497+
} else
2498+
useInitializerRegion = true;
2499+
return mlir::success();
2500+
}
2501+
24582502
mlir::LogicalResult CIRToLLVMGlobalOpLowering::lowerInitializerForConstArray(
24592503
mlir::ConversionPatternRewriter &rewriter, cir::GlobalOp op,
24602504
mlir::Attribute &init, bool &useInitializerRegion) const {

clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -628,6 +628,11 @@ class CIRToLLVMGlobalOpLowering
628628
cir::GlobalOp op, mlir::Attribute &init,
629629
bool &useInitializerRegion) const;
630630

631+
mlir::LogicalResult
632+
lowerInitializerForConstComplex(mlir::ConversionPatternRewriter &rewriter,
633+
cir::GlobalOp op, mlir::Attribute &init,
634+
bool &useInitializerRegion) const;
635+
631636
mlir::LogicalResult
632637
lowerInitializerDirect(mlir::ConversionPatternRewriter &rewriter,
633638
cir::GlobalOp op, mlir::Type llvmType,

clang/lib/CIR/Lowering/LoweringHelpers.cpp

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,33 @@ void convertToDenseElementsAttrImpl(
126126
}
127127
}
128128

129+
template <typename AttrTy, typename StorageTy>
130+
void convertToDenseElementsAttrImpl(
131+
cir::ComplexAttr attr, llvm::SmallVectorImpl<StorageTy> &values,
132+
const llvm::SmallVectorImpl<int64_t> &currentDims, int64_t dimIndex,
133+
int64_t currentIndex) {
134+
dimIndex++;
135+
std::size_t elementsSizeInCurrentDim = 1;
136+
for (std::size_t i = dimIndex; i < currentDims.size(); i++)
137+
elementsSizeInCurrentDim *= currentDims[i];
138+
139+
auto attrArray =
140+
mlir::ArrayAttr::get(attr.getContext(), {attr.getImag(), attr.getReal()});
141+
for (auto eltAttr : attrArray) {
142+
if (auto valueAttr = mlir::dyn_cast<AttrTy>(eltAttr)) {
143+
values[currentIndex++] = valueAttr.getValue();
144+
continue;
145+
}
146+
147+
if (mlir::isa<cir::ZeroAttr, cir::UndefAttr>(eltAttr)) {
148+
currentIndex += elementsSizeInCurrentDim;
149+
continue;
150+
}
151+
152+
llvm_unreachable("unknown element in ComplexAttr");
153+
}
154+
}
155+
129156
template <typename AttrTy, typename StorageTy>
130157
mlir::DenseElementsAttr convertToDenseElementsAttr(
131158
cir::ConstArrayAttr attr, const llvm::SmallVectorImpl<int64_t> &dims,
@@ -158,6 +185,20 @@ mlir::DenseElementsAttr convertToDenseElementsAttr(
158185
llvm::ArrayRef(values));
159186
}
160187

188+
template <typename AttrTy, typename StorageTy>
189+
mlir::DenseElementsAttr convertToDenseElementsAttr(
190+
cir::ComplexAttr attr, const llvm::SmallVectorImpl<int64_t> &dims,
191+
mlir::Type elementType, mlir::Type convertedElementType) {
192+
unsigned array_size = 2;
193+
auto values = llvm::SmallVector<StorageTy, 8>(
194+
array_size, getZeroInitFromType<StorageTy>(elementType));
195+
convertToDenseElementsAttrImpl<AttrTy>(attr, values, dims, /*currentDim=*/0,
196+
/*initialIndex=*/0);
197+
return mlir::DenseElementsAttr::get(
198+
mlir::RankedTensorType::get(dims, convertedElementType),
199+
llvm::ArrayRef(values));
200+
}
201+
161202
std::optional<mlir::Attribute>
162203
lowerConstArrayAttr(cir::ConstArrayAttr constArr,
163204
const mlir::TypeConverter *converter) {
@@ -191,6 +232,27 @@ lowerConstArrayAttr(cir::ConstArrayAttr constArr,
191232
return std::nullopt;
192233
}
193234

235+
std::optional<mlir::Attribute>
236+
lowerConstComplexAttr(cir::ComplexAttr constComplex,
237+
const mlir::TypeConverter *converter) {
238+
239+
// Ensure ComplexAttr has a type.
240+
auto typedConstArr = mlir::dyn_cast<mlir::TypedAttr>(constComplex);
241+
assert(typedConstArr && "cir::ComplexAttr is not a mlir::TypedAttr");
242+
243+
mlir::Type type = constComplex.getType();
244+
auto dims = llvm::SmallVector<int64_t, 2>{2};
245+
246+
if (mlir::isa<cir::IntType>(type))
247+
return convertToDenseElementsAttr<cir::IntAttr, mlir::APInt>(
248+
constComplex, dims, type, converter->convertType(type));
249+
if (mlir::isa<cir::CIRFPTypeInterface>(type))
250+
return convertToDenseElementsAttr<cir::FPAttr, mlir::APFloat>(
251+
constComplex, dims, type, converter->convertType(type));
252+
253+
return std::nullopt;
254+
}
255+
194256
std::optional<mlir::Attribute>
195257
lowerConstVectorAttr(cir::ConstVectorAttr constArr,
196258
const mlir::TypeConverter *converter) {
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -Wno-unused-value -fclangir -emit-cir %s -o %t.cir
2+
// RUN: FileCheck --input-file=%t.cir %s -check-prefix=CHECK
3+
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -Wno-unused-value -fclangir -emit-llvm %s -o %t-cir.ll
4+
// RUN: FileCheck --input-file=%t-cir.ll %s -check-prefix=LLVM
5+
6+
int _Complex gci;
7+
8+
float _Complex gcf;
9+
10+
int _Complex gci2 = { 1, 2 };
11+
12+
float _Complex gcf2 = { 1.0f, 2.0f };
13+
14+
// CHECK: cir.global external {{.*}} = #cir.zero : !cir.complex<!s32i>
15+
// CHECK: cir.global external {{.*}} = #cir.zero : !cir.complex<!cir.float>
16+
// CHECK: cir.global external {{.*}} = #cir.complex<#cir.int<1> : !s32i, #cir.int<2> : !s32i> : !cir.complex<!s32i>
17+
// CHECK: cir.global external {{.*}} = #cir.complex<#cir.fp<1.000000e+00> : !cir.float, #cir.fp<2.000000e+00> : !cir.float> : !cir.complex<!cir.float>
18+
19+
// LLVM: {{.*}} = global { i32, i32 } zeroinitializer, align 4
20+
// LLVM: {{.*}} = global { float, float } zeroinitializer, align 4
21+
// LLVM: {{.*}} = global { i32, i32 } { i32 1, i32 2 }, align 4
22+
// LLVM: {{.*}} = global { float, float } { float 1.000000e+00, float 2.000000e+00 }, align 4

0 commit comments

Comments
 (0)