Skip to content

Commit db7888c

Browse files
authored
[MLIR][Transform] Introduce transform.tune.knob op (#146732)
A new transform op to represent that an attribute is to be chosen from a set of alternatives and that this choice is made available as a `!transform.param`. When a `selected` argument is provided, the op's `apply()` semantics is that of just making this selected attribute available as the result. When `selected` is not provided, `apply()` complains that nothing has resolved the non-determinism that the op is representing.
1 parent d045cc9 commit db7888c

File tree

16 files changed

+474
-0
lines changed

16 files changed

+474
-0
lines changed

mlir/include/mlir/Dialect/Transform/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,4 @@ add_subdirectory(IRDLExtension)
55
add_subdirectory(LoopExtension)
66
add_subdirectory(PDLExtension)
77
add_subdirectory(Transforms)
8+
add_subdirectory(TuneExtension)
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
set(LLVM_TARGET_DEFINITIONS TuneExtensionOps.td)
2+
mlir_tablegen(TuneExtensionOps.h.inc -gen-op-decls)
3+
mlir_tablegen(TuneExtensionOps.cpp.inc -gen-op-defs)
4+
add_public_tablegen_target(MLIRTransformDialectTuneExtensionOpsIncGen)
5+
6+
add_mlir_doc(TuneExtensionOps TuneExtensionOps Dialects/ -gen-op-doc)
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
//===- TuneExtension.h - Tune extension for Transform dialect ---*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef MLIR_DIALECT_TRANSFORM_TUNEEXTENSION_TUNEEXTENSION_H
10+
#define MLIR_DIALECT_TRANSFORM_TUNEEXTENSION_TUNEEXTENSION_H
11+
12+
namespace mlir {
13+
class DialectRegistry;
14+
15+
namespace transform {
16+
/// Registers the tune extension of the Transform dialect in the given registry.
17+
void registerTuneExtension(DialectRegistry &dialectRegistry);
18+
} // namespace transform
19+
} // namespace mlir
20+
21+
#endif // MLIR_DIALECT_TRANSFORM_TUNEEXTENSION_TUNEEXTENSION_H
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
//===- TuneExtensionOps.h - Tune ext. for Transform dialect -----*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef MLIR_DIALECT_TRANSFORM_TUNEEXTENSION_TUNEEXTENSIONOPS_H
10+
#define MLIR_DIALECT_TRANSFORM_TUNEEXTENSION_TUNEEXTENSIONOPS_H
11+
12+
#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
13+
#include "mlir/IR/BuiltinAttributes.h"
14+
#include "mlir/IR/OpDefinition.h"
15+
16+
#define GET_OP_CLASSES
17+
#include "mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.h.inc"
18+
19+
#endif // MLIR_DIALECT_TRANSFORM_TUNEEXTENSION_TUNEEXTENSIONOPS_H
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
//===- TuneExtensionOps.td - Transform dialect operations --*- tablegen -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef MLIR_DIALECT_TRANSFORM_TUNEEXTENSION_TUNEEXTENSIONOPS
10+
#define MLIR_DIALECT_TRANSFORM_TUNEEXTENSION_TUNEEXTENSIONOPS
11+
12+
include "mlir/Dialect/Transform/IR/TransformDialect.td"
13+
include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.td"
14+
include "mlir/Interfaces/SideEffectInterfaces.td"
15+
include "mlir/IR/BuiltinAttributes.td"
16+
include "mlir/IR/CommonAttrConstraints.td"
17+
18+
def KnobOp : Op<Transform_Dialect, "tune.knob", [
19+
DeclareOpInterfaceMethods<TransformOpInterface>,
20+
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
21+
]> {
22+
let summary = "Represents a tunable parameter with a set of options";
23+
24+
let description = [{
25+
Provides a representation for "tunables" within schedules.
26+
27+
Each op represents a single tunable, which has a `name` and a set
28+
of valid `options` described by an attribute. Without a specified
29+
`selected` option, this op represents a non-deterministic choice
30+
that has yet to be resolved -- as such, the interpreter runtime
31+
semantics is to raise a failure.
32+
33+
The non-deterministic choice is resolved through providing a
34+
`selected` attribute. When provided, the interpreter runtime
35+
semantics are to return the `selected` attribute as a param through
36+
the op's result.
37+
38+
-----
39+
40+
In case the `options` attribute is an `ArrayAttr`, the verifier
41+
checks that the provided `selected` attribute occurs in `options`.
42+
}];
43+
let cppNamespace = [{ mlir::transform::tune }];
44+
let hasVerifier = 1;
45+
46+
let arguments = (ins Builtin_StringAttr:$name,
47+
AnyAttr:$options,
48+
OptionalAttr<AnyAttr>:$selected);
49+
let results = (outs TransformParamTypeInterface:$result);
50+
51+
let assemblyFormat =
52+
"`<` $name `>` (`=` $selected^ `from`)? `options` `=` $options attr-dict `->` type(results)";
53+
}
54+
55+
#endif // MLIR_DIALECT_TRANSFORM_TUNEEXTENSION_TUNEEXTENSIONOPS

mlir/include/mlir/InitAllExtensions.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
#include "mlir/Dialect/Transform/IRDLExtension/IRDLExtension.h"
5353
#include "mlir/Dialect/Transform/LoopExtension/LoopExtension.h"
5454
#include "mlir/Dialect/Transform/PDLExtension/PDLExtension.h"
55+
#include "mlir/Dialect/Transform/TuneExtension/TuneExtension.h"
5556
#include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.h"
5657
#include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h"
5758
#include "mlir/Target/LLVMIR/Dialect/GPU/GPUToLLVMIRTranslation.h"
@@ -107,6 +108,7 @@ inline void registerAllExtensions(DialectRegistry &registry) {
107108
transform::registerIRDLExtension(registry);
108109
transform::registerLoopExtension(registry);
109110
transform::registerPDLExtension(registry);
111+
transform::registerTuneExtension(registry);
110112
vector::registerTransformDialectExtension(registry);
111113
arm_neon::registerTransformDialectExtension(registry);
112114
arm_sve::registerTransformDialectExtension(registry);

mlir/lib/Dialect/Transform/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,5 @@ add_subdirectory(IRDLExtension)
55
add_subdirectory(LoopExtension)
66
add_subdirectory(PDLExtension)
77
add_subdirectory(Transforms)
8+
add_subdirectory(TuneExtension)
89
add_subdirectory(Utils)
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
add_mlir_dialect_library(MLIRTransformTuneExtension
2+
TuneExtension.cpp
3+
TuneExtensionOps.cpp
4+
5+
DEPENDS
6+
MLIRTransformDialectTuneExtensionOpsIncGen
7+
8+
LINK_LIBS PUBLIC
9+
MLIRIR
10+
MLIRTransformDialect
11+
)
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
//===- TuneExtension.cpp - Tune extension for the Transform dialect -------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "mlir/Dialect/Transform/TuneExtension/TuneExtension.h"
10+
11+
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
12+
#include "mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.h"
13+
#include "mlir/IR/DialectRegistry.h"
14+
15+
using namespace mlir;
16+
17+
class TuneExtension
18+
: public transform::TransformDialectExtension<TuneExtension> {
19+
public:
20+
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TuneExtension)
21+
22+
void init() {
23+
registerTransformOps<
24+
#define GET_OP_LIST
25+
#include "mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.cpp.inc"
26+
>();
27+
}
28+
};
29+
30+
void mlir::transform::registerTuneExtension(DialectRegistry &dialectRegistry) {
31+
dialectRegistry.addExtensions<TuneExtension>();
32+
}
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
//===- TuneExtensionOps.cpp - Tune extension for the Transform dialect ----===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "mlir/Dialect/Transform/IR/TransformOps.h"
10+
#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
11+
#include "mlir/IR/OpImplementation.h"
12+
#include "mlir/IR/PatternMatch.h"
13+
#include "llvm/Support/Debug.h"
14+
15+
#include "mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.h"
16+
17+
using namespace mlir;
18+
19+
#define GET_OP_CLASSES
20+
#include "mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.cpp.inc"
21+
22+
#define DEBUG_TYPE "transform-tune"
23+
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "] ")
24+
25+
//===----------------------------------------------------------------------===//
26+
// KnobOp
27+
//===----------------------------------------------------------------------===//
28+
29+
void transform::tune::KnobOp::getEffects(
30+
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
31+
producesHandle(getOperation()->getOpResults(), effects);
32+
onlyReadsPayload(effects);
33+
}
34+
35+
DiagnosedSilenceableFailure
36+
transform::tune::KnobOp::apply(transform::TransformRewriter &rewriter,
37+
transform::TransformResults &results,
38+
transform::TransformState &state) {
39+
if (getSelected()) {
40+
results.setParams(llvm::cast<OpResult>(getResult()), *getSelected());
41+
return DiagnosedSilenceableFailure::success();
42+
}
43+
44+
return emitDefiniteFailure()
45+
<< "non-deterministic choice " << getName()
46+
<< " is only resolved through providing a `selected` attr";
47+
}
48+
49+
LogicalResult transform::tune::KnobOp::verify() {
50+
if (auto selected = getSelected()) {
51+
if (auto optionsArray = dyn_cast<ArrayAttr>(getOptions())) {
52+
if (!llvm::is_contained(optionsArray, selected))
53+
return emitOpError("provided `selected` attribute is not an element of "
54+
"`options` array of attributes");
55+
} else
56+
LLVM_DEBUG(DBGS() << "cannot verify `selected` attribute " << selected
57+
<< " is an element of `options` attribute "
58+
<< getOptions());
59+
}
60+
61+
return success();
62+
}

0 commit comments

Comments
 (0)