|
9 | 9 | #include "core/providers/shared_library/provider_api.h"
|
10 | 10 | namespace vaip {
|
11 | 11 | using namespace onnxruntime;
|
| 12 | + |
| 13 | +static gsl::span<const char> process_ext_address(const ONNX_NAMESPACE::TensorProto& tensor) { |
| 14 | + auto tensor_proto = const_cast<ONNX_NAMESPACE::TensorProto*>(&tensor); |
| 15 | + auto file = std::string(); |
| 16 | + uintptr_t offset = 0; |
| 17 | + size_t size = 0; |
| 18 | + if (tensor_proto->data_location() == ONNX_NAMESPACE::TensorProto_DataLocation::TensorProto_DataLocation_EXTERNAL) { |
| 19 | + auto external_data = tensor_proto->mutable_external_data(); |
| 20 | + auto external_data_size = external_data->size(); |
| 21 | + for (auto i = 0; i < external_data_size; ++i) { |
| 22 | + auto& data = external_data->at(i); |
| 23 | + char* end = nullptr; |
| 24 | + if (*data.mutable_key() == "location") { |
| 25 | + file = *data.mutable_value(); |
| 26 | + } else if (*data.mutable_key() == "offset") { |
| 27 | + offset = (uintptr_t)std::strtoull(data.mutable_value()->data(), &end, 10); |
| 28 | + } else if (*data.mutable_key() == "length") { |
| 29 | + size = (size_t)std::strtoull(data.mutable_value()->data(), &end, 10); |
| 30 | + } else if (*data.mutable_key() == "checksum") { |
| 31 | + // checksum = (size_t)std::strtoull(data.mutable_value()->data(), &end, 10); |
| 32 | + } |
| 33 | + } |
| 34 | + if (file == "*/_ORT_MEM_ADDR_/*") { |
| 35 | + auto addr = reinterpret_cast<const char*>(offset); |
| 36 | + return {addr, size}; |
| 37 | + } |
| 38 | + } |
| 39 | + return {}; |
| 40 | +} |
| 41 | + |
12 | 42 | gsl::span<const char> tensor_proto_as_raw(const onnxruntime::Graph& graph, const ONNX_NAMESPACE::TensorProto& tensor) {
|
13 | 43 | auto& mut_tensor = const_cast<ONNX_NAMESPACE::TensorProto&>(tensor);
|
14 | 44 | if (!tensor.has_raw_data()) {
|
| 45 | + auto maybe_external_memory_address = process_ext_address(tensor); |
| 46 | + if (!maybe_external_memory_address.empty()) { |
| 47 | + return maybe_external_memory_address; |
| 48 | + } |
| 49 | + |
15 | 50 | std::vector<uint8_t> unpacked_tensor;
|
16 | 51 | auto path = graph.ModelPath();
|
17 | 52 | auto s = onnxruntime::utils::UnpackInitializerData(tensor, path, unpacked_tensor);
|
|
0 commit comments