Skip to content

Commit 1c583c1

Browse files
[acc][mlir] Add functionality for categorizing OpenACC variable types (#126167)
OpenACC specification describes the following type categories: scalar, array, composite, and aggregate (which includes arrays, composites, and others such as Fortran pointer/allocatable). Decision for how to do implicit mapping is dependent on a variable's category. Since acc dialect's only means of distinguishing between types is through the interfaces attached, add API to be able to get the type category. In addition to defining the new API, attempt to provide a base implementation for memref which matches what OpenACC spec describes.
1 parent 308d286 commit 1c583c1

File tree

3 files changed

+135
-2
lines changed

3 files changed

+135
-2
lines changed

mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,74 @@ def OpenACC_ReductionOperatorAttr : EnumAttr<OpenACC_Dialect,
6565
let assemblyFormat = [{ ```<` $value `>` }];
6666
}
6767

68+
// OpenACC variable type categorization. This is needed because OpenACC
69+
// dialect is used with other dialects, and each dialect defines its own
70+
// types. Thus, in order to be able to classify types and apply right semantics,
71+
// it is needed to ensure the types can be categorized.
72+
def OpenACC_VariableTypeUncategorized : I32BitEnumAttrCaseNone<"uncategorized">;
73+
74+
// The OpenACC spec definition of scalar type is as follows (from 3.3 spec,
75+
// line 5454):
76+
// Scalar datatype - an intrinsic or built-in datatype that is not an array or
77+
// aggregate datatype. In Fortran, scalar datatypes are integer, real, double
78+
// precision, complex, or logical. In C, scalar datatypes are char (signed or
79+
// unsigned), int (signed or unsigned, with optional short, long or long long
80+
// attribute), enum, float, double, long double, Complex (with optional float
81+
// or long attribute), or any pointer datatype. In C++, scalar datatypes are
82+
// char (signed or unsigned), wchar t, int (signed or unsigned, with optional
83+
// short, long or long long attribute), enum, bool, float, double, long double,
84+
// or any pointer datatype. Not all implementations or targets will support all
85+
// of these datatypes.
86+
// From an MLIR type perspective, the types that those language types map to
87+
// will be categorized as scalar.
88+
def OpenACC_VariableTypeScalar : I32BitEnumAttrCaseBit<"scalar", 0>;
89+
90+
// Not in OpenACC spec glossary as its own definition but used throughout the
91+
// spec. One definition of array that can be assumed for purposes of type
92+
// categorization is that it is a collection of elements of same type.
93+
def OpenACC_VariableTypeArray : I32BitEnumAttrCaseBit<"array", 1>;
94+
95+
// The OpenACC spec definition of composite type is as follows (from 3.3 spec,
96+
// line 5354):
97+
// Composite datatype - a derived type in Fortran, or a struct or union type in
98+
// C, or a class, struct, or union type in C++. (This is different from the use
99+
// of the term composite data type in the C and C++ languages.)
100+
def OpenACC_VariableTypeComposite : I32BitEnumAttrCaseBit<"composite", 2>;
101+
102+
// The OpenACC spec uses the type category "aggregate" to capture both arrays
103+
// and composite types. However, it includes types which do not fall in either
104+
// of those categories. Thus create a case for the others.
105+
// For example, reading the definition of "Aggregate Variables" in the 3.3
106+
// spec line 5346 shows this distinction:
107+
// Aggregate variables - a variable of any non-scalar datatype, including array
108+
// or composite variables. In Fortran, this includes any variable with
109+
// allocatable or pointer attribute and character variables
110+
def OpenACC_VariableTypeOtherNonScalar : I32BitEnumAttrCaseBit<"nonscalar", 3>;
111+
112+
// The OpenACC spec definition of aggregate type is as follows (from 3.3 spec,
113+
// line 5342):
114+
// Aggregate datatype - any non-scalar datatype such as array and composite
115+
// datatypes. In Fortran, aggregate datatypes include arrays, derived types,
116+
// character types. In C, aggregate datatypes include arrays, targets of
117+
// pointers, structs, and unions. In C++, aggregate datatypes include arrays,
118+
// targets of pointers, classes, structs, and unions.
119+
def OpenACC_VariableTypeAggregate : I32BitEnumAttrCaseGroup<"aggregate",
120+
[OpenACC_VariableTypeArray, OpenACC_VariableTypeComposite,
121+
OpenACC_VariableTypeOtherNonScalar]>;
122+
123+
def OpenACC_VariableTypeCategory : I32BitEnumAttr<
124+
"VariableTypeCategory",
125+
"Captures different type categories described in OpenACC spec",
126+
[
127+
OpenACC_VariableTypeUncategorized, OpenACC_VariableTypeScalar,
128+
OpenACC_VariableTypeArray, OpenACC_VariableTypeComposite,
129+
OpenACC_VariableTypeOtherNonScalar, OpenACC_VariableTypeAggregate]> {
130+
let separator = ",";
131+
let cppNamespace = "::mlir::acc";
132+
let genSpecializedAttr = 0;
133+
let printBitEnumPrimaryGroups = 1;
134+
}
135+
68136
// Type used in operation below.
69137
def IntOrIndex : AnyTypeOf<[AnyInteger, Index]>;
70138

