Skip to content
This repository was archived by the owner on Apr 28, 2023. It is now read-only.

Commit 031cf48

Browse files
authored
Merge pull request #579 from facebookresearch/compiler-options
Compiler options
2 parents 21ed914 + 7a909c3 commit 031cf48

18 files changed

+164
-59
lines changed

tc/aten/aten_compiler-inl.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,11 @@ std::unique_ptr<typename Backend::ExecutorType> compile(
3232
const std::string& tc,
3333
const std::string& entryPoint,
3434
const std::vector<at::Tensor>& inputs,
35-
const typename Backend::MappingOptionsType& options) {
35+
const typename Backend::MappingOptionsType& options,
36+
const CompilerOptions& compilerOptions) {
3637
auto inputDLTensors = makeDLConstTensors(inputs);
3738
return tc::compile<Backend>(
38-
tc, entryPoint, extractRawPtrs(inputDLTensors), options);
39+
tc, entryPoint, extractRawPtrs(inputDLTensors), options, compilerOptions);
3940
}
4041

4142
template <typename Executor>

tc/aten/aten_compiler.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
#include "tc/aten/aten.h"
2323
#include "tc/core/tensor.h"
2424
#include "tc/core/utils/time.h"
25+
#include "tc/utils/compiler_options.h"
2526

