Skip to content

Commit b7c521b

Browse files
committed
[OpenACC][CIR] Lowering for 'vector_length' on compute constructs
This is the same as the 'num_workers', with slightly different names in places, so we just do the same exact implementation. This extracts the implementation as well, which should make it easier to reuse.
1 parent 99e4b39 commit b7c521b

File tree

3 files changed

+171
-45
lines changed

3 files changed

+171
-45
lines changed

clang/lib/CIR/CodeGen/CIRGenStmtOpenACC.cpp

Lines changed: 67 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,56 @@ class OpenACCClauseCIREmitter final
8282
return conversionOp.getResult(0);
8383
}
8484

85+
mlir::acc::DeviceType decodeDeviceType(const IdentifierInfo *ii) {
86+
// '*' case leaves no identifier-info, just a nullptr.
87+
if (!ii)
88+
return mlir::acc::DeviceType::Star;
89+
return llvm::StringSwitch<mlir::acc::DeviceType>(ii->getName())
90+
.CaseLower("default", mlir::acc::DeviceType::Default)
91+
.CaseLower("host", mlir::acc::DeviceType::Host)
92+
.CaseLower("multicore", mlir::acc::DeviceType::Multicore)
93+
.CasesLower("nvidia", "acc_device_nvidia",
94+
mlir::acc::DeviceType::Nvidia)
95+
.CaseLower("radeon", mlir::acc::DeviceType::Radeon);
96+
}
97+
98+
// Handle a clause affected by the 'device-type' to the point that they need
99+
// to have the attributes added in the correct/corresponding order, such as
100+
// 'num_workers' or 'vector_length' on a compute construct.
101+
mlir::ArrayAttr
102+
handleDeviceTypeAffectedClause(mlir::ArrayAttr existingDeviceTypes,
103+
mlir::Value argument,
104+
mlir::MutableOperandRange &argCollection) {
105+
llvm::SmallVector<mlir::Attribute> deviceTypes;
106+
107+
// Collect the 'existing' device-type attributes so we can re-create them
108+
// and insert them.
109+
if (existingDeviceTypes) {
110+
for (const mlir::Attribute &Attr : existingDeviceTypes)
111+
deviceTypes.push_back(mlir::acc::DeviceTypeAttr::get(
112+
builder.getContext(),
113+
cast<mlir::acc::DeviceTypeAttr>(Attr).getValue()));
114+
}
115+
116+
// Insert 1 version of the 'expr' to the NumWorkers list per-current
117+
// device type.
118+
if (lastDeviceTypeClause) {
119+
for (const DeviceTypeArgument &arch :
120+
lastDeviceTypeClause->getArchitectures()) {
121+
deviceTypes.push_back(mlir::acc::DeviceTypeAttr::get(
122+
builder.getContext(), decodeDeviceType(arch.getIdentifierInfo())));
123+
argCollection.append(argument);
124+
}
125+
} else {
126+
// Else, we just add a single for 'none'.
127+
deviceTypes.push_back(mlir::acc::DeviceTypeAttr::get(
128+
builder.getContext(), mlir::acc::DeviceType::None));
129+
argCollection.append(argument);
130+
}
131+
132+
return mlir::ArrayAttr::get(builder.getContext(), deviceTypes);
133+
}
134+
85135
public:
86136
OpenACCClauseCIREmitter(OpTy &operation, CIRGenFunction &cgf,
87137
CIRGenBuilderTy &builder,
@@ -112,19 +162,6 @@ class OpenACCClauseCIREmitter final
112162
}
113163
}
114164

