Skip to content
This repository was archived by the owner on Apr 28, 2023. It is now read-only.

Commit 0edfb22

Browse files
Template ATenCompilationUnit
1 parent a9989f8 commit 0edfb22

26 files changed

+97
-100
lines changed

examples/CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@ foreach(i ${EXAMPLES_FILES})
3333
add_test(${i} ${i})
3434
target_link_libraries(
3535
${i}
36-
tc_aten
3736
tc_autotuner
3837
tc_core
3938
tc_c2
@@ -43,5 +42,7 @@ foreach(i ${EXAMPLES_FILES})
4342
${GTEST_LIBS}
4443
${GFLAGS_LIBRARIES}
4544
${GLOG_LIBRARIES}
45+
46+
${ATEN_LIBRARIES}
4647
)
4748
endforeach()

examples/example_fixture.h

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
#include "tc/core/cuda/cuda.h"
3333
#include "tc/core/cuda/cuda_compilation_cache.h"
3434
#include "tc/core/cuda/cuda_rtc.h"
35+
#include "tc/core/cuda/cuda_tc_executor.h"
3536
#include "tc/core/flags.h"
3637
#include "tc/core/mapping_options.h"
3738
#include "tc/core/scope_guard.h"
@@ -63,7 +64,7 @@ std::vector<const DLTensor*> inferOutputTensorInfo(
6364
const std::string& tc,
6465
const std::string& name,
6566
const std::vector<at::Tensor>& inputs) {
66-
tc::ATenCompilationUnit atCompl;
67+
tc::ATenCompilationUnit<tc::CudaTcExecutor> atCompl;
6768
atCompl.define(tc);
6869
return atCompl.inferOutputTensorInfo(name, inputs);
6970
}
@@ -133,7 +134,7 @@ struct Benchmark : public ::testing::Test {
133134
std::vector<at::Tensor>& outputs) {
134135
return true;
135136
}) {
136-
tc::ATenCompilationUnit atCompl;
137+
tc::ATenCompilationUnit<tc::CudaTcExecutor> atCompl;
137138
atCompl.define(tc);
138139
auto handle = atCompl.compile(name, inputs, mappingOptions);
139140
atCompl.run(name, inputs, outputs, handle);
@@ -281,7 +282,7 @@ struct Benchmark : public ::testing::Test {
281282
tc::CudaCache::loadCacheFromProtobuf(tc::makeCudaFilename(cacheFilename));
282283
tc::FLAGS_tuner_gen_restore_number = 1;
283284

284-
tc::ATenCompilationUnit atCompl;
285+
tc::ATenCompilationUnit<tc::CudaTcExecutor> atCompl;
285286
atCompl.define(tc);
286287

287288
auto mappingOptions = [&]() {
@@ -399,7 +400,7 @@ struct Benchmark : public ::testing::Test {
399400
return *options;
400401
}();
401402

402-
tc::ATenCompilationUnit atCompl;
403+
tc::ATenCompilationUnit<tc::CudaTcExecutor> atCompl;
403404
atCompl.define(TC);
404405
auto handle = atCompl.compile(kernelName, inputs, bestOptions);
405406
std::vector<at::Tensor> outputs;

src/aten/aten_compiler.cc renamed to include/tc/aten/aten_compiler-inl.h

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,14 @@
2323

2424
namespace tc {
2525

26-
ATenCompilationUnit::ATenCompilationUnit() {
27-
executionEngine_ = std::unique_ptr<ExecutionEngine<CudaTcExecutor>>(
28-
new ExecutionEngine<CudaTcExecutor>());
26+
template <typename ExecutorType>
27+
ATenCompilationUnit<ExecutorType>::ATenCompilationUnit() {
28+
executionEngine_ = std::unique_ptr<ExecutionEngine<ExecutorType>>(
29+
new ExecutionEngine<ExecutorType>());
2930
}
3031

31-
void ATenCompilationUnit::define(const std::string& language) {
32+
template <typename ExecutorType>
33+
void ATenCompilationUnit<ExecutorType>::define(const std::string& language) {
3234
executionEngine_->define(language);
3335
}
3436

@@ -67,7 +69,8 @@ void prepareOutputs(
6769

6870
} // namespace
6971

70-
size_t ATenCompilationUnit::compile(
72+
template <typename ExecutorType>
73+
size_t ATenCompilationUnit<ExecutorType>::compile(
7174
const std::string& name,
7275
const std::vector<at::Tensor>& inputs,
7376
const MappingOptions& options) {
@@ -77,7 +80,9 @@ size_t ATenCompilationUnit::compile(
7780
name, inputDLTensorsPair.first, options.toProtobufSerializedString());
7881
}
7982

80-
std::vector<const DLTensor*> ATenCompilationUnit::inferOutputTensorInfo(
83+
template <typename ExecutorType>
84+
std::vector<const DLTensor*>
85+
ATenCompilationUnit<ExecutorType>::inferOutputTensorInfo(
8186
const std::string& name,
8287
const std::vector<at::Tensor>& inputs) {
8388
auto inputDLTensorsPair = toConstDlpackTensors(inputs);
@@ -86,7 +91,8 @@ std::vector<const DLTensor*> ATenCompilationUnit::inferOutputTensorInfo(
8691
name, inputDLTensorsPair.first);
8792
}
8893

89-
Duration ATenCompilationUnit::run(
94+
template <typename ExecutorType>
95+
Duration ATenCompilationUnit<ExecutorType>::run(
9096
const std::string& name,
9197
const std::vector<at::Tensor>& inputs,
9298
std::vector<at::Tensor>& outputs,
@@ -105,7 +111,8 @@ Duration ATenCompilationUnit::run(
105111
handle, inputDLTensorsPair.first, outputDLTensorsPair.first, profile);
106112
}
107113

108-
void ATenCompilationUnit::uncheckedRun(
114+
template <typename ExecutorType>
115+
void ATenCompilationUnit<ExecutorType>::uncheckedRun(
109116
const std::vector<at::Tensor>& inputs,
110117
std::vector<at::Tensor>& outputs,
111118
size_t handle) {

include/tc/aten/aten_compiler.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,15 +23,14 @@
2323
#include <ATen/DLConvertor.h>
2424

2525
#include "tc/aten/utils.h"
26-
#include "tc/core/cuda/cuda.h"
27-
#include "tc/core/cuda/cuda_tc_executor.h"
2826
#include "tc/core/execution_engine.h"
2927
#include "tc/lang/parser.h"
3028

3129
namespace tc {
3230
/// This provides the basic interface for writing ATen style tensor operations
3331
/// based on Tensor Comprehensions.
3432

33+
template <typename ExecutorType>
3534
class ATenCompilationUnit {
3635
public:
3736
explicit ATenCompilationUnit();
@@ -72,6 +71,8 @@ class ATenCompilationUnit {
7271
size_t handle);
7372

7473
private:
75-
std::unique_ptr<ExecutionEngine<CudaTcExecutor>> executionEngine_;
74+
std::unique_ptr<ExecutionEngine<ExecutorType>> executionEngine_;
7675
};
7776
} // namespace tc
77+
78+
#include "tc/aten/aten_compiler-inl.h"

include/tc/aten/utils-inl.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#include <ATen/ATen.h>
2222
#include <ATen/DLConvertor.h>
2323
namespace tc {
24+
namespace {
2425
std::pair<std::vector<DLTensor*>, std::vector<DLManagedTensor*>>
2526
toDlpackTensors(const std::vector<at::Tensor>& tensors) {
2627
std::vector<DLTensor*> dlTensors;
@@ -50,4 +51,5 @@ void deleteDlmTensors(std::vector<DLManagedTensor*>& tensors) {
5051
tensor->deleter(tensor);
5152
}
5253
}
54+
} // namespace
5355
} // namespace tc

include/tc/aten/utils.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,15 @@
2222
#include <ATen/DLConvertor.h>
2323

2424
namespace tc {
25+
namespace {
2526
std::pair<std::vector<DLTensor*>, std::vector<DLManagedTensor*>>
2627
toDlpackTensors(const std::vector<at::Tensor>& tensors);
2728

2829
std::pair<std::vector<const DLTensor*>, std::vector<DLManagedTensor*>>
2930
toConstDlpackTensors(const std::vector<at::Tensor>& tensors);
3031

3132
void deleteDlmTensors(std::vector<DLManagedTensor*>& tensors);
33+
} // namespace
3234
} // namespace tc
3335

3436
#include "tc/aten/utils-inl.h"

include/tc/core/execution_engine-inl.h

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -67,11 +67,11 @@ ExecutionEngine<ExecutorType>::inferOutputTensorInfo(
6767
CHECK_EQ(1, tcNameMap_.count(name))
6868
<< "attempting to access undefined function " << name;
6969
// If we have already compiled for the given inputs, regardless of
70-
// the options, we can get sizes from a corresponding TcExecutor.
70+
// the options, we can get sizes from a corresponding ExecutorType.
7171
auto e = std::find_if(
7272
executors_.begin(),
7373
executors_.end(),
74-
[&](const std::unique_ptr<TcExecutor>& e) {
74+
[&](const std::unique_ptr<ExecutorType>& e) {
7575
return e && name == e->identifier &&
7676
compareDLTensorVectorMetadata(
7777
extractRawPtrs(e->inputsInfo), inputs);
@@ -85,7 +85,7 @@ ExecutionEngine<ExecutorType>::inferOutputTensorInfo(
8585
// null options. It will be used for further size queries but
8686
// will fail if somebody attempts to run it.
8787
auto executor =
88-
tc::make_unique<TcExecutor>(name, inputs, "", tcNameMap_.at(name));
88+
tc::make_unique<ExecutorType>(name, inputs, "", tcNameMap_.at(name));
8989
auto outputsInfo = executor->inferOutputTensorInfo();
9090
emplaceExecutor(std::move(executor));
9191
return outputsInfo;
@@ -114,16 +114,16 @@ size_t ExecutionEngine<ExecutorType>::compile(
114114
return handle;
115115
}
116116

117-
// Steal TcExecutor and give it back under lock
118-
// Run outside of lock on owning TcExecutor.
117+
// Steal ExecutorType and give it back under lock
118+
// Run outside of lock on owning ExecutorType.
119119
template <typename ExecutorType>
120120
Duration ExecutionEngine<ExecutorType>::run(
121121
size_t handle,
122122
const std::vector<const DLTensor*>& inputs,
123123
const std::vector<DLTensor*>& outputs,
124124
bool profile,
125125
std::function<bool(const ExecutorType*)> pruningFunction) {
126-
std::unique_ptr<TcExecutor> executorUPtr(nullptr);
126+
std::unique_ptr<ExecutorType> executorUPtr(nullptr);
127127
{
128128
std::lock_guard<std::mutex> lg(tcExecutorMutex_);
129129
std::swap(executorUPtr, executors_[handle]);
@@ -155,14 +155,14 @@ Duration ExecutionEngine<ExecutorType>::run(
155155
return res;
156156
}
157157

158-
// Steal TcExecutor and give it back under lock
159-
// Run outside of lock on owning TcExecutor.
158+
// Steal ExecutorType and give it back under lock
159+
// Run outside of lock on owning ExecutorType.
160160
template <typename ExecutorType>
161161
void ExecutionEngine<ExecutorType>::uncheckedRun(
162162
size_t handle,
163163
const std::vector<const void*>& inputs,
164164
const std::vector<void*>& outputs) {
165-
std::unique_ptr<TcExecutor> executorUPtr(nullptr);
165+
std::unique_ptr<ExecutorType> executorUPtr(nullptr);
166166
{
167167
std::lock_guard<std::mutex> lg(tcExecutorMutex_);
168168
std::swap(executorUPtr, executors_[handle]);
@@ -193,12 +193,12 @@ template <typename ExecutorType>
193193
void ExecutionEngine<ExecutorType>::clear(size_t handle) {
194194
std::lock_guard<std::mutex> lg(tcExecutorMutex_);
195195
executors_[handle]->clearRuntimeCompiledFunction();
196-
executors_[handle] = std::unique_ptr<TcExecutor>(nullptr);
196+
executors_[handle] = std::unique_ptr<ExecutorType>(nullptr);
197197
}
198198

199199
template <typename ExecutorType>
200200
size_t ExecutionEngine<ExecutorType>::emplaceExecutor(
201-
std::unique_ptr<TcExecutor> executorUPtr) {
201+
std::unique_ptr<ExecutorType> executorUPtr) {
202202
// Insert in vector under lock
203203
std::lock_guard<std::mutex> lg(tcExecutorMutex_);
204204
size_t handle = uidCounter++;
@@ -219,7 +219,7 @@ size_t ExecutionEngine<ExecutorType>::getHandle(
219219
auto it = std::find_if(
220220
executors_.begin(),
221221
executors_.end(),
222-
[&](const std::unique_ptr<TcExecutor>& e) {
222+
[&](const std::unique_ptr<ExecutorType>& e) {
223223
return e && // UPtrs get stolen by run to avoid underlying vector
224224
// realloc issues, guard against that
225225
name == e->identifier &&

include/tc/core/execution_engine.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ class ExecutionEngine {
8585
void clear(size_t handle);
8686

8787
protected:
88-
size_t emplaceExecutor(std::unique_ptr<TcExecutor> p);
88+
size_t emplaceExecutor(std::unique_ptr<ExecutorType> p);
8989

9090
size_t getHandle(
9191
const std::string& name,
@@ -100,7 +100,7 @@ class ExecutionEngine {
100100

101101
/// List of executors, indexed by handle. Derived ExecutionEngines can also
102102
/// derive TcExecutor.
103-
std::vector<std::unique_ptr<TcExecutor>> executors_;
103+
std::vector<std::unique_ptr<ExecutorType>> executors_;
104104

105105
size_t uidCounter = 0;
106106
};

src/CMakeLists.txt

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,6 @@ add_subdirectory(version)
77
add_subdirectory(core)
88
add_subdirectory(autotuner)
99

10-
if (WITH_CUDA)
11-
add_subdirectory(aten)
12-
endif()
13-
1410
if (WITH_CAFFE2 AND WITH_CUDA)
1511
add_subdirectory(c2)
1612
endif()

src/aten/CMakeLists.txt

Lines changed: 0 additions & 28 deletions
This file was deleted.

0 commit comments

Comments
 (0)