@@ -44,6 +44,28 @@ bool is_aligned(const void* data) {
44
44
return addr % kMinimumAlignment == 0 ;
45
45
}
46
46
47
+ Result<const flat_tensor_flatbuffer::NamedData*> get_named_data (
48
+ const char * key,
49
+ const flatbuffers::Vector<
50
+ flatbuffers::Offset<flat_tensor_flatbuffer::NamedData>>* named_data) {
51
+ // Linear search by name.
52
+ if (named_data == nullptr ) {
53
+ return Error::NotFound;
54
+ }
55
+ for (int i = 0 ; i < named_data->size (); i++) {
56
+ if (std::strcmp (named_data->Get (i)->key ()->c_str (), key) == 0 ) {
57
+ const auto * metadata = named_data->Get (i);
58
+ ET_CHECK_OR_RETURN_ERROR (
59
+ metadata->segment_index () >= 0 ,
60
+ InvalidExternalData,
61
+ " Invalid segment_index %d; malformed PTD file." ,
62
+ metadata->segment_index ());
63
+ return metadata;
64
+ }
65
+ }
66
+ return Error::NotFound;
67
+ }
68
+
47
69
Result<const flat_tensor_flatbuffer::TensorMetadata*> get_flat_tensor_metadata (
48
70
const char * key,
49
71
const flatbuffers::Vector<
@@ -109,6 +131,39 @@ ET_NODISCARD Result<const TensorLayout> FlatTensorDataMap::get_metadata(
109
131
110
132
ET_NODISCARD Result<FreeableBuffer> FlatTensorDataMap::get_data (
111
133
const char * key) const {
134
+ // TODO(lfq): consolidate named_data and tensors.
135
+ // Check named data.
136
+ Result<const flat_tensor_flatbuffer::NamedData*> named_data =
137
+ get_named_data (key, flat_tensor_->named_data ());
138
+ if (named_data.ok ()) {
139
+ size_t segment_index = named_data.get ()->segment_index ();
140
+ ET_CHECK_OR_RETURN_ERROR (
141
+ segment_index < flat_tensor_->segments ()->size (),
142
+ InvalidExternalData,
143
+ " Invalid segment_index %zu; malformed PTD file." ,
144
+ segment_index);
145
+
146
+ size_t segment_offset =
147
+ flat_tensor_->segments ()->Get (segment_index)->offset ();
148
+ size_t segment_size = flat_tensor_->segments ()->Get (segment_index)->size ();
149
+ ET_CHECK_OR_RETURN_ERROR (
150
+ segment_offset <
151
+ header_.segment_base_offset + header_.segment_data_size ,
152
+ InvalidExternalData,
153
+ " Invalid segment offset %zu is larger than the segment_base_offset + segment_data_size %" PRIu64
154
+ " ; malformed PTD file." ,
155
+ segment_offset,
156
+ header_.segment_base_offset + header_.segment_data_size );
157
+ return loader_->load (
158
+ /* offset=*/ header_.segment_base_offset + segment_offset,
159
+ segment_size,
160
+ DataLoader::SegmentInfo (DataLoader::SegmentInfo::Type::External));
161
+ }
162
+ if (named_data.error () != Error::NotFound) {
163
+ return named_data.error ();
164
+ }
165
+
166
+ // Check tensors, if named data is not found.
112
167
Result<const flat_tensor_flatbuffer::TensorMetadata*> metadata =
113
168
get_flat_tensor_metadata (key, flat_tensor_->tensors ());
114
169
if (!metadata.ok ()) {
@@ -179,16 +234,34 @@ ET_NODISCARD Error FlatTensorDataMap::load_data_into(
179
234
}
180
235
181
236
ET_NODISCARD Result<size_t > FlatTensorDataMap::get_num_keys () const {
182
- return flat_tensor_->tensors ()->size ();
237
+ // TODO(lfq): consolidate named_data and tensors.
238
+ if (flat_tensor_->named_data () == nullptr ) {
239
+ return flat_tensor_->tensors ()->size ();
240
+ }
241
+ return flat_tensor_->named_data ()->size () + flat_tensor_->tensors ()->size ();
183
242
}
184
243
185
244
ET_NODISCARD Result<const char *> FlatTensorDataMap::get_key (
186
245
size_t index) const {
187
- if (index < 0 || index >= flat_tensor_->tensors ()->size ()) {
188
- return Error::InvalidArgument;
189
- }
246
+ // TODO(lfq): consolidate named_data and tensors.
247
+ // For now, iterate over named_data and then flat_tensor.
248
+ size_t num_keys = get_num_keys ().get ();
249
+ ET_CHECK_OR_RETURN_ERROR (
250
+ index >= 0 && index < num_keys,
251
+ InvalidArgument,
252
+ " Index %zu out of range of size %zu" ,
253
+ index ,
254
+ num_keys);
190
255
191
- return flat_tensor_->tensors ()->Get (index )->fully_qualified_name ()->c_str ();
256
+ if (flat_tensor_->named_data () != nullptr &&
257
+ index < flat_tensor_->named_data ()->size ()) {
258
+ return flat_tensor_->named_data ()->Get (index )->key ()->c_str ();
259
+ } else {
260
+ if (flat_tensor_->named_data () != nullptr ) {
261
+ index = index - flat_tensor_->named_data ()->size ();
262
+ }
263
+ return flat_tensor_->tensors ()->Get (index )->fully_qualified_name ()->c_str ();
264
+ }
192
265
}
193
266
194
267
/* static */ Result<FlatTensorDataMap> FlatTensorDataMap::load (
0 commit comments