115-
mlir::acc::DeviceType decodeDeviceType(const IdentifierInfo *ii) {
116-
// '*' case leaves no identifier-info, just a nullptr.
117-
if (!ii)
118-
return mlir::acc::DeviceType::Star;
119-
return llvm::StringSwitch<mlir::acc::DeviceType>(ii->getName())
120-
.CaseLower("default", mlir::acc::DeviceType::Default)
121-
.CaseLower("host", mlir::acc::DeviceType::Host)
122-
.CaseLower("multicore", mlir::acc::DeviceType::Multicore)
123-
.CasesLower("nvidia", "acc_device_nvidia",
124-
mlir::acc::DeviceType::Nvidia)
125-
.CaseLower("radeon", mlir::acc::DeviceType::Radeon);
126-
}
127-
128165
void VisitDeviceTypeClause(const OpenACCDeviceTypeClause &clause) {
129166
lastDeviceTypeClause = &clause;
130167
if constexpr (isOneOfTypes<OpTy, InitOp, ShutdownOp>) {
@@ -165,45 +202,30 @@ class OpenACCClauseCIREmitter final
165202

166203
void VisitNumWorkersClause(const OpenACCNumWorkersClause &clause) {
167204
if constexpr (isOneOfTypes<OpTy, ParallelOp, KernelsOp>) {
168-
// Collect the 'existing' device-type attributes so we can re-create them
169-
// and insert them.
170-
llvm::SmallVector<mlir::Attribute> deviceTypes;
171-
mlir::ArrayAttr existingDeviceTypes =
172-
operation.getNumWorkersDeviceTypeAttr();
173-
174-
if (existingDeviceTypes) {
175-
for (mlir::Attribute Attr : existingDeviceTypes)
176-
deviceTypes.push_back(mlir::acc::DeviceTypeAttr::get(
177-
builder.getContext(),
178-
cast<mlir::acc::DeviceTypeAttr>(Attr).getValue()));
179-
}
180-
181-
// Insert 1 version of the 'int-expr' to the NumWorkers list per-current
182-
// device type.
183-
mlir::Value intExpr = createIntExpr(clause.getIntExpr());
184-
if (lastDeviceTypeClause) {
185-
for (const DeviceTypeArgument &arg :
186-
lastDeviceTypeClause->getArchitectures()) {
187-
deviceTypes.push_back(mlir::acc::DeviceTypeAttr::get(
188-
builder.getContext(), decodeDeviceType(arg.getIdentifierInfo())));
189-
operation.getNumWorkersMutable().append(intExpr);
190-
}
191-
} else {
192-
// Else, we just add a single for 'none'.
193-
deviceTypes.push_back(mlir::acc::DeviceTypeAttr::get(
194-
builder.getContext(), mlir::acc::DeviceType::None));
195-
operation.getNumWorkersMutable().append(intExpr);
196-
}
197-
198-
operation.setNumWorkersDeviceTypeAttr(
199-
mlir::ArrayAttr::get(builder.getContext(), deviceTypes));
205+
mlir::MutableOperandRange range = operation.getNumWorkersMutable();
206+
operation.setNumWorkersDeviceTypeAttr(handleDeviceTypeAffectedClause(
207+
operation.getNumWorkersDeviceTypeAttr(),
208+
createIntExpr(clause.getIntExpr()), range));
200209
} else if constexpr (isOneOfTypes<OpTy, SerialOp>) {
201210
llvm_unreachable("num_workers not valid on serial");
202211
} else {
203212
return clauseNotImplemented(clause);
204213
}
205214
}
206215

216+
void VisitVectorLengthClause(const OpenACCVectorLengthClause &clause) {
217+
if constexpr (isOneOfTypes<OpTy, ParallelOp, KernelsOp>) {
218+
mlir::MutableOperandRange range = operation.getVectorLengthMutable();
219+
operation.setVectorLengthDeviceTypeAttr(handleDeviceTypeAffectedClause(
220+
operation.getVectorLengthDeviceTypeAttr(),
221+
createIntExpr(clause.getIntExpr()), range));
222+
} else if constexpr (isOneOfTypes<OpTy, SerialOp>) {
223+
llvm_unreachable("vector_length not valid on serial");
224+
} else {
225+
return clauseNotImplemented(clause);
226+
}
227+
}
228+
207229
void VisitSelfClause(const OpenACCSelfClause &clause) {
208230
if constexpr (isOneOfTypes<OpTy, ParallelOp, SerialOp, KernelsOp>) {
209231
if (clause.isEmptySelfClause()) {

clang/test/CIR/CodeGenOpenACC/kernels.c

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,5 +158,57 @@ void acc_kernels(int cond) {
158158
// CHECK-NEXT: acc.terminator
159159
// CHECK-NEXT: } loc
160160

161+
#pragma acc kernels vector_length(cond)
162+
{}
163+
// CHECK-NEXT: %[[COND_LOAD:.*]] = cir.load %[[COND]] : !cir.ptr<!s32i>, !s32i
164+
// CHECK-NEXT: %[[CONV_CAST:.*]] = builtin.unrealized_conversion_cast %[[COND_LOAD]] : !s32i to si32
165+
// CHECK-NEXT: acc.kernels vector_length(%[[CONV_CAST]] : si32) {
166+
// CHECK-NEXT: acc.terminator
167+
// CHECK-NEXT: } loc
168+
169+
#pragma acc kernels vector_length(cond) device_type(nvidia) vector_length(2u)
170+
{}
171+
// CHECK-NEXT: %[[COND_LOAD:.*]] = cir.load %[[COND]] : !cir.ptr<!s32i>, !s32i
172+
// CHECK-NEXT: %[[CONV_CAST:.*]] = builtin.unrealized_conversion_cast %[[COND_LOAD]] : !s32i to si32
173+
// CHECK-NEXT: %[[TWO_LITERAL:.*]] = cir.const #cir.int<2> : !u32i
174+
// CHECK-NEXT: %[[TWO_CAST:.*]] = builtin.unrealized_conversion_cast %[[TWO_LITERAL]] : !u32i to ui32
175+
// CHECK-NEXT: acc.kernels vector_length(%[[CONV_CAST]] : si32, %[[TWO_CAST]] : ui32 [#acc.device_type<nvidia>]) {
176+
// CHECK-NEXT: acc.terminator
177+
// CHECK-NEXT: } loc
178+
179+
#pragma acc kernels vector_length(cond) device_type(nvidia, host) vector_length(2) device_type(radeon) vector_length(3)
180+
{}
181+
// CHECK-NEXT: %[[COND_LOAD:.*]] = cir.load %[[COND]] : !cir.ptr<!s32i>, !s32i
182+
// CHECK-NEXT: %[[CONV_CAST:.*]] = builtin.unrealized_conversion_cast %[[COND_LOAD]] : !s32i to si32
183+
// CHECK-NEXT: %[[TWO_LITERAL:.*]] = cir.const #cir.int<2> : !s32i
184+
// CHECK-NEXT: %[[TWO_CAST:.*]] = builtin.unrealized_conversion_cast %[[TWO_LITERAL]] : !s32i to si32
185+
// CHECK-NEXT: %[[THREE_LITERAL:.*]] = cir.const #cir.int<3> : !s32i
186+
// CHECK-NEXT: %[[THREE_CAST:.*]] = builtin.unrealized_conversion_cast %[[THREE_LITERAL]] : !s32i to si32
187+
// CHECK-NEXT: acc.kernels vector_length(%[[CONV_CAST]] : si32, %[[TWO_CAST]] : si32 [#acc.device_type<nvidia>], %[[TWO_CAST]] : si32 [#acc.device_type<host>], %[[THREE_CAST]] : si32 [#acc.device_type<radeon>]) {
188+
// CHECK-NEXT: acc.terminator
189+
// CHECK-NEXT: } loc
190+
191+
#pragma acc kernels vector_length(cond) device_type(nvidia) vector_length(2) device_type(radeon, multicore) vector_length(3)
192+
{}
193+
// CHECK-NEXT: %[[COND_LOAD:.*]] = cir.load %[[COND]] : !cir.ptr<!s32i>, !s32i
194+
// CHECK-NEXT: %[[CONV_CAST:.*]] = builtin.unrealized_conversion_cast %[[COND_LOAD]] : !s32i to si32
195+
// CHECK-NEXT: %[[TWO_LITERAL:.*]] = cir.const #cir.int<2> : !s32i
196+
// CHECK-NEXT: %[[TWO_CAST:.*]] = builtin.unrealized_conversion_cast %[[TWO_LITERAL]] : !s32i to si32
197+
// CHECK-NEXT: %[[THREE_LITERAL:.*]] = cir.const #cir.int<3> : !s32i
198+
// CHECK-NEXT: %[[THREE_CAST:.*]] = builtin.unrealized_conversion_cast %[[THREE_LITERAL]] : !s32i to si32
199+
// CHECK-NEXT: acc.kernels vector_length(%[[CONV_CAST]] : si32, %[[TWO_CAST]] : si32 [#acc.device_type<nvidia>], %[[THREE_CAST]] : si32 [#acc.device_type<radeon>], %[[THREE_CAST]] : si32 [#acc.device_type<multicore>]) {
200+
// CHECK-NEXT: acc.terminator
201+
// CHECK-NEXT: } loc
202+
203+
#pragma acc kernels device_type(nvidia) vector_length(2) device_type(radeon) vector_length(3)
204+
{}
205+
// CHECK-NEXT: %[[TWO_LITERAL:.*]] = cir.const #cir.int<2> : !s32i
206+
// CHECK-NEXT: %[[TWO_CAST:.*]] = builtin.unrealized_conversion_cast %[[TWO_LITERAL]] : !s32i to si32
207+
// CHECK-NEXT: %[[THREE_LITERAL:.*]] = cir.const #cir.int<3> : !s32i
208+
// CHECK-NEXT: %[[THREE_CAST:.*]] = builtin.unrealized_conversion_cast %[[THREE_LITERAL]] : !s32i to si32
209+
// CHECK-NEXT: acc.kernels vector_length(%[[TWO_CAST]] : si32 [#acc.device_type<nvidia>], %[[THREE_CAST]] : si32 [#acc.device_type<radeon>]) {
210+
// CHECK-NEXT: acc.terminator
211+
// CHECK-NEXT: } loc
212+
161213
// CHECK-NEXT: cir.return
162214
}

clang/test/CIR/CodeGenOpenACC/parallel.c

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,5 +157,57 @@ void acc_parallel(int cond) {
157157
// CHECK-NEXT: acc.yield
158158
// CHECK-NEXT: } loc
159159

160+
#pragma acc parallel vector_length(cond)
161+
{}
162+
// CHECK-NEXT: %[[COND_LOAD:.*]] = cir.load %[[COND]] : !cir.ptr<!s32i>, !s32i
163+
// CHECK-NEXT: %[[CONV_CAST:.*]] = builtin.unrealized_conversion_cast %[[COND_LOAD]] : !s32i to si32
164+
// CHECK-NEXT: acc.parallel vector_length(%[[CONV_CAST]] : si32) {
165+
// CHECK-NEXT: acc.yield
166+
// CHECK-NEXT: } loc
167+
168+
#pragma acc parallel vector_length(cond) device_type(nvidia) vector_length(2u)
169+
{}
170+
// CHECK-NEXT: %[[COND_LOAD:.*]] = cir.load %[[COND]] : !cir.ptr<!s32i>, !s32i
171+
// CHECK-NEXT: %[[CONV_CAST:.*]] = builtin.unrealized_conversion_cast %[[COND_LOAD]] : !s32i to si32
172+
// CHECK-NEXT: %[[TWO_LITERAL:.*]] = cir.const #cir.int<2> : !u32i
173+
// CHECK-NEXT: %[[TWO_CAST:.*]] = builtin.unrealized_conversion_cast %[[TWO_LITERAL]] : !u32i to ui32
174+
// CHECK-NEXT: acc.parallel vector_length(%[[CONV_CAST]] : si32, %[[TWO_CAST]] : ui32 [#acc.device_type<nvidia>]) {
175+
// CHECK-NEXT: acc.yield
176+
// CHECK-NEXT: } loc
177+
178+
#pragma acc parallel vector_length(cond) device_type(nvidia, host) vector_length(2) device_type(radeon) vector_length(3)
179+
{}
180+
// CHECK-NEXT: %[[COND_LOAD:.*]] = cir.load %[[COND]] : !cir.ptr<!s32i>, !s32i
181+
// CHECK-NEXT: %[[CONV_CAST:.*]] = builtin.unrealized_conversion_cast %[[COND_LOAD]] : !s32i to si32
182+
// CHECK-NEXT: %[[TWO_LITERAL:.*]] = cir.const #cir.int<2> : !s32i
183+
// CHECK-NEXT: %[[TWO_CAST:.*]] = builtin.unrealized_conversion_cast %[[TWO_LITERAL]] : !s32i to si32
184+
// CHECK-NEXT: %[[THREE_LITERAL:.*]] = cir.const #cir.int<3> : !s32i
185+
// CHECK-NEXT: %[[THREE_CAST:.*]] = builtin.unrealized_conversion_cast %[[THREE_LITERAL]] : !s32i to si32
186+
// CHECK-NEXT: acc.parallel vector_length(%[[CONV_CAST]] : si32, %[[TWO_CAST]] : si32 [#acc.device_type<nvidia>], %[[TWO_CAST]] : si32 [#acc.device_type<host>], %[[THREE_CAST]] : si32 [#acc.device_type<radeon>]) {
187+
// CHECK-NEXT: acc.yield
188+
// CHECK-NEXT: } loc
189+
190+
#pragma acc parallel vector_length(cond) device_type(nvidia) vector_length(2) device_type(radeon, multicore) vector_length(4)
191+
{}
192+
// CHECK-NEXT: %[[COND_LOAD:.*]] = cir.load %[[COND]] : !cir.ptr<!s32i>, !s32i
193+
// CHECK-NEXT: %[[CONV_CAST:.*]] = builtin.unrealized_conversion_cast %[[COND_LOAD]] : !s32i to si32
194+
// CHECK-NEXT: %[[TWO_LITERAL:.*]] = cir.const #cir.int<2> : !s32i
195+
// CHECK-NEXT: %[[TWO_CAST:.*]] = builtin.unrealized_conversion_cast %[[TWO_LITERAL]] : !s32i to si32
196+
// CHECK-NEXT: %[[FOUR_LITERAL:.*]] = cir.const #cir.int<4> : !s32i
197+
// CHECK-NEXT: %[[FOUR_CAST:.*]] = builtin.unrealized_conversion_cast %[[FOUR_LITERAL]] : !s32i to si32
198+
// CHECK-NEXT: acc.parallel vector_length(%[[CONV_CAST]] : si32, %[[TWO_CAST]] : si32 [#acc.device_type<nvidia>], %[[FOUR_CAST]] : si32 [#acc.device_type<radeon>], %[[FOUR_CAST]] : si32 [#acc.device_type<multicore>]) {
199+
// CHECK-NEXT: acc.yield
200+
// CHECK-NEXT: } loc
201+
202+
#pragma acc parallel device_type(nvidia) vector_length(2) device_type(radeon) vector_length(3)
203+
{}
204+
// CHECK-NEXT: %[[TWO_LITERAL:.*]] = cir.const #cir.int<2> : !s32i
205+
// CHECK-NEXT: %[[TWO_CAST:.*]] = builtin.unrealized_conversion_cast %[[TWO_LITERAL]] : !s32i to si32
206+
// CHECK-NEXT: %[[THREE_LITERAL:.*]] = cir.const #cir.int<3> : !s32i
207+
// CHECK-NEXT: %[[THREE_CAST:.*]] = builtin.unrealized_conversion_cast %[[THREE_LITERAL]] : !s32i to si32
208+
// CHECK-NEXT: acc.parallel vector_length(%[[TWO_CAST]] : si32 [#acc.device_type<nvidia>], %[[THREE_CAST]] : si32 [#acc.device_type<radeon>]) {
209+
// CHECK-NEXT: acc.yield
210+
// CHECK-NEXT: } loc
211+
160212
// CHECK-NEXT: cir.return
161213
}

0 commit comments

Comments
 (0)