Skip to content

Commit 004d31a

Browse files
yiming0416facebook-github-bot
authored andcommitted
[nativert] Move TensorMeta to pytorch core (pytorch#152475)
Summary: Pull Request resolved: pytorch#152475 Torch Native Runtime RFC: pytorch/rfcs#72 Thi diff moves `TensorMeta.cpp` and `TensorMeta.h` to PyTorch core. Test Plan: Internal CI. GitHub CI with newly added test under `test/cpp/nativert/test_tensor_meta.cpp` Reviewed By: zhxchen17 Differential Revision: D73820548
1 parent 99dac70 commit 004d31a

File tree

6 files changed

+344
-2
lines changed

6 files changed

+344
-2
lines changed

build_variables.bzl

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -587,6 +587,11 @@ jit_sources_full = [
587587

588588
libtorch_core_jit_sources = sorted(jit_sources_full)
589589

590+
591+
libtorch_nativert_sources = [
592+
"torch/nativert/graph/TensorMeta.cpp",
593+
]
594+
590595
torch_mobile_tracer_sources = [
591596
"torch/csrc/jit/mobile/model_tracer/tracer.cpp",
592597
"torch/csrc/jit/mobile/model_tracer/TensorUtils.cpp",
@@ -619,7 +624,7 @@ libtorch_lite_cmake_sources = sorted(
619624
torch_mobile_core,
620625
)
621626

622-
libtorch_cmake_sources = libtorch_core_sources + libtorch_core_jit_sources
627+
libtorch_cmake_sources = libtorch_core_sources + libtorch_core_jit_sources + libtorch_nativert_sources
623628

624629
libtorch_extra_sources = libtorch_core_jit_sources + [
625630
"torch/csrc/autograd/TraceTypeManual.cpp",
@@ -659,7 +664,7 @@ libtorch_extra_sources = libtorch_core_jit_sources + [
659664

660665
def libtorch_sources(gencode_pattern = ":generate-code[{}]"):
661666
return (
662-
libtorch_generated_sources(gencode_pattern) + libtorch_core_sources + libtorch_distributed_sources + libtorch_extra_sources
667+
libtorch_generated_sources(gencode_pattern) + libtorch_core_sources + libtorch_distributed_sources + libtorch_extra_sources + libtorch_nativert_sources
663668
)
664669

665670
libtorch_cuda_core_sources = [

caffe2/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1319,6 +1319,7 @@ if(BUILD_TEST)
13191319
)
13201320
else()
13211321
add_subdirectory(${TORCH_ROOT}/test/cpp/jit ${CMAKE_BINARY_DIR}/test_jit)
1322+
add_subdirectory(${TORCH_ROOT}/test/cpp/nativert ${CMAKE_BINARY_DIR}/test_nativert)
13221323
add_subdirectory(${TORCH_ROOT}/test/inductor ${CMAKE_BINARY_DIR}/test_inductor)
13231324
add_subdirectory(
13241325
${TORCH_ROOT}/test/cpp/tensorexpr

test/cpp/nativert/CMakeLists.txt

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
set(NATIVERT_TEST_ROOT ${TORCH_ROOT}/test/cpp/nativert)
2+
3+
# Build the cpp gtest binary containing the cpp-only tests.
4+
set(NATIVERT_TEST_SRCS
5+
${NATIVERT_TEST_ROOT}/test_tensor_meta.cpp
6+
${TORCH_ROOT}/torch/nativert/graph/TensorMeta.cpp
7+
)
8+
9+
add_executable(test_nativert
10+
${TORCH_ROOT}/test/cpp/common/main.cpp
11+
${NATIVERT_TEST_SRCS}
12+
)
13+
14+
# TODO temporary until we can delete the old gtest polyfills.
15+
target_compile_definitions(test_nativert PRIVATE USE_GTEST)
16+
17+
set(NATIVERT_TEST_DEPENDENCIES torch gtest)
18+
19+
target_link_libraries(test_nativert PRIVATE ${NATIVERT_TEST_DEPENDENCIES})
20+
target_include_directories(test_nativert PRIVATE ${ATen_CPU_INCLUDE})
21+
22+
if(USE_CUDA)
23+
target_compile_definitions(test_nativert PRIVATE USE_CUDA)
24+
elseif(USE_ROCM)
25+
target_link_libraries(test_nativert PRIVATE
26+
hiprtc::hiprtc
27+
hip::amdhip64
28+
${TORCH_CUDA_LIBRARIES})
29+
30+
target_compile_definitions(test_nativert PRIVATE USE_ROCM)
31+
endif()
32+
33+
if(INSTALL_TEST)
34+
set_target_properties(test_nativert PROPERTIES INSTALL_RPATH "${CMAKE_INSTALL_RPATH}:${_rpath_portable_origin}/../lib")
35+
install(TARGETS test_nativert DESTINATION bin)
36+
# Install PDB files for MSVC builds
37+
if(MSVC AND BUILD_SHARED_LIBS)
38+
install(FILES $<TARGET_PDB_FILE:test_nativert> DESTINATION bin OPTIONAL)
39+
endif()
40+
endif()
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
#include <gtest/gtest.h>
2+
#include <torch/nativert/graph/TensorMeta.h>
3+
4+
namespace torch::nativert {
5+
TEST(TensorMetaTest, ScalarTypeConversion) {
6+
EXPECT_EQ(
7+
convertJsonScalarType(torch::_export::ScalarType::FLOAT),
8+
c10::ScalarType::Float);
9+
EXPECT_EQ(
10+
convertJsonScalarType(torch::_export::ScalarType::INT),
11+
c10::ScalarType::Int);
12+
EXPECT_EQ(
13+
convertJsonScalarType(torch::_export::ScalarType::HALF),
14+
c10::ScalarType::Half);
15+
EXPECT_EQ(
16+
convertJsonScalarType(torch::_export::ScalarType::COMPLEXHALF),
17+
c10::ScalarType::ComplexHalf);
18+
EXPECT_EQ(
19+
convertJsonScalarType(torch::_export::ScalarType::BFLOAT16),
20+
c10::ScalarType::BFloat16);
21+
EXPECT_THROW(
22+
convertJsonScalarType(static_cast<torch::_export::ScalarType>(100)),
23+
c10::Error);
24+
}
25+
TEST(TensorMetaTest, MemoryFormatConversion) {
26+
EXPECT_EQ(
27+
convertJsonMemoryFormat(torch::_export::MemoryFormat::ContiguousFormat),
28+
c10::MemoryFormat::Contiguous);
29+
EXPECT_EQ(
30+
convertJsonMemoryFormat(torch::_export::MemoryFormat::ChannelsLast),
31+
c10::MemoryFormat::ChannelsLast);
32+
EXPECT_EQ(
33+
convertJsonMemoryFormat(torch::_export::MemoryFormat::PreserveFormat),
34+
c10::MemoryFormat::Preserve);
35+
EXPECT_THROW(
36+
convertJsonMemoryFormat(static_cast<torch::_export::MemoryFormat>(100)),
37+
c10::Error);
38+
}
39+
40+
TEST(TensorMetaTest, LayoutConversion) {
41+
EXPECT_EQ(
42+
convertJsonLayout(torch::_export::Layout::Strided), c10::Layout::Strided);
43+
EXPECT_EQ(
44+
convertJsonLayout(torch::_export::Layout::SparseCsr),
45+
c10::Layout::SparseCsr);
46+
EXPECT_EQ(
47+
convertJsonLayout(torch::_export::Layout::_mkldnn), c10::Layout::Mkldnn);
48+
EXPECT_THROW(
49+
convertJsonLayout(static_cast<torch::_export::Layout>(100)), c10::Error);
50+
}
51+
TEST(TensorMetaTest, DeviceConversion) {
52+
torch::_export::Device cpu_device;
53+
cpu_device.set_type("cpu");
54+
EXPECT_EQ(convertJsonDevice(cpu_device), c10::Device(c10::DeviceType::CPU));
55+
torch::_export::Device cuda_device;
56+
cuda_device.set_type("cuda");
57+
cuda_device.set_index(0);
58+
EXPECT_EQ(
59+
convertJsonDevice(cuda_device), c10::Device(c10::DeviceType::CUDA, 0));
60+
}
61+
62+
} // namespace torch::nativert

torch/nativert/graph/TensorMeta.cpp

Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
#include <torch/nativert/graph/TensorMeta.h>
2+
3+
#include <c10/util/Logging.h>
4+
5+
namespace torch::nativert {
6+
7+
c10::ScalarType convertJsonScalarType(
8+
const torch::_export::ScalarType& scalarType) {
9+
switch (scalarType) {
10+
case torch::_export::ScalarType::UNKNOWN:
11+
TORCH_CHECK(false, "scalar type is not properly set");
12+
case torch::_export::ScalarType::BYTE:
13+
return c10::ScalarType::Byte;
14+
case torch::_export::ScalarType::CHAR:
15+
return c10::ScalarType::Char;
16+
case torch::_export::ScalarType::SHORT:
17+
return c10::ScalarType::Short;
18+
case torch::_export::ScalarType::INT:
19+
return c10::ScalarType::Int;
20+
case torch::_export::ScalarType::LONG:
21+
return c10::ScalarType::Long;
22+
case torch::_export::ScalarType::HALF:
23+
return c10::ScalarType::Half;
24+
case torch::_export::ScalarType::FLOAT:
25+
return c10::ScalarType::Float;
26+
case torch::_export::ScalarType::DOUBLE:
27+
return c10::ScalarType::Double;
28+
case torch::_export::ScalarType::COMPLEXHALF:
29+
return c10::ScalarType::ComplexHalf;
30+
case torch::_export::ScalarType::COMPLEXFLOAT:
31+
return c10::ScalarType::ComplexFloat;
32+
case torch::_export::ScalarType::COMPLEXDOUBLE:
33+
return c10::ScalarType::ComplexDouble;
34+
case torch::_export::ScalarType::BOOL:
35+
return c10::ScalarType::Bool;
36+
case torch::_export::ScalarType::BFLOAT16:
37+
return c10::ScalarType::BFloat16;
38+
case torch::_export::ScalarType::UINT16:
39+
return c10::ScalarType::UInt16;
40+
case torch::_export::ScalarType::FLOAT8E4M3FN:
41+
return c10::ScalarType::Float8_e4m3fn;
42+
case torch::_export::ScalarType::FLOAT8E5M2:
43+
return c10::ScalarType::Float8_e5m2;
44+
default:
45+
TORCH_CHECK(false, "unknown scalar type", static_cast<int>(scalarType));
46+
}
47+
}
48+
49+
c10::MemoryFormat convertJsonMemoryFormat(
50+
const torch::_export::MemoryFormat& memoryFormat) {
51+
switch (memoryFormat) {
52+
case torch::_export::MemoryFormat::Unknown:
53+
TORCH_CHECK(false, "got unknown scalar type");
54+
case torch::_export::MemoryFormat::ContiguousFormat:
55+
return c10::MemoryFormat::Contiguous;
56+
case torch::_export::MemoryFormat::ChannelsLast:
57+
return c10::MemoryFormat::ChannelsLast;
58+
case torch::_export::MemoryFormat::ChannelsLast3d:
59+
return c10::MemoryFormat::ChannelsLast3d;
60+
case torch::_export::MemoryFormat::PreserveFormat:
61+
return c10::MemoryFormat::Preserve;
62+
default:
63+
TORCH_CHECK(
64+
false, "unknown memory format", static_cast<int>(memoryFormat));
65+
}
66+
}
67+
68+
c10::Layout convertJsonLayout(const torch::_export::Layout& layout) {
69+
switch (layout) {
70+
case torch::_export::Layout::Unknown:
71+
TORCH_CHECK(false, "got unknown layout");
72+
case torch::_export::Layout::SparseCoo:
73+
// TODO is this the right translation
74+
return c10::Layout::Sparse;
75+
case torch::_export::Layout::SparseCsr:
76+
return c10::Layout::SparseCsr;
77+
case torch::_export::Layout::SparseCsc:
78+
return c10::Layout::SparseCsc;
79+
case torch::_export::Layout::SparseBsr:
80+
return c10::Layout::SparseBsr;
81+
case torch::_export::Layout::SparseBsc:
82+
return c10::Layout::SparseBsc;
83+
case torch::_export::Layout::_mkldnn:
84+
return c10::Layout::Mkldnn;
85+
case torch::_export::Layout::Strided:
86+
return c10::Layout::Strided;
87+
default:
88+
TORCH_CHECK(false, "unknown layout", static_cast<int>(layout));
89+
}
90+
}
91+
92+
c10::Device convertJsonDevice(const torch::_export::Device& device) {
93+
c10::Device d(device.get_type());
94+
if (auto index = device.get_index()) {
95+
d.set_index(*index);
96+
}
97+
return d;
98+
}
99+
100+
TensorMeta::TensorMeta(const torch::_export::TensorMeta& tensorMeta)
101+
: device_(convertJsonDevice(tensorMeta.get_device())) {
102+
dtype_ = convertJsonScalarType(tensorMeta.get_dtype());
103+
layout_ = convertJsonLayout(tensorMeta.get_layout());
104+
requiresGrad_ = tensorMeta.get_requires_grad();
105+
106+
if (tensorMeta.get_storage_offset().tag() ==
107+
torch::_export::SymInt::Tag::AS_INT) {
108+
storage_offset_ = tensorMeta.get_storage_offset().get_as_int();
109+
} else {
110+
CHECK(false) << "SymInt not supported yet";
111+
}
112+
113+
for (const auto& size : tensorMeta.get_sizes()) {
114+
if (size.tag() == torch::_export::SymInt::Tag::AS_INT) {
115+
int64_t val = size.get_as_int();
116+
sizes_.emplace_back(val);
117+
numel_ *= val;
118+
} else if (size.tag() == torch::_export::SymInt::Tag::AS_EXPR) {
119+
// TODO: it's still unclear how SymInt shape should be used in runtime
120+
// One potential use cases is for verifing inputs shape matches constrain
121+
// This would require unpacking the serialized constrain, which is NYI
122+
//
123+
// For the time being, we just set the symbolic dim to -1
124+
hasSymbolicShape_ = true;
125+
sizes_.emplace_back(-1);
126+
numel_ = -1;
127+
}
128+
}
129+
130+
for (const auto& stride : tensorMeta.get_strides()) {
131+
if (stride.tag() == torch::_export::SymInt::Tag::AS_INT) {
132+
strides_.emplace_back(stride.get_as_int());
133+
} else if (stride.tag() == torch::_export::SymInt::Tag::AS_EXPR) {
134+
// TODO: it's still unclear how SymInt shape should be used in runtime
135+
// Setting symbolic shape to -1 for now
136+
hasSymbolicShape_ = true;
137+
strides_.emplace_back(-1);
138+
}
139+
}
140+
}
141+
142+
} // namespace torch::nativert

torch/nativert/graph/TensorMeta.h

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
#pragma once
2+
3+
#include <c10/core/Device.h>
4+
#include <c10/util/Logging.h>
5+
6+
#include <c10/core/MemoryFormat.h>
7+
#include <c10/core/ScalarType.h>
8+
#include <c10/core/TensorOptions.h>
9+
#include "c10/core/Layout.h"
10+
#include <c10/util/ArrayRef.h>
11+
12+
#include <torch/csrc/utils/generated_serialization_types.h>
13+
14+
namespace torch::nativert {
15+
16+
c10::ScalarType convertJsonScalarType(
17+
const torch::_export::ScalarType& scalarType);
18+
c10::MemoryFormat convertJsonMemoryFormat(
19+
const torch::_export::MemoryFormat& memoryFormat);
20+
c10::Layout convertJsonLayout(const torch::_export::Layout& layout);
21+
c10::Device convertJsonDevice(const torch::_export::Device& device);
22+
23+
class TensorMeta {
24+
public:
25+
explicit TensorMeta(const torch::_export::TensorMeta& tensorMeta);
26+
27+
c10::IntArrayRef sizes() const {
28+
CHECK(!hasSymbolicShape_) << "TensorMeta has symbolic shape";
29+
return sizes_;
30+
}
31+
32+
c10::IntArrayRef strides() const {
33+
CHECK(!hasSymbolicShape_) << "TensorMeta has symbolic shape";
34+
return strides_;
35+
}
36+
37+
c10::Layout layout() const {
38+
return layout_;
39+
}
40+
41+
c10::ScalarType dtype() const {
42+
return dtype_;
43+
}
44+
45+
bool requires_grad() const {
46+
return requiresGrad_;
47+
}
48+
49+
int64_t storage_offset() const {
50+
return storage_offset_;
51+
}
52+
53+
int64_t dim() const {
54+
return sizes_.size();
55+
}
56+
57+
int64_t numel() const {
58+
CHECK(!hasSymbolicShape_) << "TensorMeta has symbolic shape";
59+
return numel_;
60+
}
61+
62+
c10::Device device() const {
63+
return device_;
64+
}
65+
66+
c10::TensorOptions asTensorOptions() const {
67+
return c10::TensorOptions().dtype(dtype_).layout(layout_).requires_grad(
68+
requiresGrad_);
69+
}
70+
71+
// NYI
72+
// c10::SymIntArrayRef sym_sizes() const {}
73+
// c10::SymIntArrayRef sym_strides() const {}
74+
// c10::SymInt sym_storage_offset() const {}
75+
// c10::SymInt sym_numel() const {}
76+
77+
private:
78+
bool hasSymbolicShape_ = false;
79+
80+
std::vector<int64_t> sizes_;
81+
std::vector<int64_t> strides_;
82+
int64_t storage_offset_ = 0;
83+
int64_t numel_ = 1;
84+
85+
c10::ScalarType dtype_;
86+
c10::Layout layout_;
87+
bool requiresGrad_;
88+
89+
c10::Device device_;
90+
};
91+
92+
} // namespace torch::nativert

0 commit comments

Comments
 (0)