@@ -538,6 +538,126 @@ def BroadcastOp : LinalgStructuredBase_Op<"broadcast", [
538
538
let hasCanonicalizer = 1;
539
539
}
540
540
541
+ //===----------------------------------------------------------------------===//
542
+ // Op definition for ElementwiseOp
543
+ //===----------------------------------------------------------------------===//
544
+ def ElementwiseOp : LinalgStructuredBase_Op<"elementwise", [
545
+ AttrSizedOperandSegments]> {
546
+ let summary = [{ Performs element-wise operation }];
547
+ let description = [{
548
+ The attribute `kind` describes arithmetic operation to perform. The
549
+ operation kind can be unary (e.g. max), binary (e.g. add) or ternary
550
+ (e.g. select).
551
+
552
+ By default, all indexing maps are identities. In the case of default
553
+ indexing map, all input and output shapes must match. The number of dims in
554
+ each of the identity maps is equal to the rank of the output type.
555
+
556
+ Affine-maps for operands and result are required to be provided by the user
557
+ when a transpose and/or broadcast is needed on any operand. When a map is not
558
+ provided, default identity maps are inferred for each operand.
559
+
560
+ Iterator-types are always all `parallel`.
561
+ Iterator-types are needed for constructing the underlying structured op.
562
+
563
+ The number of dims of the iterator-types are inferred from the rank of
564
+ the result type.
565
+
566
+ Example:
567
+
568
+ Defining a unary linalg.elemwise with default indexing-map:
569
+ ```mlir
570
+ %exp = linalg.elemwise
571
+ kind=#linalg.elemwise_kind<exp>
572
+ ins(%x : tensor<4x16x8xf32>)
573
+ outs(%y: tensor<4x16x8xf32>) -> tensor<4x16x8xf32>
574
+ ```
575
+
576
+ Defining a binary linalg.elemwise with user-defined indexing-map:
577
+ ```mlir
578
+ %add = linalg.elemwise
579
+ kind=#linalg.elemwise_kind<add>
580
+ indexing_maps = [#transpose, #broadcast, #identity]
581
+ ins(%exp, %arg1 : tensor<4x16x8xf32>, tensor<4x16xf32>)
582
+ outs(%arg2: tensor<4x8x16xf32>) -> tensor<4x8x16xf32>
583
+ ```
584
+ }];
585
+
586
+ let arguments = (ins
587
+ Variadic<AnyType>:$inputs,
588
+ Variadic<AnyShaped>:$outputs,
589
+ ElementwiseKindAttr:$kind,
590
+ DefaultValuedOptionalAttr<AffineMapArrayAttr, "{}">:$indexing_maps
591
+ );
592
+
593
+ let results = (outs Variadic<AnyRankedTensor>:$result_tensors);
594
+ let regions = (region AnyRegion:$region);
595
+ let skipDefaultBuilders = 1;
596
+
597
+ let builders = [
598
+ OpBuilder<
599
+ (ins "ValueRange":$inputs, "ValueRange":$outputs,
600
+ CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
601
+ [{
602
+ buildStructuredOp($_builder, $_state, std::nullopt, inputs, outputs,
603
+ attributes, ElementwiseOp::getRegionBuilder());
604
+ }]>
605
+ ];
606
+
607
+ let hasCustomAssemblyFormat = 1;
608
+ let hasFolder = 1;
609
+ let hasVerifier = 1;
610
+
611
+ let extraClassDeclaration = structuredOpsBaseDecls # [{
612
+ /// Get the arity enum corresponding to the kind of op, e.g. if arg is
613
+ /// `ElementwiseKind::add`, return `ElementwiseArityGroup::Binary`.
614
+ static ElementwiseArityGroup getArityGroup(ElementwiseKind n);
615
+
616
+ /// Both user-specified and default indexing map will always depend on
617
+ /// the current Op instance.
618
+ static bool hasDynamicIndexingMaps() { return true; }
619
+
620
+ /// Implements the block region builder for the elementwiseOp. This is
621
+ /// called by the 'fillStructuredOpRegion'.
622
+ static void regionBuilder(ImplicitLocOpBuilder &b,
623
+ Block &block, ArrayRef<NamedAttribute> attrs);
624
+
625
+ static std::function<void(ImplicitLocOpBuilder &,
626
+ Block &, ArrayRef<NamedAttribute>)>
627
+ getRegionBuilder() {
628
+ return regionBuilder;
629
+ }
630
+
631
+ /// Returns rank of the result tensor/memref. Useful for knowing
632
+ /// the dimensionality of the iteration space when others means
633
+ /// are not possible e.g. absence of user-provided indexing map.
634
+ unsigned getResultRank() {
635
+ Value output = getDpsInitOperand(0)->get();
636
+ ShapedType shapedType = llvm::cast<ShapedType>(output.getType());
637
+ return shapedType.getRank();
638
+ }
639
+
640
+ /// Returns N 'parallel' iterator types where N is rank of result.
641
+ SmallVector<utils::IteratorType> getIteratorTypesArray();
642
+
643
+ /// The default indexing maps are identities.
644
+ /// There will be N+1 such maps, where N is the arity of the Op.
645
+ static SmallVector<AffineMap>
646
+ getDefaultIndexingMaps(unsigned NumMaps, unsigned numDims,
647
+ MLIRContext *context);
648
+
649
+ /// Destination passing style interface method.
650
+ ::mlir::MutableOperandRange getDpsInitsMutable() {
651
+ return getOutputsMutable();
652
+ }
653
+
654
+ // Generic methods.
655
+ std::string getLibraryCallName() {
656
+ return generateLibraryCallName(getOperation());
657
+ }
658
+ }];
659
+ }
660
+
541
661
//===----------------------------------------------------------------------===//
542
662
// Op definition for MatmulOp
543
663
//===----------------------------------------------------------------------===//
0 commit comments