@@ -181,8 +181,8 @@ bool OrtBackend::InitFromOnnx(const std::string& model_file,
181181 return true ;
182182}
183183
184- void OrtBackend::CopyToCpu (const Ort::Value& value, FDTensor* tensor,
185- const std::string& name) {
184+ void OrtBackend::OrtValueToFDTensor (const Ort::Value& value, FDTensor* tensor,
185+ const std::string& name, bool copy_to_fd ) {
186186 const auto info = value.GetTensorTypeAndShapeInfo ();
187187 const auto data_type = info.GetElementType ();
188188 size_t numel = info.GetElementCount ();
@@ -210,12 +210,21 @@ void OrtBackend::CopyToCpu(const Ort::Value& value, FDTensor* tensor,
210210 " Unrecognized data type of %d while calling OrtBackend::CopyToCpu()." ,
211211 data_type);
212212 }
213- tensor->Resize (shape, dtype, name);
214- memcpy (tensor->MutableData (), value.GetTensorData <void *>(), numel);
213+ const void * value_ptr = value.GetTensorData <void *>();
214+ if (copy_to_fd) {
215+ tensor->Resize (shape, dtype, name);
216+ memcpy (tensor->MutableData (), value_ptr, numel);
217+ } else {
218+ tensor->name = name;
219+ tensor->SetExternalData (
220+ shape, dtype,
221+ const_cast <void *>(value_ptr), Device::CPU);
222+ }
215223}
216224
217225bool OrtBackend::Infer (std::vector<FDTensor>& inputs,
218- std::vector<FDTensor>* outputs) {
226+ std::vector<FDTensor>* outputs,
227+ bool copy_to_fd) {
219228 if (inputs.size () != inputs_desc_.size ()) {
220229 FDERROR << " [OrtBackend] Size of the inputs(" << inputs.size ()
221230 << " ) should keep same with the inputs of this model("
@@ -243,11 +252,12 @@ bool OrtBackend::Infer(std::vector<FDTensor>& inputs,
243252 return false ;
244253 }
245254
246- // Copy result after inference
255+ // Convert result after inference
247256 std::vector<Ort::Value> ort_outputs = binding_->GetOutputValues ();
248257 outputs->resize (ort_outputs.size ());
249258 for (size_t i = 0 ; i < ort_outputs.size (); ++i) {
250- CopyToCpu (ort_outputs[i], &((*outputs)[i]), outputs_desc_[i].name );
259+ OrtValueToFDTensor (ort_outputs[i], &((*outputs)[i]),
260+ outputs_desc_[i].name , copy_to_fd);
251261 }
252262
253263 return true ;
0 commit comments