Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion backends/cuda/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ find_package_torch()

# CUDA-specific AOTI functionality
set(_aoti_cuda_sources runtime/cuda_backend.cpp runtime/shims/memory.cpp
runtime/shims/tensor_attribute.cpp
runtime/shims/tensor_attribute.cpp runtime/guard.cpp
)
add_library(aoti_cuda STATIC ${_aoti_cuda_sources})
target_include_directories(
Expand Down
2 changes: 2 additions & 0 deletions backends/cuda/runtime/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@ oncall("executorch")
runtime.cxx_library(
name = "runtime_shims",
srcs = [
"guard.cpp",
"shims/memory.cpp",
"shims/tensor_attribute.cpp",
],
headers = [
"guard.h",
"shims/memory.h",
"shims/tensor_attribute.h",
"utils.h",
Expand Down
151 changes: 151 additions & 0 deletions backends/cuda/runtime/guard.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <executorch/backends/cuda/runtime/guard.h>
#include <executorch/runtime/platform/log.h>

namespace executorch {
namespace backends {
namespace cuda {

namespace {
// Thread-local stream storage (private to this file)
thread_local std::unordered_map<DeviceIndex, cudaStream_t> current_streams_;
} // namespace

Error setCurrentCUDAStream(cudaStream_t stream, DeviceIndex device_index) {
if (device_index == -1) {
// Get current device if not specified
int current_device;
ET_CUDA_CHECK_OR_RETURN_ERROR(cudaGetDevice(&current_device));
device_index = current_device;
}

current_streams_[device_index] = stream;
return Error::Ok;
}

Result<cudaStream_t> getCurrentCUDAStream(DeviceIndex device_index) {
if (device_index == -1) {
int current_device;
ET_CUDA_CHECK_OR_RETURN_ERROR(cudaGetDevice(&current_device));
device_index = current_device;
}

auto it = current_streams_.find(device_index);
if (it != current_streams_.end()) {
return it->second;
}

cudaStream_t stream;
ET_CUDA_CHECK_OR_RETURN_ERROR(cudaStreamCreate(&stream));
setCurrentCUDAStream(stream, device_index);
return stream;
}

CUDAGuard::CUDAGuard(CUDAGuard&& other) noexcept
: original_device_index_(other.original_device_index_),
current_device_index_(other.current_device_index_) {
// Mark the moved-from object as "already restored" so its destructor doesn't
// try to restore the device
other.original_device_index_ = other.current_device_index_;
}

CUDAGuard::~CUDAGuard() {
if (original_device_index_ != current_device_index_) {
cudaError_t err = cudaSetDevice(original_device_index_);
if (err != cudaSuccess) {
ET_LOG(
Error,
"~CUDAGuard: Failed to restore device to %d: %s",
original_device_index_,
cudaGetErrorString(err));
}
}
}

Error CUDAGuard::set_index(DeviceIndex device_index) {
int orig_index = -1;
ET_CUDA_CHECK_OR_RETURN_ERROR(cudaGetDevice(&orig_index));

original_device_index_ = orig_index;
current_device_index_ = device_index;

if (current_device_index_ != original_device_index_) {
ET_CUDA_CHECK_OR_RETURN_ERROR(cudaSetDevice(current_device_index_));
}

return Error::Ok;
}

Result<CUDAGuard> CUDAGuard::create(DeviceIndex device_index) {
CUDAGuard guard; // Fixed: Removed () to create a variable, not a function
ET_CHECK_OK_OR_RETURN_ERROR(guard.set_index(device_index));
return guard;
}

CUDAStreamGuard::CUDAStreamGuard(CUDAStreamGuard&& other) noexcept
: device_guard_(std::move(other.device_guard_)),
original_stream_(other.original_stream_),
current_stream_(other.current_stream_),
device_index_(other.device_index_) {
// Mark the moved-from object as "already restored" so its destructor doesn't
// try to restore the stream
other.original_stream_ = other.current_stream_;
}

CUDAStreamGuard::~CUDAStreamGuard() {
// Restore the original stream unless this object was moved-from.
// After a move, original_stream_ == current_stream_, which indicates
// the moved-from object should not restore.
// Note: nullptr is a valid stream value (represents the default stream),
// so we must restore even if original_stream_ is nullptr.
if (original_stream_ != current_stream_) {
Error err = setCurrentCUDAStream(original_stream_, device_index_);
if (err != Error::Ok) {
ET_LOG(
Error,
"~CUDAStreamGuard: Failed to restore stream for device %d",
device_index_);
}
}
}

Error CUDAStreamGuard::set_stream(
cudaStream_t stream,
DeviceIndex device_index) {
auto result = getCurrentCUDAStream(device_index);
if (!result.ok()) {
ET_LOG(Error, "Failed to get current stream for device %d", device_index);
return result.error();
}

original_stream_ = result.get();
current_stream_ = stream;
device_index_ = device_index;

ET_CHECK_OK_OR_RETURN_ERROR(setCurrentCUDAStream(stream, device_index));

return Error::Ok;
}

Result<CUDAStreamGuard> CUDAStreamGuard::create(
cudaStream_t stream,
DeviceIndex device_index) {
auto guard_result = CUDAGuard::create(device_index);
ET_CHECK_OK_OR_RETURN_ERROR(guard_result.error());

CUDAStreamGuard stream_guard(std::move(guard_result.get()));
ET_CHECK_OK_OR_RETURN_ERROR(stream_guard.set_stream(stream, device_index));

return stream_guard;
}

} // namespace cuda
} // namespace backends
} // namespace executorch
195 changes: 195 additions & 0 deletions backends/cuda/runtime/guard.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,195 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#pragma once

#include <cuda_runtime.h>
#include <executorch/backends/cuda/runtime/utils.h>
#include <executorch/runtime/core/error.h>
#include <executorch/runtime/core/result.h>
#include <cstdint>

namespace executorch {
namespace backends {
namespace cuda {

using executorch::runtime::Error;
using executorch::runtime::Result;

// Type alias for device index
using DeviceIndex = int32_t;

/**
* Set the current CUDA stream for the specified device.
*
* @param stream The CUDA stream to set as current
* @param device_index The device index (-1 to use current device)
* @return Error code indicating success or failure
*/
Error setCurrentCUDAStream(cudaStream_t stream, DeviceIndex device_index = -1);

/**
* Get the current CUDA stream for the specified device.
* If no stream has been set, creates a new stream and sets it as current.
*
* @param device_index The device index (-1 to use current device)
* @return Result containing the current stream on success, or an error code on
* failure
*/
Result<cudaStream_t> getCurrentCUDAStream(DeviceIndex device_index = -1);

/**
* RAII guard that sets the current CUDA device and restores it on destruction.
* This ensures that the device is properly restored even if an exception
* occurs.
*
*/
class CUDAGuard {
private:
/**
* Private constructor - use create() factory method instead.
*/
explicit CUDAGuard()
: original_device_index_(-1), current_device_index_(-1) {}

public:
/**
* Factory method to create a CUDAGuard.
*
* @param device_index The device index to set as current
* @return Result containing the guard on success, or an error code on failure
*/
static Result<CUDAGuard> create(DeviceIndex device_index);

// Copy is not allowed
CUDAGuard(const CUDAGuard&) = delete;
CUDAGuard& operator=(const CUDAGuard&) = delete;

// Move constructor and assignment
CUDAGuard(CUDAGuard&& other) noexcept;
CUDAGuard& operator=(CUDAGuard&& other) = delete;

/**
* Destructor that restores the original device if necessary.
*/
~CUDAGuard();

/**
* Sets the CUDA device to the given device index.
*
* @param device_index The device index to set as current
* @return Error code indicating success or failure
*/
Error set_index(DeviceIndex device_index);

/**
* Get the original device index before the guard was created.
*
* @return The original device index
*/
DeviceIndex original_device() const {
return original_device_index_;
}

/**
* Get the current device index.
*
* @return The current device index
*/
DeviceIndex current_device() const {
return current_device_index_;
}

private:
/// The original device before this guard was created
DeviceIndex original_device_index_;
/// The current device managed by this guard
DeviceIndex current_device_index_;
};

/**
* RAII guard that sets the current CUDA device and stream, restoring both on
* destruction. This is useful for temporarily switching to a different device
* and stream.
*
*/
class CUDAStreamGuard {
private:
// Private constructor that takes a CUDAGuard
explicit CUDAStreamGuard(CUDAGuard&& guard)
: device_guard_(std::move(guard)),
original_stream_(nullptr),
current_stream_(nullptr),
device_index_(-1) {}

public:
/**
* Factory method to create a CUDAStreamGuard.
*
* @param stream The CUDA stream to set as current
* @param device_index The device index for the stream
* @return Result containing the guard on success, or an error code on failure
*/
static Result<CUDAStreamGuard> create(
cudaStream_t stream,
DeviceIndex device_index);

// Copy is not allowed
CUDAStreamGuard(const CUDAStreamGuard&) = delete;
CUDAStreamGuard& operator=(const CUDAStreamGuard&) = delete;

// Move constructor and assignment
CUDAStreamGuard(CUDAStreamGuard&& other) noexcept;
CUDAStreamGuard& operator=(CUDAStreamGuard&& other) noexcept = delete;

/**
* Destructor that restores the original stream and device.
*/
~CUDAStreamGuard();

/**
* Sets the CUDA stream to the given stream on the specified device.
*
* @param stream The CUDA stream to set as current
* @param device_index The device index for the stream
* @return Error code indicating success or failure
*/
Error set_stream(cudaStream_t stream, DeviceIndex device_index);

/**
* Get the current guarded stream.
*
* @return The current stream
*/
cudaStream_t stream() const {
return current_stream_;
}

/**
* Get the device index being guarded.
*
* @return The device index
*/
DeviceIndex device_index() const {
return device_index_;
}

private:
/// The device guard that handles device switching
CUDAGuard device_guard_;
/// The original stream that was current before this guard
cudaStream_t original_stream_ = nullptr;
/// The current stream being guarded
cudaStream_t current_stream_ = nullptr;
/// The device index for this stream guard
DeviceIndex device_index_;
};

} // namespace cuda
} // namespace backends
} // namespace executorch
6 changes: 6 additions & 0 deletions backends/cuda/runtime/tests/TARGETS
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
load("@fbcode_macros//build_defs:cpp_unittest.bzl", "cpp_unittest")
load(":targets.bzl", "define_common_targets")

oncall("executorch")

define_common_targets()
Loading
Loading