@@ -181,8 +181,8 @@ bool OrtBackend::InitFromOnnx(const std::string& model_file,
181
181
return true ;
182
182
}
183
183
184
- void OrtBackend::CopyToCpu (const Ort::Value& value, FDTensor* tensor,
185
- const std::string& name) {
184
+ void OrtBackend::OrtValueToFDTensor (const Ort::Value& value, FDTensor* tensor,
185
+ const std::string& name, bool copy_to_fd ) {
186
186
const auto info = value.GetTensorTypeAndShapeInfo ();
187
187
const auto data_type = info.GetElementType ();
188
188
size_t numel = info.GetElementCount ();
@@ -210,12 +210,21 @@ void OrtBackend::CopyToCpu(const Ort::Value& value, FDTensor* tensor,
210
210
" Unrecognized data type of %d while calling OrtBackend::CopyToCpu()." ,
211
211
data_type);
212
212
}
213
- tensor->Resize (shape, dtype, name);
214
- memcpy (tensor->MutableData (), value.GetTensorData <void *>(), numel);
213
+ const void * value_ptr = value.GetTensorData <void *>();
214
+ if (copy_to_fd) {
215
+ tensor->Resize (shape, dtype, name);
216
+ memcpy (tensor->MutableData (), value_ptr, numel);
217
+ } else {
218
+ tensor->name = name;
219
+ tensor->SetExternalData (
220
+ shape, dtype,
221
+ const_cast <void *>(value_ptr), Device::CPU);
222
+ }
215
223
}
216
224
217
225
bool OrtBackend::Infer (std::vector<FDTensor>& inputs,
218
- std::vector<FDTensor>* outputs) {
226
+ std::vector<FDTensor>* outputs,
227
+ bool copy_to_fd) {
219
228
if (inputs.size () != inputs_desc_.size ()) {
220
229
FDERROR << " [OrtBackend] Size of the inputs(" << inputs.size ()
221
230
<< " ) should keep same with the inputs of this model("
@@ -243,11 +252,12 @@ bool OrtBackend::Infer(std::vector<FDTensor>& inputs,
243
252
return false ;
244
253
}
245
254
246
- // Copy result after inference
255
+ // Convert result after inference
247
256
std::vector<Ort::Value> ort_outputs = binding_->GetOutputValues ();
248
257
outputs->resize (ort_outputs.size ());
249
258
for (size_t i = 0 ; i < ort_outputs.size (); ++i) {
250
- CopyToCpu (ort_outputs[i], &((*outputs)[i]), outputs_desc_[i].name );
259
+ OrtValueToFDTensor (ort_outputs[i], &((*outputs)[i]),
260
+ outputs_desc_[i].name , copy_to_fd);
251
261
}
252
262
253
263
return true ;
0 commit comments