Skip to content
This repository was archived by the owner on Apr 28, 2023. It is now read-only.

Commit f0edebe

Browse files
Proposed PyTorch integration
This commit proposes a set of **measured** low-overhead abstractions that live side-by-side with the current pybind. The idea is to show by example what a low-overhead abstraction for PyTorch should look like and shift to that in a subsequent PR. In particular the changes to the bindings include: 1. implement a compilation cache that performs memoization by hashing TensorInfo on each call. This is required because PyTorch autograd functions are stateless and we need to allocate outputs ourselves; 2. implement an explicit Tuner that takes a configuration object, preset to good default values to avoid making it too easy to call the tuner in single threaded mode; 3. implement an explicit interface to the compilation cache and allow loading the topK options; On the python side, this commit creates: 1. a TcBuilder class which makes a Tc operation usable by: a. providing an entry point to the compilation cache on which we can compile, alloc_outputs and run/unchecked_run (we can later implement an idiom which just call directly into the compiled code if needed); b. providing an abstraction with or without reinforcement to an underlying tuner object backed by a proto file; c. being passed as the first argument to a TcFunction which implements torch.autograd.Function 2. a TcFunction which extends torch.autograd.Function and can take multiple TC names to run in sequence for forward and backward. In the current impl there is a potential synchronization problem that I want to be able to reproduce in the future. This issue goes away in the next commit once output allocation is merged at the C++ level. This commit does a first functional implementation, the following commit will remove extra overhead by merging the output allocation (which also seemd to remove the synchronization issue that I do not know how to solve yet).
1 parent 018d3b9 commit f0edebe

File tree

7 files changed

+954
-5
lines changed

7 files changed

+954
-5
lines changed

tc/autotuner/options_cache-inl.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -336,10 +336,10 @@ std::vector<typename Backend::MappingOptionsType> loadTopKFromCacheFile(
336336

337337
template <typename Backend>
338338
void appendTopKToCacheFile(
339-
const std::shared_ptr<OptionsCache<Backend>>& cache,
339+
const OptionsCache<Backend>& cache,
340340
const std::string& cacheFilename,
341341
uint32_t count) {
342-
OptionsCache<Backend> copy(*cache);
342+
OptionsCache<Backend> copy(cache);
343343
copy.pruneKeepTopK(count);
344344
auto proto = copy.toProtobuf();
345345
OptionsCache<Backend> optionsCache;

tc/autotuner/options_cache.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -166,9 +166,10 @@ struct OptionsCache {
166166
OptionsCacheKeyHash>
167167
store_;
168168

169+
// Make friend to access toProtobuf/fromProtobuf
169170
template <typename BackendType>
170171
friend void appendTopKToCacheFile(
171-
const std::shared_ptr<OptionsCache<BackendType>>& cache,
172+
const OptionsCache<BackendType>& cache,
172173
const std::string& cacheFilename,
173174
uint32_t count);
174175
};
@@ -199,7 +200,7 @@ std::vector<typename Backend::MappingOptionsType> loadTopKFromCacheFile(
199200
/// needed.
200201
template <typename Backend>
201202
void appendTopKToCacheFile(
202-
const std::shared_ptr<OptionsCache<Backend>>& cache,
203+
const OptionsCache<Backend>& cache,
203204
const std::string& cacheFilename,
204205
uint32_t count);
205206

tc/examples/blockdiagperm.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ def blockdiagperm2dfissioned_2(float(B, N) I, int32(N) Idx) -> (O) {
148148
// later if required
149149
if (not FLAGS_options_cache.empty()) {
150150
tc::autotune::appendTopKToCacheFile(
151-
geneticAutotuneATen.optionsCache,
151+
*geneticAutotuneATen.optionsCache,
152152
FLAGS_options_cache,
153153
tc::FLAGS_tuner_save_best_candidates_count);
154154
}

tensor_comprehensions/pybinds/CMakeLists.txt

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,3 +26,22 @@ set_target_properties(tc PROPERTIES DEBUG_POSTFIX "")
2626
# "Install" to source_dir/tensor_comprehensions for easy inclusion in python
2727
# package with setuptools. Without this, setuptools is hell.
2828
install(TARGETS tc DESTINATION ${PROJECT_SOURCE_DIR}/tensor_comprehensions)
29+
30+
add_library(tclib MODULE tclib.cc)
31+
target_include_directories(tclib PUBLIC ${PROJECT_SOURCE_DIR}/tc)
32+
target_link_libraries(
33+
tclib
34+
35+
tc_autotuner
36+
tc_aten
37+
38+
${PYTHON_LIBRARIES}
39+
)
40+
41+
set_target_properties(tclib PROPERTIES PREFIX "")
42+
set_target_properties(tclib PROPERTIES SUFFIX ".so")
43+
set_target_properties(tclib PROPERTIES DEBUG_POSTFIX "")
44+
45+
# "Install" to source_dir/tensor_comprehensions for easy inclusion in python
46+
# package with setuptools. Without this, setuptools is hell.
47+
install(TARGETS tclib DESTINATION ${PROJECT_SOURCE_DIR}/tensor_comprehensions)

0 commit comments

Comments
 (0)