Skip to content

Commit 21bc23e

Browse files
authored
[CIR] Upstream support for accessing structure members (#136383)
This adds ClangIR support for accessing structure members. Access to union members is deferred to a later change.
1 parent 89a792e commit 21bc23e

File tree

18 files changed

+470
-24
lines changed

18 files changed

+470
-24
lines changed

clang/include/clang/CIR/Dialect/Builder/CIRBaseBuilder.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,12 @@ class CIRBaseBuilderTy : public mlir::OpBuilder {
191191
return create<cir::StoreOp>(loc, val, dst);
192192
}
193193

194+
cir::GetMemberOp createGetMember(mlir::Location loc, mlir::Type resultTy,
195+
mlir::Value base, llvm::StringRef name,
196+
unsigned index) {
197+
return create<cir::GetMemberOp>(loc, resultTy, base, name, index);
198+
}
199+
194200
mlir::Value createDummyValue(mlir::Location loc, mlir::Type type,
195201
clang::CharUnits alignment) {
196202
auto addr = createAlloca(loc, getPointerTo(type), type, {},

clang/include/clang/CIR/Dialect/IR/CIROps.td

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1380,6 +1380,70 @@ def GetGlobalOp : CIR_Op<"get_global",
13801380
let hasVerifier = 0;
13811381
}
13821382

1383+
//===----------------------------------------------------------------------===//
1384+
// GetMemberOp
1385+
//===----------------------------------------------------------------------===//
1386+
1387+
def GetMemberOp : CIR_Op<"get_member"> {
1388+
let summary = "Get the address of a member of a record";
1389+
let description = [{
1390+
The `cir.get_member` operation gets the address of a particular named
1391+
member from the input record.
1392+
1393+
It expects a pointer to the base record as well as the name of the member
1394+
and its field index.
1395+
1396+
Example:
1397+
```mlir
1398+
// Suppose we have a record with multiple members.
1399+
!s32i = !cir.int<s, 32>
1400+
!s8i = !cir.int<s, 8>
1401+
!ty_B = !cir.record<"struct.B" {!s32i, !s8i}>
1402+
1403+
// Get the address of the member at index 1.
1404+
%1 = cir.get_member %0[1] {name = "i"} : (!cir.ptr<!ty_B>) -> !cir.ptr<!s8i>
1405+
```
1406+
}];
1407+
1408+
let arguments = (ins
1409+
Arg<CIR_PointerType, "the address to load from", [MemRead]>:$addr,
1410+
StrAttr:$name,
1411+
IndexAttr:$index_attr);
1412+
1413+
let results = (outs Res<CIR_PointerType, "">:$result);
1414+
1415+
let assemblyFormat = [{
1416+
$addr `[` $index_attr `]` attr-dict
1417+
`:` qualified(type($addr)) `->` qualified(type($result))
1418+
}];
1419+
1420+
let builders = [
1421+
OpBuilder<(ins "mlir::Type":$type,
1422+
"mlir::Value":$value,
1423+
"llvm::StringRef":$name,
1424+
"unsigned":$index),
1425+
[{
1426+
mlir::APInt fieldIdx(64, index);
1427+
build($_builder, $_state, type, value, name, fieldIdx);
1428+
}]>
1429+
];
1430+
1431+
let extraClassDeclaration = [{
1432+
/// Return the index of the record member being accessed.
1433+
uint64_t getIndex() { return getIndexAttr().getZExtValue(); }
1434+
1435+
/// Return the record type pointed by the base pointer.
1436+
cir::PointerType getAddrTy() { return getAddr().getType(); }
1437+
1438+
/// Return the result type.
1439+
cir::PointerType getResultTy() {
1440+
return mlir::cast<cir::PointerType>(getResult().getType());
1441+
}
1442+
}];
1443+
1444+
let hasVerifier = 1;
1445+
}
1446+
13831447
//===----------------------------------------------------------------------===//
13841448
// FuncOp
13851449
//===----------------------------------------------------------------------===//

clang/include/clang/CIR/Dialect/IR/CIRTypes.td

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -511,6 +511,9 @@ def CIR_RecordType : CIR_Type<"Record", "record",
511511
void complete(llvm::ArrayRef<mlir::Type> members, bool packed,
512512
bool isPadded);
513513

514+
uint64_t getElementOffset(const mlir::DataLayout &dataLayout,
515+
unsigned idx) const;
516+
514517
private:
515518
unsigned computeStructSize(const mlir::DataLayout &dataLayout) const;
516519
uint64_t computeStructAlignment(const mlir::DataLayout &dataLayout) const;

clang/include/clang/CIR/MissingFeatures.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,8 @@ struct MissingFeatures {
157157
static bool emitCheckedInBoundsGEP() { return false; }
158158
static bool preservedAccessIndexRegion() { return false; }
159159
static bool bitfields() { return false; }
160+
static bool typeChecks() { return false; }
161+
static bool lambdaFieldToName() { return false; }
160162

161163
// Missing types
162164
static bool dataMemberType() { return false; }

clang/lib/CIR/CodeGen/CIRGenExpr.cpp

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,38 @@ using namespace clang;
2727
using namespace clang::CIRGen;
2828
using namespace cir;
2929

30+
/// Get the address of a zero-sized field within a record. The resulting address
31+
/// doesn't necessarily have the right type.
32+
Address CIRGenFunction::emitAddrOfFieldStorage(Address base,
33+
const FieldDecl *field,
34+
llvm::StringRef fieldName,
35+
unsigned fieldIndex) {
36+
if (field->isZeroSize(getContext())) {
37+
cgm.errorNYI(field->getSourceRange(),
38+
"emitAddrOfFieldStorage: zero-sized field");
39+
return Address::invalid();
40+
}
41+
42+
mlir::Location loc = getLoc(field->getLocation());
43+
44+
mlir::Type fieldType = convertType(field->getType());
45+
auto fieldPtr = cir::PointerType::get(builder.getContext(), fieldType);
46+
// For most cases fieldName is the same as field->getName() but for lambdas,
47+
// which do not currently carry the name, so it can be passed down from the
48+
// CaptureStmt.
49+
cir::GetMemberOp memberAddr = builder.createGetMember(
50+
loc, fieldPtr, base.getPointer(), fieldName, fieldIndex);
51+
52+
// Retrieve layout information, compute alignment and return the final
53+
// address.
54+
const RecordDecl *rec = field->getParent();
55+
const CIRGenRecordLayout &layout = cgm.getTypes().getCIRGenRecordLayout(rec);
56+
unsigned idx = layout.getCIRFieldNo(field);
57+
CharUnits offset = CharUnits::fromQuantity(
58+
layout.getCIRType().getElementOffset(cgm.getDataLayout().layout, idx));
59+
return Address(memberAddr, base.getAlignment().alignmentAtOffset(offset));
60+
}
61+
3062
/// Given an expression of pointer type, try to
3163
/// derive a more accurate bound on the alignment of the pointer.
3264
Address CIRGenFunction::emitPointerWithAlignment(const Expr *expr,
@@ -264,6 +296,66 @@ mlir::Value CIRGenFunction::emitStoreThroughBitfieldLValue(RValue src,
264296
return {};
265297
}
266298

299+
LValue CIRGenFunction::emitLValueForField(LValue base, const FieldDecl *field) {
300+
LValueBaseInfo baseInfo = base.getBaseInfo();
301+
302+
if (field->isBitField()) {
303+
cgm.errorNYI(field->getSourceRange(), "emitLValueForField: bitfield");
304+
return LValue();
305+
}
306+
307+
QualType fieldType = field->getType();
308+
const RecordDecl *rec = field->getParent();
309+
AlignmentSource baseAlignSource = baseInfo.getAlignmentSource();
310+
LValueBaseInfo fieldBaseInfo(getFieldAlignmentSource(baseAlignSource));
311+
assert(!cir::MissingFeatures::opTBAA());
312+
313+
Address addr = base.getAddress();
314+
if (auto *classDef = dyn_cast<CXXRecordDecl>(rec)) {
315+
cgm.errorNYI(field->getSourceRange(), "emitLValueForField: C++ class");
316+
return LValue();
317+
}
318+
319+
unsigned recordCVR = base.getVRQualifiers();
320+
if (rec->isUnion()) {
321+
cgm.errorNYI(field->getSourceRange(), "emitLValueForField: union");
322+
return LValue();
323+
}
324+
325+
assert(!cir::MissingFeatures::preservedAccessIndexRegion());
326+
llvm::StringRef fieldName = field->getName();
327+
const CIRGenRecordLayout &layout =
328+
cgm.getTypes().getCIRGenRecordLayout(field->getParent());
329+
unsigned fieldIndex = layout.getCIRFieldNo(field);
330+
331+
assert(!cir::MissingFeatures::lambdaFieldToName());
332+
333+
addr = emitAddrOfFieldStorage(addr, field, fieldName, fieldIndex);
334+
335+
// If this is a reference field, load the reference right now.
336+
if (fieldType->isReferenceType()) {
337+
cgm.errorNYI(field->getSourceRange(), "emitLValueForField: reference type");
338+
return LValue();
339+
}
340+
341+
if (field->hasAttr<AnnotateAttr>()) {
342+
cgm.errorNYI(field->getSourceRange(), "emitLValueForField: AnnotateAttr");
343+
return LValue();
344+
}
345+
346+
LValue lv = makeAddrLValue(addr, fieldType, fieldBaseInfo);
347+
lv.getQuals().addCVRQualifiers(recordCVR);
348+
349+
// __weak attribute on a field is ignored.
350+
if (lv.getQuals().getObjCGCAttr() == Qualifiers::Weak) {
351+
cgm.errorNYI(field->getSourceRange(),
352+
"emitLValueForField: __weak attribute");
353+
return LValue();
354+
}
355+
356+
return lv;
357+
}
358+
267359
mlir::Value CIRGenFunction::emitToMemory(mlir::Value value, QualType ty) {
268360
// Bool has a different representation in memory than in registers,
269361
// but in ClangIR, it is simply represented as a cir.bool value.
@@ -608,6 +700,48 @@ CIRGenFunction::emitArraySubscriptExpr(const clang::ArraySubscriptExpr *e) {
608700
return lv;
609701
}
610702

703+
LValue CIRGenFunction::emitMemberExpr(const MemberExpr *e) {
704+
if (auto *vd = dyn_cast<VarDecl>(e->getMemberDecl())) {
705+
cgm.errorNYI(e->getSourceRange(), "emitMemberExpr: VarDecl");
706+
return LValue();
707+
}
708+
709+
Expr *baseExpr = e->getBase();
710+
// If this is s.x, emit s as an lvalue. If it is s->x, emit s as a scalar.
711+
LValue baseLV;
712+
if (e->isArrow()) {
713+
LValueBaseInfo baseInfo;
714+
assert(!cir::MissingFeatures::opTBAA());
715+
Address addr = emitPointerWithAlignment(baseExpr, &baseInfo);
716+
QualType ptrTy = baseExpr->getType()->getPointeeType();
717+
assert(!cir::MissingFeatures::typeChecks());
718+
baseLV = makeAddrLValue(addr, ptrTy, baseInfo);
719+
} else {
720+
assert(!cir::MissingFeatures::typeChecks());
721+
baseLV = emitLValue(baseExpr);
722+
}
723+
724+
const NamedDecl *nd = e->getMemberDecl();
725+
if (auto *field = dyn_cast<FieldDecl>(nd)) {
726+
LValue lv = emitLValueForField(baseLV, field);
727+
assert(!cir::MissingFeatures::setObjCGCLValueClass());
728+
if (getLangOpts().OpenMP) {
729+
// If the member was explicitly marked as nontemporal, mark it as
730+
// nontemporal. If the base lvalue is marked as nontemporal, mark access
731+
// to children as nontemporal too.
732+
cgm.errorNYI(e->getSourceRange(), "emitMemberExpr: OpenMP");
733+
}
734+
return lv;
735+
}
736+
737+
if (const auto *fd = dyn_cast<FunctionDecl>(nd)) {
738+
cgm.errorNYI(e->getSourceRange(), "emitMemberExpr: FunctionDecl");
739+
return LValue();
740+
}
741+
742+
llvm_unreachable("Unhandled member declaration!");
743+
}
744+
611745
LValue CIRGenFunction::emitBinaryOperatorLValue(const BinaryOperator *e) {
612746
// Comma expressions just emit their LHS then their RHS as an l-value.
613747
if (e->getOpcode() == BO_Comma) {

clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,8 @@ class ScalarExprEmitter : public StmtVisitor<ScalarExprEmitter, mlir::Value> {
168168
return emitLoadOfLValue(e);
169169
}
170170

171+
mlir::Value VisitMemberExpr(MemberExpr *e);
172+
171173
mlir::Value VisitExplicitCastExpr(ExplicitCastExpr *e) {
172174
return VisitCastExpr(e);
173175
}
@@ -1520,6 +1522,19 @@ mlir::Value ScalarExprEmitter::VisitCallExpr(const CallExpr *e) {
15201522
return v;
15211523
}
15221524

1525+
mlir::Value ScalarExprEmitter::VisitMemberExpr(MemberExpr *e) {
1526+
// TODO(cir): The classic codegen calls tryEmitAsConstant() here. Folding
1527+
// constants sound like work for MLIR optimizers, but we'll keep an assertion
1528+
// for now.
1529+
assert(!cir::MissingFeatures::tryEmitAsConstant());
1530+
Expr::EvalResult result;
1531+
if (e->EvaluateAsInt(result, cgf.getContext(), Expr::SE_AllowSideEffects)) {
1532+
cgf.cgm.errorNYI(e->getSourceRange(), "Constant interger member expr");
1533+
// Fall through to emit this as a non-constant access.
1534+
}
1535+
return emitLoadOfLValue(e);
1536+
}
1537+
15231538
mlir::Value CIRGenFunction::emitScalarConversion(mlir::Value src,
15241539
QualType srcTy, QualType dstTy,
15251540
SourceLocation loc) {

clang/lib/CIR/CodeGen/CIRGenFunction.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -513,6 +513,8 @@ LValue CIRGenFunction::emitLValue(const Expr *e) {
513513
return emitArraySubscriptExpr(cast<ArraySubscriptExpr>(e));
514514
case Expr::UnaryOperatorClass:
515515
return emitUnaryOpLValue(cast<UnaryOperator>(e));
516+
case Expr::MemberExprClass:
517+
return emitMemberExpr(cast<MemberExpr>(e));
516518
case Expr::BinaryOperatorClass:
517519
return emitBinaryOperatorLValue(cast<BinaryOperator>(e));
518520
case Expr::ParenExprClass:

clang/lib/CIR/CodeGen/CIRGenFunction.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -423,6 +423,10 @@ class CIRGenFunction : public CIRGenTypeCache {
423423
clang::CharUnits alignment);
424424

425425
public:
426+
Address emitAddrOfFieldStorage(Address base, const FieldDecl *field,
427+
llvm::StringRef fieldName,
428+
unsigned fieldIndex);
429+
426430
mlir::Value emitAlloca(llvm::StringRef name, mlir::Type ty,
427431
mlir::Location loc, clang::CharUnits alignment,
428432
bool insertIntoFnEntryBlock,
@@ -551,6 +555,9 @@ class CIRGenFunction : public CIRGenTypeCache {
551555
/// of the expression.
552556
/// FIXME: document this function better.
553557
LValue emitLValue(const clang::Expr *e);
558+
LValue emitLValueForField(LValue base, const clang::FieldDecl *field);
559+
560+
LValue emitMemberExpr(const MemberExpr *e);
554561

555562
/// Given an expression with a pointer type, emit the value and compute our
556563
/// best estimate of the alignment of the pointee.

clang/lib/CIR/CodeGen/CIRGenModule.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#include "CIRGenValue.h"
2020

2121
#include "clang/AST/CharUnits.h"
22+
#include "clang/CIR/Dialect/IR/CIRDataLayout.h"
2223
#include "clang/CIR/Dialect/IR/CIRDialect.h"
2324

2425
#include "TargetInfo.h"
@@ -95,6 +96,12 @@ class CIRGenModule : public CIRGenTypeCache {
9596
const clang::LangOptions &getLangOpts() const { return langOpts; }
9697
mlir::MLIRContext &getMLIRContext() { return *builder.getContext(); }
9798

99+
const cir::CIRDataLayout getDataLayout() const {
100+
// FIXME(cir): instead of creating a CIRDataLayout every time, set it as an
101+
// attribute for the CIRModule class.
102+
return cir::CIRDataLayout(theModule);
103+
}
104+
98105
/// -------
99106
/// Handling globals
100107
/// -------

clang/lib/CIR/CodeGen/CIRGenRecordLayout.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,10 +42,10 @@ class CIRGenRecordLayout {
4242
cir::RecordType getCIRType() const { return completeObjectType; }
4343

4444
/// Return cir::RecordType element number that corresponds to the field FD.
45-
unsigned getCIRFieldNo(const clang::FieldDecl *FD) const {
46-
FD = FD->getCanonicalDecl();
47-
assert(fieldInfo.count(FD) && "Invalid field for record!");
48-
return fieldInfo.lookup(FD);
45+
unsigned getCIRFieldNo(const clang::FieldDecl *fd) const {
46+
fd = fd->getCanonicalDecl();
47+
assert(fieldInfo.count(fd) && "Invalid field for record!");
48+
return fieldInfo.lookup(fd);
4949
}
5050
};
5151

0 commit comments

Comments
 (0)