Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions backends/aoti/aoti_model_container.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,8 @@ struct AOTIDelegateHandle {
void* so_handle;
std::string so_path;
AOTInductorModelContainerHandle container_handle;
void* cuda_stream; // cudaStream_t stored as void* to avoid CUDA header
// dependency
};

} // namespace aoti
Expand Down
22 changes: 20 additions & 2 deletions backends/cuda/runtime/cuda_backend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
* LICENSE file in the root directory of this source tree.
*/

#include <cuda_runtime.h>
#include <dlfcn.h>
#include <executorch/runtime/backend/interface.h>
#include <executorch/runtime/core/error.h>
Expand All @@ -16,14 +17,14 @@

#include <filesystem>
#include <fstream>
#include <iostream>
#include <string>
#include <vector>

// Include our shim layer headers
#include <executorch/backends/aoti/aoti_model_container.h>
#include <executorch/backends/aoti/common_shims.h>
#include <executorch/backends/cuda/runtime/shims/memory.h>
#include <executorch/backends/cuda/runtime/utils.h>

namespace executorch {
namespace backends {
Expand Down Expand Up @@ -182,6 +183,12 @@ class ET_EXPERIMENTAL CudaBackend final
handle->so_handle = so_handle;
handle->so_path = so_path.string();
handle->container_handle = container_handle;

// Create a CUDA stream for asynchronous execution
cudaStream_t cuda_stream;
ET_CUDA_CHECK_OR_RETURN_ERROR(cudaStreamCreate(&cuda_stream));
handle->cuda_stream = static_cast<void*>(cuda_stream);

return (DelegateHandle*)handle; // Return the handle post-processing
}

Expand Down Expand Up @@ -288,7 +295,7 @@ class ET_EXPERIMENTAL CudaBackend final
n_inputs,
gpu_outputs.data(), // Use GPU output tensors
n_outputs,
nullptr, // Pass the actual CUDA stream!
handle->cuda_stream, // Pass the actual CUDA stream
nullptr); // proxy_executor_handle can remain nullptr

if (error != Error::Ok) {
Expand Down Expand Up @@ -334,6 +341,17 @@ class ET_EXPERIMENTAL CudaBackend final
}
AOTIDelegateHandle* handle = (AOTIDelegateHandle*)handle_;

// Destroy the CUDA stream if it exists
if (handle->cuda_stream != nullptr) {
cudaStream_t cuda_stream = static_cast<cudaStream_t>(handle->cuda_stream);
cudaError_t stream_err = cudaStreamDestroy(cuda_stream);
ET_CHECK_OR_LOG_ERROR(
stream_err == cudaSuccess,
"Failed to destroy CUDA stream: %s",
cudaGetErrorString(stream_err));
handle->cuda_stream = nullptr;
}

// Delete the container BEFORE closing the shared library
if (handle->container_handle != nullptr) {
AOTIRuntimeError delete_result =
Expand Down
22 changes: 22 additions & 0 deletions runtime/platform/log.h
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,20 @@ using ::executorch::runtime::LogLevel;
##__VA_ARGS__); \
} \
} while (0)

/**
* Check a condition and log an error message if the condition is false.
*
* @param[in] _condition The condition to check.
* @param[in] _format Log message format string.
*/
#define ET_CHECK_OR_LOG_ERROR(_condition, _format, ...) \
do { \
if (!(_condition)) { \
ET_LOG(Error, _format, ##__VA_ARGS__); \
} \
} while (0)

#else // ET_LOG_ENABLED

/**
Expand All @@ -191,4 +205,12 @@ using ::executorch::runtime::LogLevel;
*/
#define ET_LOG(_level, _format, ...) ((void)0)

/**
* Check a condition and log an error message if the condition is false.
*
* @param[in] _condition The condition to check.
* @param[in] _format Log message format string.
*/
#define ET_CHECK_OR_LOG_ERROR(_condition, _format, ...) ((void)0)

#endif // ET_LOG_ENABLED
Loading