Skip to content

Commit 4c1539c

Browse files
committed
[MLIR] [OpenMP] Initial support for OMP ALLOCATE directive op.
1 parent 36cbd43 commit 4c1539c

File tree

6 files changed

+179
-0
lines changed

6 files changed

+179
-0
lines changed

mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,31 @@
2222
include "mlir/Dialect/OpenMP/OpenMPOpBase.td"
2323
include "mlir/IR/SymbolInterfaces.td"
2424

25+
//===----------------------------------------------------------------------===//
26+
// V5.2: [6.3] `align` clause
27+
//===----------------------------------------------------------------------===//
28+
29+
class OpenMP_AlignClauseSkip<
30+
bit traits = false, bit arguments = false, bit assemblyFormat = false,
31+
bit description = false, bit extraClassDeclaration = false
32+
> : OpenMP_Clause<traits, arguments, assemblyFormat, description,
33+
extraClassDeclaration> {
34+
let arguments = (ins
35+
ConfinedAttr<OptionalAttr<I64Attr>, [IntPositive]>:$align
36+
);
37+
38+
let optAssemblyFormat = [{
39+
`align` `(` $align `)`
40+
}];
41+
42+
let description = [{
43+
The `align` clause is used to specify the byte alignment to use for
44+
allocations associated with the construct on which the clause appears.
45+
}];
46+
}
47+
48+
def OpenMP_AlignClause : OpenMP_AlignClauseSkip<>;
49+
2550
//===----------------------------------------------------------------------===//
2651
// V5.2: [5.11] `aligned` clause
2752
//===----------------------------------------------------------------------===//
@@ -84,6 +109,32 @@ class OpenMP_AllocateClauseSkip<
84109

85110
def OpenMP_AllocateClause : OpenMP_AllocateClauseSkip<>;
86111

112+
//===----------------------------------------------------------------------===//
113+
// V5.2: [6.4] `allocator` clause
114+
//===----------------------------------------------------------------------===//
115+
116+
class OpenMP_AllocatorClauseSkip<
117+
bit traits = false, bit arguments = false, bit assemblyFormat = false,
118+
bit description = false, bit extraClassDeclaration = false
119+
> : OpenMP_Clause<traits, arguments, assemblyFormat, description,
120+
extraClassDeclaration> {
121+
122+
let arguments = (ins
123+
OptionalAttr<AllocatorHandleAttr>:$allocator
124+
);
125+
126+
let optAssemblyFormat = [{
127+
`allocator` `(` custom<ClauseAttr>($allocator) `)`
128+
}];
129+
130+
let description = [{
131+
`allocator` specifies the memory allocator to be used for allocations
132+
associated with the construct on which the clause appears.
133+
}];
134+
}
135+
136+
def OpenMP_AllocatorClause : OpenMP_AllocatorClauseSkip<>;
137+
87138
//===----------------------------------------------------------------------===//
88139
// LLVM OpenMP extension `ompx_bare` clause
89140
//===----------------------------------------------------------------------===//

mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -263,4 +263,34 @@ def VariableCaptureKindAttr : OpenMP_EnumAttr<VariableCaptureKind,
263263
let assemblyFormat = "`(` $value `)`";
264264
}
265265

266+
267+
//===----------------------------------------------------------------------===//
268+
// allocator_handle enum.
269+
//===----------------------------------------------------------------------===//
270+
271+
def OpenMP_AllocatorHandleNullAllocator : I32EnumAttrCase<"omp_null_allocator", 0>;
272+
def OpenMP_AllocatorHandleDefaultMemAlloc : I32EnumAttrCase<"omp_default_mem_alloc", 1>;
273+
def OpenMP_AllocatorHandleLargeCapMemAlloc : I32EnumAttrCase<"omp_large_cap_mem_alloc", 2>;
274+
def OpenMP_AllocatorHandleConstMemAlloc : I32EnumAttrCase<"omp_const_mem_alloc", 3>;
275+
def OpenMP_AllocatorHandleHighBwMemAlloc : I32EnumAttrCase<"omp_high_bw_mem_alloc", 4>;
276+
def OpenMP_AllocatorHandleLowLatMemAlloc : I32EnumAttrCase<"omp_low_lat_mem_alloc", 5>;
277+
def OpenMP_AllocatorHandleCgroupMemAlloc : I32EnumAttrCase<"omp_cgroup_mem_alloc", 6>;
278+
def OpenMP_AllocatorHandlePteamMemAlloc : I32EnumAttrCase<"omp_pteam_mem_alloc", 7>;
279+
def OpenMP_AllocatorHandlethreadMemAlloc : I32EnumAttrCase<"omp_thread_mem_alloc", 8>;
280+
281+
def AllocatorHandle : OpenMP_I32EnumAttr<
282+
"AllocatorHandle",
283+
"OpenMP allocator_handle", [
284+
OpenMP_AllocatorHandleNullAllocator,
285+
OpenMP_AllocatorHandleDefaultMemAlloc,
286+
OpenMP_AllocatorHandleLargeCapMemAlloc,
287+
OpenMP_AllocatorHandleConstMemAlloc,
288+
OpenMP_AllocatorHandleHighBwMemAlloc,
289+
OpenMP_AllocatorHandleLowLatMemAlloc,
290+
OpenMP_AllocatorHandleCgroupMemAlloc,
291+
OpenMP_AllocatorHandlePteamMemAlloc,
292+
OpenMP_AllocatorHandlethreadMemAlloc
293+
]>;
294+
295+
def AllocatorHandleAttr : OpenMP_EnumAttr<AllocatorHandle, "allocator_handle">;
266296
#endif // OPENMP_ENUMS

mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1883,4 +1883,31 @@ def MaskedOp : OpenMP_Op<"masked", clauses = [
18831883
];
18841884
}
18851885

1886+
//===----------------------------------------------------------------------===//
1887+
// [Spec 5.2] 6.5 allocate Directive
1888+
//===----------------------------------------------------------------------===//
1889+
def AllocateDirOp : OpenMP_Op<"allocate_dir", clauses = [
1890+
OpenMP_AlignClause, OpenMP_AllocatorClause
1891+
]> {
1892+
let summary = "allocate directive";
1893+
let description = [{
1894+
The storage for each list item that appears in the allocate directive is
1895+
provided an allocation through the memory allocator.
1896+
}] # clausesDescription;
1897+
1898+
let arguments = !con((ins Variadic<AnyType>:$varList),
1899+
clausesArgs);
1900+
1901+
// Override inherited assembly format to include `varList`.
1902+
let assemblyFormat = " `(` $varList `:` type($varList) `)` oilist(" #
1903+
clausesOptAssemblyFormat #
1904+
") attr-dict ";
1905+
1906+
let builders = [
1907+
OpBuilder<(ins CArg<"const AllocateDirOperands &">:$clauses)>
1908+
];
1909+
1910+
let hasVerifier = 1;
1911+
}
1912+
18861913
#endif // OPENMP_OPS

mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3512,6 +3512,20 @@ LogicalResult ScanOp::verify() {
35123512
"reduction modifier");
35133513
}
35143514

3515+
/// Verifies align clause in allocate directive
3516+
3517+
LogicalResult AllocateDirOp::verify() {
3518+
std::optional<u_int64_t> align = this->getAlign();
3519+
3520+
if (align.has_value()) {
3521+
if ((align.value() > 0) && ((align.value() & (align.value() - 1)) != 0))
3522+
return emitError() << "ALIGN value : " << align.value()
3523+
<< " must be power of 2";
3524+
}
3525+
3526+
return success();
3527+
}
3528+
35153529
#define GET_ATTRDEF_CLASSES
35163530
#include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc"
35173531

mlir/test/Dialect/OpenMP/invalid.mlir

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2993,3 +2993,27 @@ llvm.func @invalid_mapper(%0 : !llvm.ptr) {
29932993
}
29942994
llvm.return
29952995
}
2996+
2997+
// -----
2998+
func.func @invalid_allocate_align_1(%arg0 : memref<i32>) -> () {
2999+
// expected-error @below {{failed to satisfy constraint: 64-bit signless integer attribute whose value is positive}}
3000+
omp.allocate_dir (%arg0 : memref<i32>) align(-1)
3001+
3002+
return
3003+
}
3004+
3005+
// -----
3006+
func.func @invalid_allocate_align_2(%arg0 : memref<i32>) -> () {
3007+
// expected-error @below {{must be power of 2}}
3008+
omp.allocate_dir (%arg0 : memref<i32>) align(3)
3009+
3010+
return
3011+
}
3012+
3013+
// -----
3014+
func.func @invalid_allocate_allocator(%arg0 : memref<i32>) -> () {
3015+
// expected-error @below {{invalid clause value}}
3016+
omp.allocate_dir (%arg0 : memref<i32>) allocator(omp_small_cap_mem_alloc)
3017+
3018+
return
3019+
}

mlir/test/Dialect/OpenMP/ops.mlir

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3197,3 +3197,36 @@ func.func @omp_workshare_loop_wrapper_attrs(%idx : index) {
31973197
}
31983198
return
31993199
}
3200+
3201+
// CHECK-LABEL: func.func @omp_allocate_dir(
3202+
// CHECK-SAME: %[[ARG0:.*]]: memref<i32>,
3203+
// CHECK-SAME: %[[ARG1:.*]]: memref<i32>) {
3204+
func.func @omp_allocate_dir(%arg0 : memref<i32>, %arg1 : memref<i32>) -> () {
3205+
3206+
// Test with one data var
3207+
// CHECK: omp.allocate_dir(%[[ARG0]] : memref<i32>)
3208+
omp.allocate_dir (%arg0 : memref<i32>)
3209+
3210+
// Test with two data vars
3211+
// CHECK: omp.allocate_dir(%[[ARG0]], %[[ARG1]] : memref<i32>, memref<i32>)
3212+
omp.allocate_dir (%arg0, %arg1: memref<i32>, memref<i32>)
3213+
3214+
// Test with one data var and align clause
3215+
// CHECK: omp.allocate_dir(%[[ARG0]] : memref<i32>) align(2)
3216+
omp.allocate_dir (%arg0 : memref<i32>) align(2)
3217+
3218+
// Test with one data var and allocator clause
3219+
// CHECK: omp.allocate_dir(%[[ARG0]] : memref<i32>) allocator(omp_pteam_mem_alloc)
3220+
omp.allocate_dir (%arg0 : memref<i32>) allocator(omp_pteam_mem_alloc)
3221+
3222+
// Test with one data var, align clause and allocator clause
3223+
// CHECK: omp.allocate_dir(%[[ARG0]] : memref<i32>) align(2) allocator(omp_thread_mem_alloc)
3224+
omp.allocate_dir (%arg0 : memref<i32>) align(2) allocator(omp_thread_mem_alloc)
3225+
3226+
// Test with two data vars, align clause and allocator clause
3227+
// CHECK: omp.allocate_dir(%[[ARG0]], %[[ARG1]] : memref<i32>, memref<i32>) align(2) allocator(omp_cgroup_mem_alloc)
3228+
omp.allocate_dir (%arg0, %arg1 : memref<i32>, memref<i32>) align(2) allocator(omp_cgroup_mem_alloc)
3229+
3230+
return
3231+
}
3232+

0 commit comments

Comments
 (0)