Skip to content

Commit 55ab13e

Browse files
mingyueliuhmingyue
andauthored
[VitisAI] support memory buffer contains the TensorProto external data (microsoft#22042)
### Description Extend VitisAI EP `tensor_proto_as_raw` API to support memory buffer containing the TensorProto external data ### Motivation and Context For reduce peak memory usage, VitisAI EP need support ORT format model and setting session option `session.use_ort_model_bytes_for_initializers` for enable directly use the model bytes for initializers. Co-authored-by: mingyue <mingyue@xilinx.com>
1 parent 5c36110 commit 55ab13e

File tree

1 file changed

+35
-0
lines changed

1 file changed

+35
-0
lines changed

onnxruntime/core/providers/vitisai/imp/tensor_proto.cc

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,44 @@
99
#include "core/providers/shared_library/provider_api.h"
1010
namespace vaip {
1111
using namespace onnxruntime;
12+
13+
static gsl::span<const char> process_ext_address(const ONNX_NAMESPACE::TensorProto& tensor) {
14+
auto tensor_proto = const_cast<ONNX_NAMESPACE::TensorProto*>(&tensor);
15+
auto file = std::string();
16+
uintptr_t offset = 0;
17+
size_t size = 0;
18+
if (tensor_proto->data_location() == ONNX_NAMESPACE::TensorProto_DataLocation::TensorProto_DataLocation_EXTERNAL) {
19+
auto external_data = tensor_proto->mutable_external_data();
20+
auto external_data_size = external_data->size();
21+
for (auto i = 0; i < external_data_size; ++i) {
22+
auto& data = external_data->at(i);
23+
char* end = nullptr;
24+
if (*data.mutable_key() == "location") {
25+
file = *data.mutable_value();
26+
} else if (*data.mutable_key() == "offset") {
27+
offset = (uintptr_t)std::strtoull(data.mutable_value()->data(), &end, 10);
28+
} else if (*data.mutable_key() == "length") {
29+
size = (size_t)std::strtoull(data.mutable_value()->data(), &end, 10);
30+
} else if (*data.mutable_key() == "checksum") {
31+
// checksum = (size_t)std::strtoull(data.mutable_value()->data(), &end, 10);
32+
}
33+
}
34+
if (file == "*/_ORT_MEM_ADDR_/*") {
35+
auto addr = reinterpret_cast<const char*>(offset);
36+
return {addr, size};
37+
}
38+
}
39+
return {};
40+
}
41+
1242
gsl::span<const char> tensor_proto_as_raw(const onnxruntime::Graph& graph, const ONNX_NAMESPACE::TensorProto& tensor) {
1343
auto& mut_tensor = const_cast<ONNX_NAMESPACE::TensorProto&>(tensor);
1444
if (!tensor.has_raw_data()) {
45+
auto maybe_external_memory_address = process_ext_address(tensor);
46+
if (!maybe_external_memory_address.empty()) {
47+
return maybe_external_memory_address;
48+
}
49+
1550
std::vector<uint8_t> unpacked_tensor;
1651
auto path = graph.ModelPath();
1752
auto s = onnxruntime::utils::UnpackInitializerData(tensor, path, unpacked_tensor);

0 commit comments

Comments
 (0)