|
1 |
| -// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. |
| 1 | +// Copyright (c) 2021-2024, NVIDIA CORPORATION. All rights reserved. |
2 | 2 | //
|
3 | 3 | // Redistribution and use in source and binary forms, with or without
|
4 | 4 | // modification, are permitted provided that the following conditions
|
@@ -847,9 +847,15 @@ TRITONBACKEND_ModelInstanceExecute(
|
847 | 847 | if (input_memory_type == TRITONSERVER_MEMORY_GPU) {
|
848 | 848 | ipbuffer_vec.resize(input_element_cnt);
|
849 | 849 | ipbuffer_int = ipbuffer_vec.data();
|
850 |
| - cudaMemcpy( |
851 |
| - const_cast<int32_t*>(ipbuffer_int), input_buffer, input_byte_size, |
852 |
| - cudaMemcpyDeviceToHost); |
| 850 | + LOG_IF_CUDA_ERROR( |
| 851 | + cudaMemcpyAsync( |
| 852 | + const_cast<int32_t*>(ipbuffer_int), input_buffer, input_byte_size, |
| 853 | + cudaMemcpyDeviceToHost, instance_state->CudaStream()), |
| 854 | + "failed to copy buffer from Device to Host"); |
| 855 | + |
| 856 | + LOG_IF_CUDA_ERROR( |
| 857 | + cudaStreamSynchronize(instance_state->CudaStream()), |
| 858 | + "failed to perform synchronization on cuda stream"); |
853 | 859 | } else {
|
854 | 860 | ipbuffer_int = reinterpret_cast<const int32_t*>(input_buffer);
|
855 | 861 | }
|
@@ -939,9 +945,15 @@ TRITONBACKEND_ModelInstanceExecute(
|
939 | 945 | }
|
940 | 946 |
|
941 | 947 | if (output_memory_type == TRITONSERVER_MEMORY_GPU) {
|
942 |
| - cudaMemcpy( |
943 |
| - output_buffer, const_cast<int32_t*>(obuffer_int), |
944 |
| - buffer_byte_size, cudaMemcpyHostToDevice); |
| 948 | + LOG_IF_CUDA_ERROR( |
| 949 | + cudaMemcpyAsync( |
| 950 | + output_buffer, const_cast<int32_t*>(obuffer_int), |
| 951 | + buffer_byte_size, cudaMemcpyHostToDevice, |
| 952 | + instance_state->CudaStream()), |
| 953 | + "failed to copy buffer from Device to Host"); |
| 954 | + LOG_IF_CUDA_ERROR( |
| 955 | + cudaStreamSynchronize(instance_state->CudaStream()), |
| 956 | + "failed to perform synchronization on cuda stream"); |
945 | 957 | }
|
946 | 958 | }
|
947 | 959 | }
|
|
0 commit comments