From 2444ee0b8cbaebeb82d469be9548f18d044fbf9f Mon Sep 17 00:00:00 2001 From: lucylq Date: Mon, 28 Apr 2025 13:09:42 -0700 Subject: [PATCH 1/2] Support named_data in flat_tensor Currently flat_tensor ndm only accounts for tensors in get_data, get_num_keys, get_key functions. Add support to return named_data values as well. TODO:consolidate tensors and named_data into one structure in the flatbuffer. This will simplify all the serialization and runtime code. Differential Revision: [D73679683](https://our.internmc.facebook.com/intern/diff/D73679683/) [ghstack-poisoned] --- .../flat_tensor/flat_tensor_data_map.cpp | 82 +++++++++++++++++-- 1 file changed, 77 insertions(+), 5 deletions(-) diff --git a/extension/flat_tensor/flat_tensor_data_map.cpp b/extension/flat_tensor/flat_tensor_data_map.cpp index 8aa0af13928..5f8f8e9d6e3 100644 --- a/extension/flat_tensor/flat_tensor_data_map.cpp +++ b/extension/flat_tensor/flat_tensor_data_map.cpp @@ -44,6 +44,28 @@ bool is_aligned(const void* data) { return addr % kMinimumAlignment == 0; } +Result get_named_data( + const char* key, + const flatbuffers::Vector< + flatbuffers::Offset>* named_data) { + // Linear search by name. + if (named_data == nullptr) { + return Error::NotFound; + } + for (int i = 0; i < named_data->size(); i++) { + if (std::strcmp(named_data->Get(i)->key()->c_str(), key) == 0) { + const auto* metadata = named_data->Get(i); + ET_CHECK_OR_RETURN_ERROR( + metadata->segment_index() >= 0, + InvalidExternalData, + "Invalid segment_index %d; malformed PTD file.", + metadata->segment_index()); + return metadata; + } + } + return Error::NotFound; +} + Result get_flat_tensor_metadata( const char* key, const flatbuffers::Vector< @@ -109,6 +131,39 @@ ET_NODISCARD Result FlatTensorDataMap::get_metadata( ET_NODISCARD Result FlatTensorDataMap::get_data( const char* key) const { + // TODO(lfq): consolidate named_data and tensors. + // Check named data. + Result named_data = + get_named_data(key, flat_tensor_->named_data()); + if (named_data.ok()) { + size_t segment_index = named_data.get()->segment_index(); + ET_CHECK_OR_RETURN_ERROR( + segment_index < flat_tensor_->segments()->size(), + InvalidExternalData, + "Invalid segment_index %zu; malformed PTD file.", + segment_index); + + size_t segment_offset = + flat_tensor_->segments()->Get(segment_index)->offset(); + size_t segment_size = flat_tensor_->segments()->Get(segment_index)->size(); + ET_CHECK_OR_RETURN_ERROR( + segment_offset < + header_.segment_base_offset + header_.segment_data_size, + InvalidExternalData, + "Invalid segment offset %zu is larger than the segment_base_offset + segment_data_size %" PRIu64 + "; malformed PTD file.", + segment_offset, + header_.segment_base_offset + header_.segment_data_size); + return loader_->load( + /*offset=*/header_.segment_base_offset + segment_offset, + segment_size, + DataLoader::SegmentInfo(DataLoader::SegmentInfo::Type::External)); + } + if (named_data.error() != Error::NotFound) { + return named_data.error(); + } + + // Check tensors, if named data is not found. Result metadata = get_flat_tensor_metadata(key, flat_tensor_->tensors()); if (!metadata.ok()) { @@ -179,16 +234,33 @@ ET_NODISCARD Error FlatTensorDataMap::load_data_into( } ET_NODISCARD Result FlatTensorDataMap::get_num_keys() const { - return flat_tensor_->tensors()->size(); + // TODO(lfq): consolidate named_data and tensors. + if (flat_tensor_->named_data() == nullptr) { + return flat_tensor_->tensors()->size(); + } + return flat_tensor_->named_data()->size() + flat_tensor_->tensors()->size(); } ET_NODISCARD Result FlatTensorDataMap::get_key( size_t index) const { - if (index < 0 || index >= flat_tensor_->tensors()->size()) { - return Error::InvalidArgument; - } + // TODO(lfq): consolidate named_data and tensors. + // For now, iterate over named_data and then flat_tensor. + size_t num_keys = get_num_keys().get(); + ET_CHECK_OR_RETURN_ERROR( + index >= 0 && index < num_keys, + InvalidArgument, + "Index %zu out of range of size %zu", + index, + num_keys); - return flat_tensor_->tensors()->Get(index)->fully_qualified_name()->c_str(); + if (flat_tensor_->named_data() != nullptr && index < flat_tensor_->named_data()->size()) { + return flat_tensor_->named_data()->Get(index)->key()->c_str(); + } else { + if (flat_tensor_->named_data() != nullptr) { + index = index - flat_tensor_->named_data()->size(); + } + return flat_tensor_->tensors()->Get(index)->fully_qualified_name()->c_str(); + } } /* static */ Result FlatTensorDataMap::load( From 794d0daabcea1b7caa903179e582ebbd188abc47 Mon Sep 17 00:00:00 2001 From: lucylq Date: Mon, 28 Apr 2025 16:40:47 -0700 Subject: [PATCH 2/2] Update on "Support named_data in flat_tensor" Currently flat_tensor ndm only accounts for tensors in get_data, get_num_keys, get_key functions. Add support to return named_data values as well. TODO:consolidate tensors and named_data into one structure in the flatbuffer. This will simplify all the serialization and runtime code. Differential Revision: [D73679683](https://our.internmc.facebook.com/intern/diff/D73679683/) [ghstack-poisoned] --- extension/flat_tensor/flat_tensor_data_map.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/extension/flat_tensor/flat_tensor_data_map.cpp b/extension/flat_tensor/flat_tensor_data_map.cpp index 5f8f8e9d6e3..c5590cb61b1 100644 --- a/extension/flat_tensor/flat_tensor_data_map.cpp +++ b/extension/flat_tensor/flat_tensor_data_map.cpp @@ -253,7 +253,8 @@ ET_NODISCARD Result FlatTensorDataMap::get_key( index, num_keys); - if (flat_tensor_->named_data() != nullptr && index < flat_tensor_->named_data()->size()) { + if (flat_tensor_->named_data() != nullptr && + index < flat_tensor_->named_data()->size()) { return flat_tensor_->named_data()->Get(index)->key()->c_str(); } else { if (flat_tensor_->named_data() != nullptr) {