Skip to content
This repository was archived by the owner on Apr 28, 2023. It is now read-only.

Commit 018d3b9

Browse files
Add at::Tensor -> TensorInfo
This commit adds a conversion from ATen tensor directly to TensorInfo. Previously one needed to go through multiple conversions. These conversions were expected to be fully memoized by keeping a reference to the executor. However the PyTorch autograd function is stateless and we cannot rely on this assumption. Therefore we reduce the overhead by removing all the unnecessary conversions. Notice that a we need to retrieve the dlpack type from the ATen type and this function is not exposed on the ATen side so we currently copy the ATen code. Once the function is exposed and shipped in the proper conda package we can get rid of our copy.
1 parent 7a11fca commit 018d3b9

File tree

4 files changed

+57
-1
lines changed

4 files changed

+57
-1
lines changed

tc/aten/aten-inl.h

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,57 @@
2525

2626
namespace tc {
2727
namespace aten {
28+
29+
// Stolen from ATen, get rid of our copy when ATen exposes the functionality
30+
// Unfortunately we need to wait for updated conda packages so we just copy
31+
// for now.
32+
inline DLDataType getDLDataType(const at::Type& type) {
33+
using at::ScalarType;
34+
35+
DLDataType dtype;
36+
dtype.lanes = 1;
37+
dtype.bits = type.elementSizeInBytes() * 8;
38+
switch (type.scalarType()) {
39+
case ScalarType::Byte:
40+
dtype.code = DLDataTypeCode::kDLUInt;
41+
break;
42+
case ScalarType::Char:
43+
dtype.code = DLDataTypeCode::kDLInt;
44+
break;
45+
case ScalarType::Double:
46+
dtype.code = DLDataTypeCode::kDLFloat;
47+
break;
48+
case ScalarType::Float:
49+
dtype.code = DLDataTypeCode::kDLFloat;
50+
break;
51+
case ScalarType::Int:
52+
dtype.code = DLDataTypeCode::kDLInt;
53+
break;
54+
case ScalarType::Long:
55+
dtype.code = DLDataTypeCode::kDLInt;
56+
break;
57+
case ScalarType::Short:
58+
dtype.code = DLDataTypeCode::kDLInt;
59+
break;
60+
case ScalarType::Half:
61+
dtype.code = DLDataTypeCode::kDLFloat;
62+
break;
63+
case ScalarType::Undefined:
64+
throw std::logic_error("Undefined is not a valid ScalarType");
65+
case ScalarType::NumOptions:
66+
throw std::logic_error("NumOptions is not a valid ScalarType");
67+
}
68+
return dtype;
69+
}
70+
71+
inline TensorInfo toTensorInfo(const at::Tensor& t) {
72+
return TensorInfo(
73+
getDLDataType(t.type()),
74+
reinterpret_cast<std::uintptr_t>(t.data_ptr()) % TensorInfo::kAlignment,
75+
t.sizes(),
76+
t.strides());
77+
}
78+
2879
inline std::vector<DLTensorUPtr> makeDLTensors(
2980
const std::vector<at::Tensor>& tensors) {
3081
std::vector<DLTensorUPtr> dlTensors;

tc/aten/aten.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525
namespace tc {
2626
namespace aten {
2727

28+
inline TensorInfo toTensorInfo(const at::Tensor&);
29+
2830
inline std::vector<DLTensorUPtr> makeDLTensors(
2931
const std::vector<at::Tensor>& tensors);
3032

tc/core/tensor.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,8 @@ namespace tc {
2828
namespace detail {
2929
template <typename DLTensorType>
3030
uint64_t getDLTensorAlignment(const DLTensorType* t) {
31-
return (reinterpret_cast<std::uintptr_t>(t->data) + t->byte_offset) % 256;
31+
return (reinterpret_cast<std::uintptr_t>(t->data) + t->byte_offset) %
32+
TensorInfo::kAlignment;
3233
}
3334

3435
std::vector<int64_t> toIntVector(const int64_t* ptr, size_t ndim) {

tc/core/tensor.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,8 @@ std::vector<const DLTensor*> extractRawPtrs(
7777
* It is serializable to protobuf and stored directly in the cache.
7878
*/
7979
struct TensorInfo {
80+
static constexpr int kAlignment = 256;
81+
8082
DLDataType dtype;
8183
uint64_t alignment;
8284
std::vector<int64_t> shape;

0 commit comments

Comments
 (0)