mlir/include/mlir/Dialect/OpenACC/OpenACCTypeInterfaces.td

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,28 @@ def OpenACC_PointerLikeTypeInterface : TypeInterface<"PointerLikeType"> {
2828
/*retTy=*/"::mlir::Type",
2929
/*methodName=*/"getElementType"
3030
>,
31+
InterfaceMethod<
32+
/*description=*/[{
33+
Returns the type category of the pointee. The `var` is provided because
34+
a dialect's type system may be incomplete. For example, consider a
35+
dialect which computes interior pointers - so a float array element
36+
may be represented as `ptr<f32>`. The type system says the pointee
37+
is `f32` but this is not a scalar from the point-of-view of OpenACC.
38+
It is an array element and thus the appropriate type category is
39+
"array" - therefore being able to look up how a variable is computed
40+
is important for a complete type determination.
41+
The `varType` is provided in cases where a dialect's type system
42+
erased the target type.
43+
}],
44+
/*retTy=*/"::mlir::acc::VariableTypeCategory",
45+
/*methodName=*/"getPointeeTypeCategory",
46+
/*args=*/(ins "::mlir::TypedValue<::mlir::acc::PointerLikeType>":$varPtr,
47+
"::mlir::Type":$varType),
48+
/*methodBody=*/"",
49+
/*defaultImplementation=*/[{
50+
return ::mlir::acc::VariableTypeCategory::uncategorized;
51+
}]
52+
>,
3153
];
3254
}
3355

@@ -106,7 +128,7 @@ def OpenACC_MappableTypeInterface : TypeInterface<"MappableType"> {
106128
return {};
107129
}]
108130
>,
109-
InterfaceMethod<
131+
InterfaceMethod<
110132
/*description=*/[{
111133
Returns explicit `acc.bounds` operations that envelop the whole
112134
data structure. These operations are inserted using the provided builder
@@ -121,6 +143,18 @@ def OpenACC_MappableTypeInterface : TypeInterface<"MappableType"> {
121143
return {};
122144
}]
123145
>,
146+
InterfaceMethod<
147+
/*description=*/[{
148+
Returns the OpenACC type category.
149+
}],
150+
/*retTy=*/"::mlir::acc::VariableTypeCategory",
151+
/*methodName=*/"getTypeCategory",
152+
/*args=*/(ins "::mlir::Value":$var),
153+
/*methodBody=*/"",
154+
/*defaultImplementation=*/[{
155+
return ::mlir::acc::VariableTypeCategory::uncategorized;
156+
}]
157+
>,
124158
];
125159
}
126160

mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,42 @@ using namespace acc;
3232
#include "mlir/Dialect/OpenACCMPCommon/Interfaces/OpenACCMPOpsInterfaces.cpp.inc"
3333

3434
namespace {
35+
36+
static bool isScalarLikeType(Type type) {
37+
return type.isIntOrIndexOrFloat() || isa<ComplexType>(type);
38+
}
39+
3540
struct MemRefPointerLikeModel
3641
: public PointerLikeType::ExternalModel<MemRefPointerLikeModel,
3742
MemRefType> {
3843
Type getElementType(Type pointer) const {
39-
return llvm::cast<MemRefType>(pointer).getElementType();
44+
return cast<MemRefType>(pointer).getElementType();
45+
}
46+
mlir::acc::VariableTypeCategory
47+
getPointeeTypeCategory(Type pointer, TypedValue<PointerLikeType> varPtr,
48+
Type varType) const {
49+
if (auto mappableTy = dyn_cast<MappableType>(varType)) {
50+
return mappableTy.getTypeCategory(varPtr);
51+
}
52+
auto memrefTy = cast<MemRefType>(pointer);
53+
if (!memrefTy.hasRank()) {
54+
// This memref is unranked - aka it could have any rank, including a
55+
// rank of 0 which could mean scalar. For now, return uncategorized.
56+
return mlir::acc::VariableTypeCategory::uncategorized;
57+
}
58+
59+
if (memrefTy.getRank() == 0) {
60+
if (isScalarLikeType(memrefTy.getElementType())) {
61+
return mlir::acc::VariableTypeCategory::scalar;
62+
}
63+
// Zero-rank non-scalar - need further analysis to determine the type
64+
// category. For now, return uncategorized.
65+
return mlir::acc::VariableTypeCategory::uncategorized;
66+
}
67+
68+
// It has a rank - must be an array.
69+
assert(memrefTy.getRank() > 0 && "rank expected to be positive");
70+
return mlir::acc::VariableTypeCategory::array;
4071
}
4172
};
4273

0 commit comments

Comments
 (0)