Skip to content

Commit 2fb26e8

Browse files
committed
Support named_data in flat_tensor
Pull Request resolved: #10527 Currently flat_tensor ndm only accounts for tensors in get_data, get_num_keys, get_key functions. Add support to return named_data values as well. TODO:consolidate tensors and named_data into one structure in the flatbuffer. This will simplify all the serialization and runtime code. ghstack-source-id: 280786211 Differential Revision: [D73679683](https://our.internmc.facebook.com/intern/diff/D73679683/)
1 parent c5dd476 commit 2fb26e8

File tree

1 file changed

+78
-5
lines changed

1 file changed

+78
-5
lines changed

extension/flat_tensor/flat_tensor_data_map.cpp

+78-5
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,28 @@ bool is_aligned(const void* data) {
4444
return addr % kMinimumAlignment == 0;
4545
}
4646

47+
Result<const flat_tensor_flatbuffer::NamedData*> get_named_data(
48+
const char* key,
49+
const flatbuffers::Vector<
50+
flatbuffers::Offset<flat_tensor_flatbuffer::NamedData>>* named_data) {
51+
// Linear search by name.
52+
if (named_data == nullptr) {
53+
return Error::NotFound;
54+
}
55+
for (int i = 0; i < named_data->size(); i++) {
56+
if (std::strcmp(named_data->Get(i)->key()->c_str(), key) == 0) {
57+
const auto* metadata = named_data->Get(i);
58+
ET_CHECK_OR_RETURN_ERROR(
59+
metadata->segment_index() >= 0,
60+
InvalidExternalData,
61+
"Invalid segment_index %d; malformed PTD file.",
62+
metadata->segment_index());
63+
return metadata;
64+
}
65+
}
66+
return Error::NotFound;
67+
}
68+
4769
Result<const flat_tensor_flatbuffer::TensorMetadata*> get_flat_tensor_metadata(
4870
const char* key,
4971
const flatbuffers::Vector<
@@ -109,6 +131,39 @@ ET_NODISCARD Result<const TensorLayout> FlatTensorDataMap::get_metadata(
109131

110132
ET_NODISCARD Result<FreeableBuffer> FlatTensorDataMap::get_data(
111133
const char* key) const {
134+
// TODO(lfq): consolidate named_data and tensors.
135+
// Check named data.
136+
Result<const flat_tensor_flatbuffer::NamedData*> named_data =
137+
get_named_data(key, flat_tensor_->named_data());
138+
if (named_data.ok()) {
139+
size_t segment_index = named_data.get()->segment_index();
140+
ET_CHECK_OR_RETURN_ERROR(
141+
segment_index < flat_tensor_->segments()->size(),
142+
InvalidExternalData,
143+
"Invalid segment_index %zu; malformed PTD file.",
144+
segment_index);
145+
146+
size_t segment_offset =
147+
flat_tensor_->segments()->Get(segment_index)->offset();
148+
size_t segment_size = flat_tensor_->segments()->Get(segment_index)->size();
149+
ET_CHECK_OR_RETURN_ERROR(
150+
segment_offset <
151+
header_.segment_base_offset + header_.segment_data_size,
152+
InvalidExternalData,
153+
"Invalid segment offset %zu is larger than the segment_base_offset + segment_data_size %" PRIu64
154+
"; malformed PTD file.",
155+
segment_offset,
156+
header_.segment_base_offset + header_.segment_data_size);
157+
return loader_->load(
158+
/*offset=*/header_.segment_base_offset + segment_offset,
159+
segment_size,
160+
DataLoader::SegmentInfo(DataLoader::SegmentInfo::Type::External));
161+
}
162+
if (named_data.error() != Error::NotFound) {
163+
return named_data.error();
164+
}
165+
166+
// Check tensors, if named data is not found.
112167
Result<const flat_tensor_flatbuffer::TensorMetadata*> metadata =
113168
get_flat_tensor_metadata(key, flat_tensor_->tensors());
114169
if (!metadata.ok()) {
@@ -179,16 +234,34 @@ ET_NODISCARD Error FlatTensorDataMap::load_data_into(
179234
}
180235

181236
ET_NODISCARD Result<size_t> FlatTensorDataMap::get_num_keys() const {
182-
return flat_tensor_->tensors()->size();
237+
// TODO(lfq): consolidate named_data and tensors.
238+
if (flat_tensor_->named_data() == nullptr) {
239+
return flat_tensor_->tensors()->size();
240+
}
241+
return flat_tensor_->named_data()->size() + flat_tensor_->tensors()->size();
183242
}
184243

185244
ET_NODISCARD Result<const char*> FlatTensorDataMap::get_key(
186245
size_t index) const {
187-
if (index < 0 || index >= flat_tensor_->tensors()->size()) {
188-
return Error::InvalidArgument;
189-
}
246+
// TODO(lfq): consolidate named_data and tensors.
247+
// For now, iterate over named_data and then flat_tensor.
248+
size_t num_keys = get_num_keys().get();
249+
ET_CHECK_OR_RETURN_ERROR(
250+
index >= 0 && index < num_keys,
251+
InvalidArgument,
252+
"Index %zu out of range of size %zu",
253+
index,
254+
num_keys);
190255

191-
return flat_tensor_->tensors()->Get(index)->fully_qualified_name()->c_str();
256+
if (flat_tensor_->named_data() != nullptr &&
257+
index < flat_tensor_->named_data()->size()) {
258+
return flat_tensor_->named_data()->Get(index)->key()->c_str();
259+
} else {
260+
if (flat_tensor_->named_data() != nullptr) {
261+
index = index - flat_tensor_->named_data()->size();
262+
}
263+
return flat_tensor_->tensors()->Get(index)->fully_qualified_name()->c_str();
264+
}
192265
}
193266

194267
/* static */ Result<FlatTensorDataMap> FlatTensorDataMap::load(

0 commit comments

Comments
 (0)