2627
namespace tc {
2728
namespace aten {
@@ -57,7 +58,8 @@ std::unique_ptr<typename Backend::ExecutorType> compile(
5758
const std::string& tc,
5859
const std::string& entryPoint,
5960
const std::vector<at::Tensor>& inputs,
60-
const typename Backend::MappingOptionsType& options);
61+
const typename Backend::MappingOptionsType& options,
62+
const CompilerOptions& compilerOptions = CompilerOptions());
6163

6264
/// Given an executor resulting from compiling a TC, run the TC and fill the
6365
/// outputs vector with the results. The output vector must have as many

tc/autotuner/autotuner-inl.h

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
#include "tc/core/tensor.h"
3030
#include "tc/core/utils/math.h"
3131
#include "tc/lang/canonicalize.h"
32+
#include "tc/utils/compiler_options.h"
3233

3334
namespace tc {
3435
namespace autotune {
@@ -74,6 +75,9 @@ void TuningHarness<Backend>::stopAfterCurrentIteration() {
7475
template <typename Backend>
7576
template <typename SearchStrategy>
7677
void TuningHarness<Backend>::doCompile(SearchStrategy& searchStrategy) {
78+
CompilerOptions supressWarningsOptions;
79+
supressWarningsOptions.emitWarnings = false;
80+
7781
// Atomically fetch and add the next job until there are no jobs left
7882
while (true) {
7983
auto current = currentCompilationJob_.fetch_add(1);
@@ -92,8 +96,8 @@ void TuningHarness<Backend>::doCompile(SearchStrategy& searchStrategy) {
9296
LOG(INFO) << "[COMPILE] Start compilation @:" << current;
9397
LOG_LINE_BY_LINE(INFO, ssInfo);
9498
}
95-
pExecutor =
96-
tc::compile<Backend>(tcTree_, inputs_.begin()->second, options);
99+
pExecutor = tc::detail::compile<Backend>(
100+
tcTree_, inputs_.begin()->second, options, supressWarningsOptions);
97101
LOG_IF(INFO, FLAGS_debug_tuner) << "[COMPILE] Done compilation";
98102
} catch (const std::exception& e) {
99103
LOG(WARNING) << "[TUNER][COMPILE] failed compilation: " << e.what();

tc/core/compiler-inl.h

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -37,25 +37,30 @@ std::unique_ptr<typename Backend::ExecutorType> compile(
3737
const std::string& entryPoint,
3838
const std::vector<const DLConstTensor*>& inputs,
3939
/* TODO: in the future also pass outputs for stride and alignment info */
40-
const typename Backend::MappingOptionsType& options) {
40+
const typename Backend::MappingOptionsType& options,
41+
const CompilerOptions& compilerOptions) {
4142
auto parsedTcs = detail::parse(tc);
4243
TC_CHECK_EQ(parsedTcs.count(entryPoint), 1u)
4344
<< "attempting to access undefined function " << entryPoint;
44-
return compile<Backend>(parsedTcs[entryPoint], inputs, options);
45+
return detail::compile<Backend>(
46+
parsedTcs[entryPoint], inputs, options, compilerOptions);
4547
}
4648

49+
namespace detail {
4750
template <typename Backend>
4851
std::unique_ptr<typename Backend::ExecutorType> compile(
4952
lang::TreeRef tcDefinition,
5053
const std::vector<const DLConstTensor*>& inputs,
5154
/* TODO: in the future also pass outputs for stride and alignment info */
52-
const typename Backend::MappingOptionsType& options) {
55+
const typename Backend::MappingOptionsType& options,
56+
const CompilerOptions& compilerOptions) {
5357
using CompilationResultType = typename Backend::CompilationResultType;
5458

5559
auto inputsInfo = makeTensorInfoVector(inputs);
56-
auto outputsInfo = detail::inferOutputTensorInfo(tcDefinition, inputs);
57-
auto halideComponents =
58-
tc2halide::translate(isl::with_exceptions::globalIslCtx(), tcDefinition);
60+
auto outputsInfo =
61+
detail::inferOutputTensorInfo(tcDefinition, inputs, compilerOptions);
62+
auto halideComponents = tc2halide::translate(
63+
isl::with_exceptions::globalIslCtx(), tcDefinition, compilerOptions);
5964
detail::checkInputsCompliant(halideComponents, inputs);
6065

6166
auto tcName = lang::Def(tcDefinition).name().name();
@@ -69,4 +74,5 @@ std::unique_ptr<typename Backend::ExecutorType> compile(
6974
new typename Backend::ExecutorType(
7075
inputsInfo, outputsInfo, halideComponents, compilationResult));
7176
}
77+
} // namespace detail
7278
} // namespace tc

tc/core/compiler.cc

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,16 +24,19 @@
2424
#include "tc/core/halide_utils.h"
2525
#include "tc/core/tensor.h"
2626
#include "tc/lang/canonicalize.h"
27+
#include "tc/utils/compiler_options.h"
2728

2829
namespace tc {
2930
std::vector<TensorInfo> inferOutputTensorInfo(
3031
const std::string& tc,
3132
const std::string& entryPoint,
32-
const std::vector<const DLConstTensor*> inputs) {
33+
const std::vector<const DLConstTensor*> inputs,
34+
const CompilerOptions& compilerOptions) {
3335
auto parsedTcs = detail::parse(tc);
3436
TC_CHECK_EQ(parsedTcs.count(entryPoint), 1u)
3537
<< "attempting to access undefined function " << entryPoint;
36-
return tc::detail::inferOutputTensorInfo(parsedTcs[entryPoint], inputs);
38+
return tc::detail::inferOutputTensorInfo(
39+
parsedTcs[entryPoint], inputs, compilerOptions);
3740
}
3841

3942
namespace detail {
@@ -101,9 +104,11 @@ void checkInputsCompliant(
101104

102105
std::vector<TensorInfo> inferOutputTensorInfo(
103106
lang::TreeRef tcDefinition,
104-
const std::vector<const DLConstTensor*> inputs) {
107+
const std::vector<const DLConstTensor*> inputs,
108+
const CompilerOptions& compilerOptions) {
105109
return tc::inferOutputTensorInfo(
106-
tc2halide::translate(isl::with_exceptions::globalIslCtx(), tcDefinition),
110+
tc2halide::translate(
111+
isl::with_exceptions::globalIslCtx(), tcDefinition, compilerOptions),
107112
inputs);
108113
}
109114

tc/core/compiler.h

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
#include "tc/core/mapping_options.h"
2323
#include "tc/core/tensor.h"
2424
#include "tc/lang/tree.h"
25+
#include "tc/utils/compiler_options.h"
2526

2627
/**
2728
* This provides a simple functional-style C++ API with multi-backend
@@ -62,8 +63,9 @@ namespace tc {
6263
/// "entryPoint", this function compiles a new TcExecutor for the specified
6364
/// Backend. For now, contiguous output sizes are inferred given input sizes.
6465
/// If you need another kernel for another entryPoint or other inputs or
65-
// other options then just compile another TcExecutor; because atm we fully
66-
/// JIT specialize on all sizes.
66+
/// other options then just compile another TcExecutor; because atm we fully
67+
/// JIT specialize on all sizes. General compilation options (warnings, debug
68+
/// info) are provided in "compilerOptions".
6769
/// \returns a new TcExecutor on which the run method can be called to run
6870
/// entryPoint
6971
template <typename Backend>
@@ -72,7 +74,8 @@ std::unique_ptr<typename Backend::ExecutorType> compile(
7274
const std::string& entryPoint,
7375
const std::vector<const DLConstTensor*>& inputs,
7476
/* TODO: in the future also pass outputs for stride and alignment info */
75-
const typename Backend::MappingOptionsType& options);
77+
const typename Backend::MappingOptionsType& options,
78+
const CompilerOptions& compilerOptions = CompilerOptions());
7679

7780
/// Given a TC representation as a TC + TC function name entryPoint and a list
7881
/// of input tensors that match the definition in the TC function definition
@@ -85,7 +88,8 @@ std::unique_ptr<typename Backend::ExecutorType> compile(
8588
std::vector<TensorInfo> inferOutputTensorInfo(
8689
const std::string& tc,
8790
const std::string& entryPoint,
88-
const std::vector<const DLConstTensor*> inputs);
91+
const std::vector<const DLConstTensor*> inputs,
92+
const CompilerOptions& compilerOptions = CompilerOptions());
8993

9094
namespace detail {
9195
/// Given a TC representation, this parses the TC functions into a map of
@@ -105,7 +109,8 @@ std::unique_ptr<typename Backend::ExecutorType> compile(
105109
lang::TreeRef tcDefinition,
106110
const std::vector<const DLConstTensor*>& inputs,
107111
/* TODO: in the future also pass outputs for stride and alignment info */
108-
const typename Backend::MappingOptionsType& options);
112+
const typename Backend::MappingOptionsType& options,
113+
const CompilerOptions& compilerOptions = CompilerOptions());
109114

110115
/// Given a TC representation as a TreeRef and a list of input tensors that
111116
/// match the definition in the TC function definition (in positional order),
@@ -116,7 +121,8 @@ std::unique_ptr<typename Backend::ExecutorType> compile(
116121
/// performing output shape validation.
117122
std::vector<TensorInfo> inferOutputTensorInfo(
118123
lang::TreeRef tcDefinition,
119-
const std::vector<const DLConstTensor*> inputs);
124+
const std::vector<const DLConstTensor*> inputs,
125+
const CompilerOptions& compilerOptions = CompilerOptions());
120126

121127
} // namespace detail
122128
} // namespace tc

tc/core/polyhedral/scop.cc

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
#include "tc/core/polyhedral/schedule_utils.h"
3535
#include "tc/core/scope_guard.h"
3636
#include "tc/core/tc2halide.h"
37+
#include "tc/utils/compiler_options.h"
3738

3839
using namespace std;
3940

@@ -69,12 +70,18 @@ ScopUPtr Scop::makeScop(
6970
return scop;
7071
}
7172

72-
ScopUPtr Scop::makeScop(isl::ctx ctx, const string& tc) {
73-
return makeScop(ctx, tc2halide::translate(ctx, tc));
73+
ScopUPtr Scop::makeScop(
74+
isl::ctx ctx,
75+
const string& tc,
76+
const CompilerOptions& compilerOptions) {
77+
return makeScop(ctx, tc2halide::translate(ctx, tc, compilerOptions));
7478
}
7579

76-
ScopUPtr Scop::makeScop(isl::ctx ctx, const lang::TreeRef& treeRef) {
77-
return makeScop(ctx, tc2halide::translate(ctx, treeRef));
80+
ScopUPtr Scop::makeScop(
81+
isl::ctx ctx,
82+
const lang::TreeRef& treeRef,
83+
const CompilerOptions& compilerOptions) {
84+
return makeScop(ctx, tc2halide::translate(ctx, treeRef, compilerOptions));
7885
}
7986

8087
isl::union_set& Scop::domainRef() {

tc/core/polyhedral/scop.h

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
#include "tc/core/tc2halide.h"
3333
#include "tc/core/tensor.h"
3434
#include "tc/external/isl.h"
35+
#include "tc/utils/compiler_options.h"
3536

3637
namespace tc {
3738
namespace polyhedral {
@@ -56,11 +57,15 @@ struct Scop {
5657
// Halide IR is constructed and made a member by setting halideComponents.
5758
// These operations are grouped and scheduled in a halide::Stmt which becomes
5859
// the unit from which the scop is constructed.
59-
static std::unique_ptr<Scop> makeScop(isl::ctx ctx, const std::string& tc);
60+
static std::unique_ptr<Scop> makeScop(
61+
isl::ctx ctx,
62+
const std::string& tc,
63+
const CompilerOptions& compilerOptions = CompilerOptions());
6064

6165
static std::unique_ptr<Scop> makeScop(
6266
isl::ctx ctx,
63-
const lang::TreeRef& treeRef);
67+
const lang::TreeRef& treeRef,
68+
const CompilerOptions& compilerOptions = CompilerOptions());
6469

6570
// Clone a Scop
6671
static std::unique_ptr<Scop> makeScop(const Scop& scop) {

tc/core/tc2halide.cc

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#include "tc/core/tc2halide.h"
2121
#include "tc/lang/parser.h"
2222
#include "tc/lang/sema.h"
23+
#include "tc/utils/compiler_options.h"
2324

2425
namespace tc2halide {
2526

@@ -270,7 +271,7 @@ void forwardBoundsInference(
270271
const std::vector<Expr>& exprs,
271272
const FunctionBounds& bounds,
272273
const lang::TreeRef& comprehension,
273-
bool throwWarnings,
274+
const tc::CompilerOptions& compilerOptions,
274275
Scope<Interval>* solution) {
275276
class CreateConstraints : public IRVisitor {
276277
using IRVisitor::visit;
@@ -488,10 +489,10 @@ void forwardBoundsInference(
488489
lang::ErrorReport err(comprehension);
489490
err << "Required precondition will not be checked at runtime: "
490491
<< remaining;
491-
if (throwWarnings) {
492+
if (compilerOptions.throwWarnings) {
492493
throw err;
493494
} else {
494-
warn(err);
495+
warn(err, compilerOptions);
495496
}
496497
}
497498
}
@@ -509,7 +510,7 @@ Expr reductionUpdate(Expr e) {
509510
void translateComprehension(
510511
const lang::Comprehension& comprehension,
511512
const map<string, Parameter>& params,
512-
bool throwWarnings,
513+
const tc::CompilerOptions& compilerOptions,
513514
map<string, Function>* funcs,
514515
FunctionBounds* bounds) {
515516
Function f;
@@ -670,7 +671,7 @@ void translateComprehension(
670671
// Infer the rest
671672
all_exprs.push_back(rhs);
672673
forwardBoundsInference(
673-
all_exprs, *bounds, comprehension, throwWarnings, &solution);
674+
all_exprs, *bounds, comprehension, compilerOptions, &solution);
674675

675676
// TODO: What if subsequent updates have incompatible bounds
676677
// (e.g. an in-place stencil)?. The .bound directive will use the
@@ -754,7 +755,9 @@ void translateComprehension(
754755
}
755756

756757
// Translate a semantically checked TC def to HalideComponents struct.
757-
HalideComponents translateDef(const lang::Def& def, bool throwWarnings) {
758+
HalideComponents translateDef(
759+
const lang::Def& def,
760+
const tc::CompilerOptions& compilerOptions) {
758761
map<string, Function> funcs;
759762
HalideComponents components;
760763
components.def = def;
@@ -765,7 +768,7 @@ HalideComponents translateDef(const lang::Def& def, bool throwWarnings) {
765768
}
766769
for (auto c : def.statements()) {
767770
translateComprehension(
768-
c, components.params, throwWarnings, &funcs, &bounds);
771+
c, components.params, compilerOptions, &funcs, &bounds);
769772
}
770773
vector<Function> outputs;
771774
for (auto p : def.returns()) {
@@ -906,19 +909,24 @@ HalideComponents translateDef(const lang::Def& def, bool throwWarnings) {
906909
}
907910
} // namespace
908911

909-
HalideComponents
910-
translate(isl::ctx ctx, const lang::TreeRef& treeRef, bool throwWarnings) {
912+
HalideComponents translate(
913+
isl::ctx ctx,
914+
const lang::TreeRef& treeRef,
915+
const tc::CompilerOptions& compilerOptions = tc::CompilerOptions()) {
911916
LOG_IF(INFO, tc::FLAGS_debug_halide) << treeRef;
912917
return translateDef(
913-
lang::Def(lang::Sema().checkFunction(treeRef)), throwWarnings);
918+
lang::Def(lang::Sema(compilerOptions).checkFunction(treeRef)),
919+
compilerOptions);
914920
}
915921

916922
// NOTE: there is no guarantee here that the tc string has only one def. It
917923
// could have many defs. Only first def will be converted in that case.
918-
HalideComponents
919-
translate(isl::ctx ctx, const std::string& tc, bool throwWarnings) {
924+
HalideComponents translate(
925+
isl::ctx ctx,
926+
const std::string& tc,
927+
const tc::CompilerOptions& compilerOptions = tc::CompilerOptions()) {
920928
LOG_IF(INFO, tc::FLAGS_debug_halide) << tc;
921-
return translate(ctx, lang::Parser(tc).parseFunction(), throwWarnings);
929+
return translate(ctx, lang::Parser(tc).parseFunction(), compilerOptions);
922930
}
923931

924932
} // namespace tc2halide

tc/core/tc2halide.h

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#include "tc/external/isl.h"
2121
#include "tc/lang/tree.h"
2222
#include "tc/lang/tree_views.h"
23+
#include "tc/utils/compiler_options.h"
2324

2425
namespace tc2halide {
2526

@@ -44,15 +45,19 @@ struct HalideComponents {
4445
Halide::Internal::Call::ConstString kReductionUpdate = "ReductionUpdate";
4546

4647
// Translate a TC parse tree into equivalent Halide imperative IR with
47-
// a naive schedule.
48+
// a naive schedule. Additional options, such as how to treat warnings, are
49+
// passed in as "compilerOptions".
4850
HalideComponents translate(
4951
isl::ctx ctx,
5052
const lang::TreeRef& treeRef,
51-
bool throwWarnings = false);
53+
const tc::CompilerOptions& compilerOptions);
5254

5355
// Translate TC source into equivalent Halide imperative IR with a
54-
// naive schedule.
55-
HalideComponents
56-
translate(isl::ctx ctx, const std::string& tc, bool throwWarnings = false);
56+
// naive schedule. Additional options, such as how to treat warnings, are
57+
// passed in as "compilerOptions".
58+
HalideComponents translate(
59+
isl::ctx ctx,
60+
const std::string& tc,
61+
const tc::CompilerOptions& compilerOptions);
5762

5863
} // namespace tc2halide

0 commit comments

Comments
 (0)