Skip to content

Commit 6de5d1e

Browse files
authored
[mlir][linalg] Extend elementwise (llvm#124661)
Implements Linalg elemwise named-op following the proposal and discussions in RFC: https://discourse.llvm.org/t/rfc-extend-linalg-elemwise-named-ops-semantics/83927/1
1 parent b9622e8 commit 6de5d1e

File tree

7 files changed

+721
-0
lines changed

7 files changed

+721
-0
lines changed

mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,12 @@ def Linalg_Dialect : Dialect {
6161
}];
6262
}
6363

64+
// Define the attribute enums matching elementwise op kind (e.g., add).
65+
def ElementwiseKindAttr : EnumAttr<Linalg_Dialect,
66+
ElementwiseKind, "elementwise_kind"> {
67+
let assemblyFormat = "`<` $value `>`";
68+
}
69+
6470
// Define the function attribute enums matching the OpDSL functions.
6571
def UnaryFnAttr : EnumAttr<Linalg_Dialect, UnaryFn, "unary_fn"> {
6672
let assemblyFormat = "`<` $value `>`";

mlir/include/mlir/Dialect/Linalg/IR/LinalgEnums.td

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,65 @@ def TernaryFn : I32EnumAttr<"TernaryFn", "", [
5555
let genSpecializedAttr = 0;
5656
let cppNamespace = "::mlir::linalg";
5757
}
58+
59+
// Join two I32EnumAttrCase lists. This joining takes care that the
60+
// 'int enum values' in the combined list do not overlap. It does this
61+
// by adding to each element of second list the offset '!size(a)'.
62+
class JoinTwoI32EnumAttrCaseList< list<I32EnumAttrCase> a,
63+
list<I32EnumAttrCase> b> {
64+
int aSize = !size(a);
65+
list<I32EnumAttrCase> result =
66+
!foldl(a, b, acc, var,
67+
acc # [I32EnumAttrCase<var.symbol,
68+
!add(var.value, aSize)
69+
>]);
70+
}
71+
72+
// Flatten 'list of list of I32EnumAttrCase' to 'list of I32EnumAttrCase'.
73+
// The flattening (via call to 'join') ensures no overlap in enum values.
74+
class ConcatI32EnumAtrCaseList< list<list<I32EnumAttrCase>> l> {
75+
list<I32EnumAttrCase> result =
76+
!foldl([]<I32EnumAttrCase>, l, acc, var,
77+
JoinTwoI32EnumAttrCaseList<acc, var>.result);
78+
}
79+
80+
// Define a unified `enum class : i32` for all element-wise op functions.
81+
def ElementwiseKind :
82+
I32EnumAttr<"ElementwiseKind",
83+
"",
84+
ConcatI32EnumAtrCaseList<[UnaryFn.enumerants,
85+
BinaryFn.enumerants,
86+
TernaryFn.enumerants]>.result
87+
> {
88+
let genSpecializedAttr = 0;
89+
let cppNamespace = "::mlir::linalg";
90+
}
91+
92+
// Define an `enum class : i32` that marks where each individual enum class
93+
// e.g. UnaryFn, BinaryFn, etc. end in the unified enum class ElementwiseKind.
94+
def ElementwiseCaseLimits : I32EnumAttr<"ElementwiseCaseLimits", "", []> {
95+
int last_unary = !size(UnaryFn.enumerants);
96+
int last_binary = !add(last_unary, !size(BinaryFn.enumerants));
97+
int last_ternary = !add(last_binary, !size(TernaryFn.enumerants));
98+
99+
let enumerants = [
100+
I32EnumAttrCase<"LastUnary", last_unary>,
101+
I32EnumAttrCase<"LastBinary", last_binary>,
102+
I32EnumAttrCase<"LastTernary", last_ternary>];
103+
let genSpecializedAttr = 0;
104+
let cppNamespace = "::mlir::linalg";
105+
}
106+
107+
// Define an `enum class : i32` to categorise arity elementwise ops.
108+
def ElementwiseArityGroup : I32EnumAttr<"ElementwiseArityGroup", "", [
109+
I32EnumAttrCase<"Unary", 1>,
110+
I32EnumAttrCase<"Binary", 2>,
111+
I32EnumAttrCase<"Ternary", 3>
112+
]> {
113+
let genSpecializedAttr = 0;
114+
let cppNamespace = "::mlir::linalg";
115+
}
116+
58117
def TypeFn : I32EnumAttr<"TypeFn", "", [
59118
I32EnumAttrCase<"cast_signed", 0>,
60119
I32EnumAttrCase<"cast_unsigned", 1>

mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -538,6 +538,126 @@ def BroadcastOp : LinalgStructuredBase_Op<"broadcast", [
538538
let hasCanonicalizer = 1;
539539
}
540540

541+
//===----------------------------------------------------------------------===//
542+
// Op definition for ElementwiseOp
543+
//===----------------------------------------------------------------------===//
544+
def ElementwiseOp : LinalgStructuredBase_Op<"elementwise", [
545+
AttrSizedOperandSegments]> {
546+
let summary = [{ Performs element-wise operation }];
547+
let description = [{
548+
The attribute `kind` describes arithmetic operation to perform. The
549+
operation kind can be unary (e.g. max), binary (e.g. add) or ternary
550+
(e.g. select).
551+
552+
By default, all indexing maps are identities. In the case of default
553+
indexing map, all input and output shapes must match. The number of dims in
554+
each of the identity maps is equal to the rank of the output type.
555+
556+
Affine-maps for operands and result are required to be provided by the user
557+
when a transpose and/or broadcast is needed on any operand. When a map is not
558+
provided, default identity maps are inferred for each operand.
559+
560+
Iterator-types are always all `parallel`.
561+
Iterator-types are needed for constructing the underlying structured op.
562+
563+
The number of dims of the iterator-types are inferred from the rank of
564+
the result type.
565+
566+
Example:
567+
568+
Defining a unary linalg.elemwise with default indexing-map:
569+
```mlir
570+
%exp = linalg.elemwise
571+
kind=#linalg.elemwise_kind<exp>
572+
ins(%x : tensor<4x16x8xf32>)
573+
outs(%y: tensor<4x16x8xf32>) -> tensor<4x16x8xf32>
574+
```
575+
576+
Defining a binary linalg.elemwise with user-defined indexing-map:
577+
```mlir
578+
%add = linalg.elemwise
579+
kind=#linalg.elemwise_kind<add>
580+
indexing_maps = [#transpose, #broadcast, #identity]
581+
ins(%exp, %arg1 : tensor<4x16x8xf32>, tensor<4x16xf32>)
582+
outs(%arg2: tensor<4x8x16xf32>) -> tensor<4x8x16xf32>
583+
```
584+
}];
585+
586+
let arguments = (ins
587+
Variadic<AnyType>:$inputs,
588+
Variadic<AnyShaped>:$outputs,
589+
ElementwiseKindAttr:$kind,
590+
DefaultValuedOptionalAttr<AffineMapArrayAttr, "{}">:$indexing_maps
591+
);
592+
593+
let results = (outs Variadic<AnyRankedTensor>:$result_tensors);
594+
let regions = (region AnyRegion:$region);
595+
let skipDefaultBuilders = 1;
596+
597+
let builders = [
598+
OpBuilder<
599+
(ins "ValueRange":$inputs, "ValueRange":$outputs,
600+
CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
601+
[{
602+
buildStructuredOp($_builder, $_state, std::nullopt, inputs, outputs,
603+
attributes, ElementwiseOp::getRegionBuilder());
604+
}]>
605+
];
606+
607+
let hasCustomAssemblyFormat = 1;
608+
let hasFolder = 1;
609+
let hasVerifier = 1;
610+
611+
let extraClassDeclaration = structuredOpsBaseDecls # [{
612+
/// Get the arity enum corresponding to the kind of op, e.g. if arg is
613+
/// `ElementwiseKind::add`, return `ElementwiseArityGroup::Binary`.
614+
static ElementwiseArityGroup getArityGroup(ElementwiseKind n);
615+
616+
/// Both user-specified and default indexing map will always depend on
617+
/// the current Op instance.
618+
static bool hasDynamicIndexingMaps() { return true; }
619+
620+
/// Implements the block region builder for the elementwiseOp. This is
621+
/// called by the 'fillStructuredOpRegion'.
622+
static void regionBuilder(ImplicitLocOpBuilder &b,
623+
Block &block, ArrayRef<NamedAttribute> attrs);
624+
625+
static std::function<void(ImplicitLocOpBuilder &,
626+
Block &, ArrayRef<NamedAttribute>)>
627+
getRegionBuilder() {
628+
return regionBuilder;
629+
}
630+
631+
/// Returns rank of the result tensor/memref. Useful for knowing
632+
/// the dimensionality of the iteration space when others means
633+
/// are not possible e.g. absence of user-provided indexing map.
634+
unsigned getResultRank() {
635+
Value output = getDpsInitOperand(0)->get();
636+
ShapedType shapedType = llvm::cast<ShapedType>(output.getType());
637+
return shapedType.getRank();
638+
}
639+
640+
/// Returns N 'parallel' iterator types where N is rank of result.
641+
SmallVector<utils::IteratorType> getIteratorTypesArray();
642+
643+
/// The default indexing maps are identities.
644+
/// There will be N+1 such maps, where N is the arity of the Op.
645+
static SmallVector<AffineMap>
646+
getDefaultIndexingMaps(unsigned NumMaps, unsigned numDims,
647+
MLIRContext *context);
648+
649+
/// Destination passing style interface method.
650+
::mlir::MutableOperandRange getDpsInitsMutable() {
651+
return getOutputsMutable();
652+
}
653+
654+
// Generic methods.
655+
std::string getLibraryCallName() {
656+
return generateLibraryCallName(getOperation());
657+
}
658+
}];
659+
}
660+
541661
//===----------------------------------------------------------------------===//
542662
// Op definition for MatmulOp
543663
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)