Skip to content

Commit f543122

Browse files
[mlir][Linalg] Drop function attribute from generic ops.
The function attribute in generic ops is not paying for itself. A region is the more standardized way of specifying a custom computation. If needed this region can call a function directly. This is deemed more natural than managing a dedicated function attribute. This also simplifies named ops generation by trimming unnecessary complexity. Differential Revision: https://reviews.llvm.org/D78266
1 parent 2ec5520 commit f543122

File tree

11 files changed

+206
-532
lines changed

11 files changed

+206
-532
lines changed

mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td

Lines changed: 26 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -523,15 +523,14 @@ class GenericOpBase<string mnemonic> : LinalgStructuredBase_Op<mnemonic, []> {
523523
AffineMapArrayAttr:$indexing_maps,
524524
ArrayAttr:$iterator_types,
525525
OptionalAttr<StrAttr>:$doc,
526-
OptionalAttr<FlatSymbolRefAttr>:$fun,
527526
OptionalAttr<StrAttr>:$library_call);
528527
let results = (outs Variadic<AnyRankedTensor>:$output_tensors);
529528
let regions = (region AnyRegion:$region);
530529
let extraClassDeclaration = [{
531530
SmallVector<StringRef, 8> linalgTraitAttrNames() {
532531
return SmallVector<StringRef, 8>{
533532
getArgsInAttrName(), getArgsOutAttrName(), getDocAttrName(),
534-
getFunAttrName(), getIndexingMapsAttrName(), getLibraryCallAttrName(),
533+
getIndexingMapsAttrName(), getLibraryCallAttrName(),
535534
getIteratorTypesAttrName()
536535
};
537536
}
@@ -540,12 +539,6 @@ class GenericOpBase<string mnemonic> : LinalgStructuredBase_Op<mnemonic, []> {
540539

541540
unsigned getNumOutputs() { return args_out().getSExtValue(); }
542541

543-
FuncOp getFunction() {
544-
auto moduleOp = getParentOfType<ModuleOp>();
545-
return fun().hasValue() ?
546-
moduleOp.lookupSymbol<FuncOp>(fun().getValue()) : FuncOp();
547-
}
548-
549542
StringRef getLibraryCallName() {
550543
return library_call().hasValue() ? library_call().getValue() : "";
551544
}
@@ -581,13 +574,6 @@ def GenericOp : GenericOpBase<"generic"> {
581574
- args_in: an I64Attr representing the number of input (readonly) views
582575
- args_out: an I64Attr representing the number of output (readwrite) views
583576
- doc [optional]: a documentation string
584-
- fun: a FlatSymbolRefAttr that must resolve to an existing function
585-
symbol. To support inplace updates in a generic fashion, the signature
586-
of the function must be:
587-
```
588-
fun([input views element types], [output views element types])
589-
-> ([output views element types])
590-
```
591577
- indexing_maps: a list of AffineMapAttr, one AffineMapAttr per each input
592578
and output view. Such AffineMapAttr specifies the mapping between the
593579
loops and the indexing within each view.
@@ -604,19 +590,13 @@ def GenericOp : GenericOpBase<"generic"> {
604590
Example:
605591
Defining a #matmul_trait attribute in MLIR can be done as follows:
606592
```mlir
607-
func @fma(%a: f32, %b: f32, %c: f32) -> f32 {
608-
%d = mulf %a, %b: f32
609-
%e = addf %c, %d: f32
610-
return %e: f32
611-
}
612593
#matmul_accesses = [
613594
(m, n, k) -> (m, k),
614595
(m, n, k) -> (k, n),
615596
(m, n, k) -> (m, n)
616597
]
617598
#matmul_trait = {
618599
doc = "C(m, n) += A(m, k) * B(k, n)",
619-
fun = @fma,
620600
indexing_maps = #matmul_accesses,
621601
library_call = "linalg_matmul",
622602
n_views = [2, 1],
@@ -626,10 +606,14 @@ def GenericOp : GenericOpBase<"generic"> {
626606

627607
And can be reused in multiple places as:
628608
```mlir
629-
linalg.generic #matmul_trait %A, %B, %C [other-attributes] :
630-
memref<?x?xf32, stride_specification>,
631-
memref<?x?xf32, stride_specification>,
632-
memref<?x?xf32, stride_specification>
609+
linalg.generic #matmul_trait %A, %B, %C [other-attributes] {
610+
(%a: f32, %b: f32, %c: f32) :
611+
%d = mulf %a, %b: f32
612+
%e = addf %c, %d: f32
613+
linalg_yield %e : f32
614+
} : memref<?x?xf32, stride_specification>,
615+
memref<?x?xf32, stride_specification>,
616+
memref<?x?xf32, stride_specification>
633617
```
634618

635619
This may lower to either:
@@ -649,9 +633,9 @@ def GenericOp : GenericOpBase<"generic"> {
649633
%a = load %A[%m, %k] : memref<?x?xf32, stride_specification>
650634
%b = load %B[%k, %n] : memref<?x?xf32, stride_specification>
651635
%c = load %C[%m, %n] : memref<?x?xf32, stride_specification>
652-
%d = call @func_of_elements(%a, %b, %c)
653-
: (f32, f32, f32) -> (f32)
654-
store %d, %C[%m, %n] : memref<?x?x?xf32, stride_specification>
636+
%d = mulf %a, %b: f32
637+
%e = addf %c, %d: f32
638+
store %e, %C[%m, %n] : memref<?x?x?xf32, stride_specification>
655639
}
656640
}
657641
}
@@ -662,7 +646,7 @@ def GenericOp : GenericOpBase<"generic"> {
662646
mixing input and output ranked tensor values with input and output memrefs.
663647

664648
```mlir
665-
%C = linalg.generic #trait_attribute %A, %B {other-attributes} :
649+
%C = linalg.generic #trait_attribute %A, %B {other-attributes} {region} :
666650
tensor<?x?xf32>,
667651
memref<?x?xf32, stride_specification>
668652
-> (tensor<?x?xf32>)
@@ -708,13 +692,6 @@ def IndexedGenericOp : GenericOpBase<"indexed_generic"> {
708692
- args_in: an I64Attr representing the number of input (readonly) views
709693
- args_out: an I64Attr representing the number of output (readwrite) views
710694
- doc [optional]: a documentation string
711-
- fun: a FlatSymbolRefAttr that must resolve to an existing function
712-
symbol. To support inplace updates in a generic fashion, the signature
713-
of the function must be:
714-
```
715-
fun([index types of induction variables], [input views element types],
716-
[output views element types]) -> ([output views element types])
717-
```
718695
- indexing_maps: a list of AffineMapAttr, one AffineMapAttr per each input
719696
and output view. Such AffineMapAttr specifies the mapping between the
720697
loops and the indexing within each view.
@@ -732,23 +709,13 @@ def IndexedGenericOp : GenericOpBase<"indexed_generic"> {
732709
Defining a #matmul_trait attribute in MLIR can be done as follows:
733710

734711
```mlir
735-
func @fma(%offset_m: index, %offset_n: index, %offset_k: index,
736-
%a: f32, %b: f32, %c: f32)
737-
-> f32
738-
{
739-
"some_optional_condition"(%offset_m, %offset_n, %offset_k)
740-
%d = mulf %a, %b: f32
741-
%e = addf %c, %d: f32
742-
return %e: f32
743-
}
744712
#matmul_accesses = [
745713
(m, n, k) -> (m, k),
746714
(m, n, k) -> (k, n),
747715
(m, n, k) -> (m, n)
748716
]
749717
#matmul_trait = {
750718
doc = "C(m, n) += A(m, k) * B(k, n)",
751-
fun = @fma,
752719
indexing_maps = #matmul_accesses,
753720
library_call = "linalg_matmul",
754721
n_views = [2, 1],
@@ -759,10 +726,16 @@ def IndexedGenericOp : GenericOpBase<"indexed_generic"> {
759726
And can be reused in multiple places as:
760727

761728
```mlir
762-
linalg.indexed_generic #matmul_trait %A, %B, %C [other-attributes] :
763-
memref<?x?xf32, stride_specification>,
764-
memref<?x?xf32, stride_specification>,
765-
memref<?x?xf32, stride_specification>
729+
linalg.indexed_generic #matmul_trait %A, %B, %C [other-attributes] {
730+
(%offset_m: index, %offset_n: index, %offset_k: index,
731+
%a: f32, %b: f32, %c: f32) :
732+
"some_optional_computation"(%offset_m, %offset_n, %offset_k)
733+
%d = mulf %a, %b: f32
734+
%e = addf %c, %d: f32
735+
linalg_yield %e : f32
736+
} : memref<?x?xf32, stride_specification>,
737+
memref<?x?xf32, stride_specification>,
738+
memref<?x?xf32, stride_specification>
766739
```
767740

768741
This may lower to either:
@@ -784,8 +757,9 @@ def IndexedGenericOp : GenericOpBase<"indexed_generic"> {
784757
%a = load %A[%m, %k] : memref<?x?xf32, stride_specification>
785758
%b = load %B[%k, %n] : memref<?x?xf32, stride_specification>
786759
%c = load %C[%m, %n] : memref<?x?xf32, stride_specification>
787-
%d = call @func_of_elements_and_indices(%m, %n, %k, %a, %b, %c)
788-
: (index, index, index, f32, f32, f32) -> (f32)
760+
"some_optional_computation"(%m, %n, %k)
761+
%d = mulf %a, %b: f32
762+
%e = addf %c, %d: f32
789763
store %d, %C[%m, %n] : memref<?x?x?xf32, stride_specification>
790764
}
791765
}

mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -66,10 +66,6 @@ constexpr StringRef getArgsOutAttrName() { return "args_out"; }
6666
/// string of the structured op.
6767
constexpr StringRef getDocAttrName() { return "doc"; }
6868

69-
/// Attribute name for the StrArrayAttr which encodes the SymbolAttr for the
70-
/// MLIR function that implements the body of the structured op.
71-
constexpr StringRef getFunAttrName() { return "fun"; }
72-
7369
/// Attribute name for the StrArrayAttr which encodes the external library
7470
/// function that implements the structured op.
7571
constexpr StringRef getLibraryCallAttrName() { return "library_call"; }

mlir/lib/Dialect/Linalg/EDSC/Builders.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,6 @@ Operation *mlir::edsc::makeGenericLinalgOp(
177177
builder.getAffineMapArrayAttr(maps),
178178
builder.getStrArrayAttr(iteratorStrTypes),
179179
StringAttr() /*doc*/,
180-
FlatSymbolRefAttr() /*fun*/,
181180
StringAttr() /*library_call*/
182181
/* TODO: other attributes in op */
183182
)

mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp

Lines changed: 17 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -133,10 +133,11 @@ static void printGenericOp(OpAsmPrinter &p, GenericOpType op) {
133133
attrs.push_back(attr);
134134

135135
auto dictAttr = DictionaryAttr::get(attrs, op.getContext());
136-
p << op.getOperationName() << " " << dictAttr << " " << op.getOperands();
136+
p << op.getOperationName() << " " << dictAttr;
137+
p.printOptionalAttrDict(op.getAttrs(), attrNames);
138+
p << " " << op.getOperands();
137139
if (!op.region().empty())
138140
p.printRegion(op.region());
139-
p.printOptionalAttrDict(op.getAttrs(), attrNames);
140141
p << ": " << op.getOperandTypes();
141142
auto outputTensorTypes = op.getResultTypes();
142143
if (!outputTensorTypes.empty())
@@ -156,21 +157,21 @@ static ParseResult parseGenericOp(OpAsmParser &parser, OperationState &result) {
156157
// The name is unimportant as we will overwrite result.attributes.
157158
// The core linalg traits must contain the information necessary to pass the
158159
// verifier.
159-
if (parser.parseAttribute(dictAttr, "_", result.attributes) ||
160-
parser.parseOperandList(operandsInfo))
160+
if (parser.parseAttribute(dictAttr, "_", result.attributes))
161161
return failure();
162162
result.attributes.assign(dictAttr.getValue().begin(),
163163
dictAttr.getValue().end());
164164

165+
// Optional attributes may be added.
166+
if (parser.parseOptionalAttrDict(result.attributes) ||
167+
parser.parseOperandList(operandsInfo))
168+
return failure();
169+
165170
Region &region = *result.addRegion();
166171
SmallVector<Type, 8> operandTypes, regionTypes;
167-
// Optional attributes may be added.
168-
// Either Optional getFunAttrName() attribute or region must be specified.
169-
if (!dictAttr.get(getFunAttrName()) &&
170-
parser.parseOptionalRegion(region, regionOperandsInfo, regionTypes))
172+
if (parser.parseRegion(region, regionOperandsInfo, regionTypes))
171173
return failure();
172-
if (parser.parseOptionalAttrDict(result.attributes) ||
173-
parser.parseColonTypeList(operandTypes))
174+
if (parser.parseColonTypeList(operandTypes))
174175
return failure();
175176
// Generic ops may specify that a subset of its outputs are tensors. Such
176177
// outputs are specified in the result type.
@@ -183,10 +184,7 @@ static ParseResult parseGenericOp(OpAsmParser &parser, OperationState &result) {
183184
parser.getCurrentLocation(), result.operands);
184185
}
185186

186-
template <typename GenericOpType>
187-
static LogicalResult verifyBlockArgs(GenericOpType op, Block &block);
188-
189-
template <> LogicalResult verifyBlockArgs(GenericOp op, Block &block) {
187+
LogicalResult verifyBlockArgs(GenericOp op, Block &block) {
190188
auto nOperands = op.getNumOperands();
191189
if (block.getNumArguments() != nOperands)
192190
return op.emitOpError("expected number of block arguments to match number "
@@ -205,7 +203,7 @@ template <> LogicalResult verifyBlockArgs(GenericOp op, Block &block) {
205203
return success();
206204
}
207205

208-
template <> LogicalResult verifyBlockArgs(IndexedGenericOp op, Block &block) {
206+
LogicalResult verifyBlockArgs(IndexedGenericOp op, Block &block) {
209207
auto nInputViews = op.getNumInputs();
210208
auto nLoops = op.getNumLoops();
211209
auto nOperands = op.getNumOperands();
@@ -234,81 +232,6 @@ template <> LogicalResult verifyBlockArgs(IndexedGenericOp op, Block &block) {
234232
return success();
235233
}
236234

237-
template <typename GenericOpType>
238-
static LogicalResult verifyFuncArgs(GenericOpType op, FunctionType funType);
239-
240-
template <typename GenericOpType>
241-
static LogicalResult verifyFuncArgsGeneric(GenericOpType op,
242-
FunctionType funType) {
243-
auto res = verifyFuncArgs(op, funType);
244-
if (failed(res))
245-
return res;
246-
247-
auto nInputs = op.getNumInputs();
248-
auto nOutputs = op.getNumOutputs();
249-
// linalg.generic output element types are exactly the function results.
250-
for (unsigned idx = 0; idx < nOutputs; ++idx) {
251-
ShapedType shapedType = op.getShapedType(nInputs + idx);
252-
if (funType.getResult(idx) != shapedType.getElementType())
253-
return op.emitOpError("expected function result ")
254-
<< (idx + 1) << " of the same type as elemental type "
255-
<< shapedType.getElementType() << " of output " << (idx + 1);
256-
}
257-
return success();
258-
}
259-
260-
template <> LogicalResult verifyFuncArgs(GenericOp op, FunctionType funType) {
261-
auto nOperands = op.getNumOperands();
262-
if (funType.getNumInputs() != nOperands)
263-
return op.emitOpError(
264-
"expected function arguments to match number of operands");
265-
if (funType.getNumResults() != op.getNumOutputs())
266-
return op.emitOpError("expected function results(")
267-
<< funType.getNumResults() << ") to match number of outputs("
268-
<< op.getNumOutputs() << ")";
269-
270-
// linalg.generic operands element types are exactly the first function
271-
// arguments.
272-
for (unsigned idx = 0; idx < nOperands; ++idx) {
273-
ShapedType shapedType = op.getShapedType(idx);
274-
if (funType.getInput(idx) != shapedType.getElementType())
275-
return op.emitOpError("expected function argument ")
276-
<< (idx + 1) << " of the same type as elemental type "
277-
<< shapedType.getElementType() << " of operand " << (idx + 1);
278-
}
279-
280-
return success();
281-
}
282-
283-
template <>
284-
LogicalResult verifyFuncArgs(IndexedGenericOp op, FunctionType funType) {
285-
auto nLoops = op.getNumLoops();
286-
auto nOutputs = op.getNumOutputs();
287-
auto nOperands = op.getNumOperands();
288-
if (funType.getNumInputs() != nOperands + nLoops)
289-
return op.emitOpError("expected function arguments to match number of "
290-
"loops + number of operands");
291-
if (funType.getNumResults() != nOutputs)
292-
return op.emitOpError(
293-
"expected function results to match number of outputs");
294-
for (unsigned i = 0; i < nLoops; ++i)
295-
if (!funType.getInput(i).isIndex())
296-
return op.emitOpError("expected function argument ")
297-
<< (i + 1) << " to be an index";
298-
299-
// linalg.generic operands element types are exactly the first function
300-
// arguments.
301-
for (unsigned idx = 0; idx < nOperands; ++idx) {
302-
ShapedType shapedType = op.getShapedType(idx);
303-
if (funType.getInput(idx + nLoops) != shapedType.getElementType())
304-
return op.emitOpError("expected function argument ")
305-
<< (idx + nLoops + 1) << " of the same type as elemental type "
306-
<< shapedType.getElementType() << " of input " << (idx + 1);
307-
}
308-
309-
return success();
310-
}
311-
312235
template <typename GenericOpType>
313236
static LogicalResult verifyGenericOp(GenericOpType op) {
314237
auto nInputViews = op.getNumInputs();
@@ -320,20 +243,10 @@ static LogicalResult verifyGenericOp(GenericOpType op) {
320243
<< " inputs (tensor or buffer) and output buffer operands";
321244

322245
auto &region = op.region();
323-
auto funOp = op.getFunction();
324-
auto funType = funOp ? funOp.getType() : FunctionType();
325-
if (!region.empty()) {
326-
if (region.getBlocks().size() != 1)
327-
return op.emitOpError("expected region with 1 block");
328-
if (failed(verifyBlockArgs(op, region.getBlocks().front())))
329-
return failure();
330-
} else {
331-
if (!funOp || !funOp.getType())
332-
return op.emitOpError(
333-
"expected function attribute to refer to a defined symbol");
334-
if (failed(verifyFuncArgsGeneric(op, funType)))
335-
return failure();
336-
}
246+
if (region.getBlocks().size() != 1)
247+
return op.emitOpError("expected region with 1 block");
248+
if (failed(verifyBlockArgs(op, region.getBlocks().front())))
249+
return failure();
337250

338251
SmallVector<AffineMap, 4> indexingMaps;
339252
indexingMaps.reserve(op.indexing_maps().size());

mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -382,8 +382,7 @@ static bool areTensorOpsFusible(LinalgOp producer, LinalgOp consumer,
382382
// - only handle ops that use regions for specifying the scalar operations.
383383
if (!producerOp || !consumerOp || producerOp.getNumOutputs() != 1 ||
384384
producerOp.getResult(0) != consumerOp.getOperand(consumerIdx) ||
385-
producerOp.getNumParallelLoops() != producerOp.getNumLoops() ||
386-
producerOp.fun() || consumerOp.fun())
385+
producerOp.getNumParallelLoops() != producerOp.getNumLoops())
387386
return false;
388387

389388
// Get the consumer index map. The number of results of the consumer index map
@@ -472,7 +471,6 @@ Optional<LinalgOp> mlir::linalg::fuseTensorOps(OpBuilder &b, LinalgOp producer,
472471
b.getI64IntegerAttr(fusedArgsIn), b.getI64IntegerAttr(fusedArgsOut),
473472
b.getArrayAttr(fusedIndexingMapAttrs), consumerOp.iterator_types(),
474473
/*doc=*/nullptr,
475-
/*fun=*/nullptr,
476474
/*library_call=*/nullptr);
477475

478476
// Build the region of the fused op.

0 commit comments

Comments
 (0)