Skip to content
This repository was archived by the owner on Apr 28, 2023. It is now read-only.

Commit f6265e0

Browse files
Merge pull request #495 from nicolasvasilache/pr/redo-pytorch
More generic and lightweight PyTorch integration
2 parents 7e2df7c + 11b0c31 commit f6265e0

File tree

13 files changed

+1305
-10
lines changed

13 files changed

+1305
-10
lines changed

.jenkins/build.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ python setup.py install
7070
./test_python/run_test.sh
7171

7272
for f in $(find ./python/examples -name "*.py"); do
73-
python $f
73+
python $f -v
7474
done
7575

7676
FILTER_OUT="benchmark_MLP_model benchmark_kronecker" ./test.sh

python/examples/tc_pybind_example.py

Lines changed: 576 additions & 0 deletions
Large diffs are not rendered by default.

tc/aten/aten-inl.h

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,57 @@
2525

2626
namespace tc {
2727
namespace aten {
28+
29+
// Stolen from ATen, get rid of our copy when ATen exposes the functionality
30+
// Unfortunately we need to wait for updated conda packages so we just copy
31+
// for now.
32+
inline DLDataType getDLDataType(const at::Type& type) {
33+
using at::ScalarType;
34+
35+
DLDataType dtype;
36+
dtype.lanes = 1;
37+
dtype.bits = type.elementSizeInBytes() * 8;
38+
switch (type.scalarType()) {
39+
case ScalarType::Byte:
40+
dtype.code = DLDataTypeCode::kDLUInt;
41+
break;
42+
case ScalarType::Char:
43+
dtype.code = DLDataTypeCode::kDLInt;
44+
break;
45+
case ScalarType::Double:
46+
dtype.code = DLDataTypeCode::kDLFloat;
47+
break;
48+
case ScalarType::Float:
49+
dtype.code = DLDataTypeCode::kDLFloat;
50+
break;
51+
case ScalarType::Int:
52+
dtype.code = DLDataTypeCode::kDLInt;
53+
break;
54+
case ScalarType::Long:
55+
dtype.code = DLDataTypeCode::kDLInt;
56+
break;
57+
case ScalarType::Short:
58+
dtype.code = DLDataTypeCode::kDLInt;
59+
break;
60+
case ScalarType::Half:
61+
dtype.code = DLDataTypeCode::kDLFloat;
62+
break;
63+
case ScalarType::Undefined:
64+
throw std::logic_error("Undefined is not a valid ScalarType");
65+
case ScalarType::NumOptions:
66+
throw std::logic_error("NumOptions is not a valid ScalarType");
67+
}
68+
return dtype;
69+
}
70+
71+
inline TensorInfo toTensorInfo(const at::Tensor& t) {
72+
return TensorInfo(
73+
getDLDataType(t.type()),
74+
reinterpret_cast<std::uintptr_t>(t.data_ptr()) % TensorInfo::kAlignment,
75+
t.sizes(),
76+
t.strides());
77+
}
78+
2879
inline std::vector<DLTensorUPtr> makeDLTensors(
2980
const std::vector<at::Tensor>& tensors) {
3081
std::vector<DLTensorUPtr> dlTensors;

tc/aten/aten.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525
namespace tc {
2626
namespace aten {
2727

28+
inline TensorInfo toTensorInfo(const at::Tensor&);
29+
2830
inline std::vector<DLTensorUPtr> makeDLTensors(
2931
const std::vector<at::Tensor>& tensors);
3032

tc/autotuner/autotuner-inl.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -458,6 +458,7 @@ Autotuner<Backend, SearchStrategy>::tune(
458458
std::this_thread::sleep_for(std::chrono::milliseconds(100));
459459
if (sigint_) {
460460
tuningHarness.stopAfterCurrentIteration();
461+
break;
461462
}
462463
if (sigterm_) {
463464
std::cerr << "Autotuning aborted." << std::endl;

tc/autotuner/options_cache-inl.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -336,10 +336,10 @@ std::vector<typename Backend::MappingOptionsType> loadTopKFromCacheFile(
336336

337337
template <typename Backend>
338338
void appendTopKToCacheFile(
339-
const std::shared_ptr<OptionsCache<Backend>>& cache,
339+
const OptionsCache<Backend>& cache,
340340
const std::string& cacheFilename,
341341
uint32_t count) {
342-
OptionsCache<Backend> copy(*cache);
342+
OptionsCache<Backend> copy(cache);
343343
copy.pruneKeepTopK(count);
344344
auto proto = copy.toProtobuf();
345345
OptionsCache<Backend> optionsCache;

tc/autotuner/options_cache.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -166,9 +166,10 @@ struct OptionsCache {
166166
OptionsCacheKeyHash>
167167
store_;
168168

169+
// Make friend to access toProtobuf/fromProtobuf
169170
template <typename BackendType>
170171
friend void appendTopKToCacheFile(
171-
const std::shared_ptr<OptionsCache<BackendType>>& cache,
172+
const OptionsCache<BackendType>& cache,
172173
const std::string& cacheFilename,
173174
uint32_t count);
174175
};
@@ -199,7 +200,7 @@ std::vector<typename Backend::MappingOptionsType> loadTopKFromCacheFile(
199200
/// needed.
200201
template <typename Backend>
201202
void appendTopKToCacheFile(
202-
const std::shared_ptr<OptionsCache<Backend>>& cache,
203+
const OptionsCache<Backend>& cache,
203204
const std::string& cacheFilename,
204205
uint32_t count);
205206

tc/core/tensor.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,8 @@ namespace tc {
2828
namespace detail {
2929
template <typename DLTensorType>
3030
uint64_t getDLTensorAlignment(const DLTensorType* t) {
31-
return (reinterpret_cast<std::uintptr_t>(t->data) + t->byte_offset) % 256;
31+
return (reinterpret_cast<std::uintptr_t>(t->data) + t->byte_offset) %
32+
TensorInfo::kAlignment;
3233
}
3334

3435
std::vector<int64_t> toIntVector(const int64_t* ptr, size_t ndim) {

tc/core/tensor.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,8 @@ std::vector<const DLTensor*> extractRawPtrs(
7777
* It is serializable to protobuf and stored directly in the cache.
7878
*/
7979
struct TensorInfo {
80+
static constexpr int kAlignment = 256;
81+
8082
DLDataType dtype;
8183
uint64_t alignment;
8284
std::vector<int64_t> shape;

tc/examples/blockdiagperm.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ def blockdiagperm2dfissioned_2(float(B, N) I, int32(N) Idx) -> (O) {
148148
// later if required
149149
if (not FLAGS_options_cache.empty()) {
150150
tc::autotune::appendTopKToCacheFile(
151-
geneticAutotuneATen.optionsCache,
151+
*geneticAutotuneATen.optionsCache,
152152
FLAGS_options_cache,
153153
tc::FLAGS_tuner_save_best_candidates_count);
154154
}

0 commit comments

Comments
 (0)