@@ -523,15 +523,14 @@ class GenericOpBase<string mnemonic> : LinalgStructuredBase_Op<mnemonic, []> {
523
523
AffineMapArrayAttr:$indexing_maps,
524
524
ArrayAttr:$iterator_types,
525
525
OptionalAttr<StrAttr>:$doc,
526
- OptionalAttr<FlatSymbolRefAttr>:$fun,
527
526
OptionalAttr<StrAttr>:$library_call);
528
527
let results = (outs Variadic<AnyRankedTensor>:$output_tensors);
529
528
let regions = (region AnyRegion:$region);
530
529
let extraClassDeclaration = [{
531
530
SmallVector<StringRef, 8> linalgTraitAttrNames() {
532
531
return SmallVector<StringRef, 8>{
533
532
getArgsInAttrName(), getArgsOutAttrName(), getDocAttrName(),
534
- getFunAttrName(), getIndexingMapsAttrName(), getLibraryCallAttrName(),
533
+ getIndexingMapsAttrName(), getLibraryCallAttrName(),
535
534
getIteratorTypesAttrName()
536
535
};
537
536
}
@@ -540,12 +539,6 @@ class GenericOpBase<string mnemonic> : LinalgStructuredBase_Op<mnemonic, []> {
540
539
541
540
unsigned getNumOutputs() { return args_out().getSExtValue(); }
542
541
543
- FuncOp getFunction() {
544
- auto moduleOp = getParentOfType<ModuleOp>();
545
- return fun().hasValue() ?
546
- moduleOp.lookupSymbol<FuncOp>(fun().getValue()) : FuncOp();
547
- }
548
-
549
542
StringRef getLibraryCallName() {
550
543
return library_call().hasValue() ? library_call().getValue() : "";
551
544
}
@@ -581,13 +574,6 @@ def GenericOp : GenericOpBase<"generic"> {
581
574
- args_in: an I64Attr representing the number of input (readonly) views
582
575
- args_out: an I64Attr representing the number of output (readwrite) views
583
576
- doc [optional]: a documentation string
584
- - fun: a FlatSymbolRefAttr that must resolve to an existing function
585
- symbol. To support inplace updates in a generic fashion, the signature
586
- of the function must be:
587
- ```
588
- fun([input views element types], [output views element types])
589
- -> ([output views element types])
590
- ```
591
577
- indexing_maps: a list of AffineMapAttr, one AffineMapAttr per each input
592
578
and output view. Such AffineMapAttr specifies the mapping between the
593
579
loops and the indexing within each view.
@@ -604,19 +590,13 @@ def GenericOp : GenericOpBase<"generic"> {
604
590
Example:
605
591
Defining a #matmul_trait attribute in MLIR can be done as follows:
606
592
```mlir
607
- func @fma(%a: f32, %b: f32, %c: f32) -> f32 {
608
- %d = mulf %a, %b: f32
609
- %e = addf %c, %d: f32
610
- return %e: f32
611
- }
612
593
#matmul_accesses = [
613
594
(m, n, k) -> (m, k),
614
595
(m, n, k) -> (k, n),
615
596
(m, n, k) -> (m, n)
616
597
]
617
598
#matmul_trait = {
618
599
doc = "C(m, n) += A(m, k) * B(k, n)",
619
- fun = @fma,
620
600
indexing_maps = #matmul_accesses,
621
601
library_call = "linalg_matmul",
622
602
n_views = [2, 1],
@@ -626,10 +606,14 @@ def GenericOp : GenericOpBase<"generic"> {
626
606
627
607
And can be reused in multiple places as:
628
608
```mlir
629
- linalg.generic #matmul_trait %A, %B, %C [other-attributes] :
630
- memref<?x?xf32, stride_specification>,
631
- memref<?x?xf32, stride_specification>,
632
- memref<?x?xf32, stride_specification>
609
+ linalg.generic #matmul_trait %A, %B, %C [other-attributes] {
610
+ (%a: f32, %b: f32, %c: f32) :
611
+ %d = mulf %a, %b: f32
612
+ %e = addf %c, %d: f32
613
+ linalg_yield %e : f32
614
+ } : memref<?x?xf32, stride_specification>,
615
+ memref<?x?xf32, stride_specification>,
616
+ memref<?x?xf32, stride_specification>
633
617
```
634
618
635
619
This may lower to either:
@@ -649,9 +633,9 @@ def GenericOp : GenericOpBase<"generic"> {
649
633
%a = load %A[%m, %k] : memref<?x?xf32, stride_specification>
650
634
%b = load %B[%k, %n] : memref<?x?xf32, stride_specification>
651
635
%c = load %C[%m, %n] : memref<?x?xf32, stride_specification>
652
- %d = call @func_of_elements( %a, %b, %c)
653
- : ( f32, f32, f32) -> (f32)
654
- store %d , %C[%m, %n] : memref<?x?x?xf32, stride_specification>
636
+ %d = mulf %a, %b: f32
637
+ %e = addf %c, %d: f32
638
+ store %e , %C[%m, %n] : memref<?x?x?xf32, stride_specification>
655
639
}
656
640
}
657
641
}
@@ -662,7 +646,7 @@ def GenericOp : GenericOpBase<"generic"> {
662
646
mixing input and output ranked tensor values with input and output memrefs.
663
647
664
648
```mlir
665
- %C = linalg.generic #trait_attribute %A, %B {other-attributes} :
649
+ %C = linalg.generic #trait_attribute %A, %B {other-attributes} {region} :
666
650
tensor<?x?xf32>,
667
651
memref<?x?xf32, stride_specification>
668
652
-> (tensor<?x?xf32>)
@@ -708,13 +692,6 @@ def IndexedGenericOp : GenericOpBase<"indexed_generic"> {
708
692
- args_in: an I64Attr representing the number of input (readonly) views
709
693
- args_out: an I64Attr representing the number of output (readwrite) views
710
694
- doc [optional]: a documentation string
711
- - fun: a FlatSymbolRefAttr that must resolve to an existing function
712
- symbol. To support inplace updates in a generic fashion, the signature
713
- of the function must be:
714
- ```
715
- fun([index types of induction variables], [input views element types],
716
- [output views element types]) -> ([output views element types])
717
- ```
718
695
- indexing_maps: a list of AffineMapAttr, one AffineMapAttr per each input
719
696
and output view. Such AffineMapAttr specifies the mapping between the
720
697
loops and the indexing within each view.
@@ -732,23 +709,13 @@ def IndexedGenericOp : GenericOpBase<"indexed_generic"> {
732
709
Defining a #matmul_trait attribute in MLIR can be done as follows:
733
710
734
711
```mlir
735
- func @fma(%offset_m: index, %offset_n: index, %offset_k: index,
736
- %a: f32, %b: f32, %c: f32)
737
- -> f32
738
- {
739
- "some_optional_condition"(%offset_m, %offset_n, %offset_k)
740
- %d = mulf %a, %b: f32
741
- %e = addf %c, %d: f32
742
- return %e: f32
743
- }
744
712
#matmul_accesses = [
745
713
(m, n, k) -> (m, k),
746
714
(m, n, k) -> (k, n),
747
715
(m, n, k) -> (m, n)
748
716
]
749
717
#matmul_trait = {
750
718
doc = "C(m, n) += A(m, k) * B(k, n)",
751
- fun = @fma,
752
719
indexing_maps = #matmul_accesses,
753
720
library_call = "linalg_matmul",
754
721
n_views = [2, 1],
@@ -759,10 +726,16 @@ def IndexedGenericOp : GenericOpBase<"indexed_generic"> {
759
726
And can be reused in multiple places as:
760
727
761
728
```mlir
762
- linalg.indexed_generic #matmul_trait %A, %B, %C [other-attributes] :
763
- memref<?x?xf32, stride_specification>,
764
- memref<?x?xf32, stride_specification>,
765
- memref<?x?xf32, stride_specification>
729
+ linalg.indexed_generic #matmul_trait %A, %B, %C [other-attributes] {
730
+ (%offset_m: index, %offset_n: index, %offset_k: index,
731
+ %a: f32, %b: f32, %c: f32) :
732
+ "some_optional_computation"(%offset_m, %offset_n, %offset_k)
733
+ %d = mulf %a, %b: f32
734
+ %e = addf %c, %d: f32
735
+ linalg_yield %e : f32
736
+ } : memref<?x?xf32, stride_specification>,
737
+ memref<?x?xf32, stride_specification>,
738
+ memref<?x?xf32, stride_specification>
766
739
```
767
740
768
741
This may lower to either:
@@ -784,8 +757,9 @@ def IndexedGenericOp : GenericOpBase<"indexed_generic"> {
784
757
%a = load %A[%m, %k] : memref<?x?xf32, stride_specification>
785
758
%b = load %B[%k, %n] : memref<?x?xf32, stride_specification>
786
759
%c = load %C[%m, %n] : memref<?x?xf32, stride_specification>
787
- %d = call @func_of_elements_and_indices(%m, %n, %k, %a, %b, %c)
788
- : (index, index, index, f32, f32, f32) -> (f32)
760
+ "some_optional_computation"(%m, %n, %k)
761
+ %d = mulf %a, %b: f32
762
+ %e = addf %c, %d: f32
789
763
store %d, %C[%m, %n] : memref<?x?x?xf32, stride_specification>
790
764
}
791
765
}
0 commit comments