|
1 | 1 | from typing import TYPE_CHECKING, List, Optional, Tuple
|
2 | 2 |
|
3 | 3 | import ray
|
| 4 | +import os |
4 | 5 |
|
5 | 6 | if TYPE_CHECKING:
|
6 | 7 | import torch
|
@@ -95,22 +96,66 @@ def get_actor_node(actor: Optional["ray.actor.ActorHandle"]) -> str:
|
95 | 96 | )
|
96 | 97 |
|
97 | 98 |
|
98 |
| -def get_default_torch_device(*, allow_cpu: bool) -> "torch.device": |
99 |
| - """Get the default torch device inside this actor or driver. |
| 99 | +def get_cuda_devices() -> List["torch.device"]: |
| 100 | + """Gets the correct torch cuda device list configured for this process. |
100 | 101 |
|
101 |
| - If any GPUs are available, the default device will be cuda:0 and we will rely on |
102 |
| - torch to handle mapping CUDA_VISIBLE_DEVICES to a physical device. |
103 |
| -
|
104 |
| - If no GPUs are available, a CPU device will be returned if allow_cpu is true, else |
105 |
| - the function will raise a RuntimeError. |
| 102 | + Assumes that `CUDA_VISIBLE_DEVICES` is set and is a |
| 103 | + superset of the `ray.get_gpu_ids()`. |
106 | 104 | """
|
| 105 | + # Note: currently this method replicates the logic from |
| 106 | + # `CUDATorchDeviceManager.get_devices()`. |
| 107 | + # TODO(rui): tailor and clean up the logic for proper use in |
| 108 | + # Compiled Graphs. |
107 | 109 | import torch
|
108 | 110 |
|
109 |
| - accelerator_ids = ray.get_runtime_context().get_accelerator_ids() |
110 |
| - if not accelerator_ids.get("GPU", []): |
111 |
| - if allow_cpu: |
112 |
| - return torch.device("cpu") |
| 111 | + # GPU IDs are assigned by Ray after you specify "use_gpu" |
| 112 | + # GPU `ray.get_gpu_ids()` may return ints or may return strings. |
| 113 | + # We should always convert to strings. |
| 114 | + gpu_ids = [str(id) for id in ray.get_gpu_ids()] |
| 115 | + |
| 116 | + device_ids = [] |
| 117 | + |
| 118 | + if len(gpu_ids) > 0: |
| 119 | + cuda_visible_str = os.environ.get("CUDA_VISIBLE_DEVICES", "") |
| 120 | + if cuda_visible_str and cuda_visible_str != "NoDevFiles": |
| 121 | + cuda_visible_list = cuda_visible_str.split(",") |
113 | 122 | else:
|
114 |
| - raise RuntimeError("No CUDA device available.") |
| 123 | + cuda_visible_list = [] |
| 124 | + |
| 125 | + # By default, there should only be one GPU ID if `use_gpu=True`. |
| 126 | + # If there are multiple GPUs, return a list of devices. |
| 127 | + # If using fractional GPUs, these IDs are not guaranteed |
| 128 | + # to be unique across different processes. |
| 129 | + for gpu_id in gpu_ids: |
| 130 | + try: |
| 131 | + device_ids.append(cuda_visible_list.index(gpu_id)) |
| 132 | + except IndexError: |
| 133 | + raise RuntimeError( |
| 134 | + "CUDA_VISIBLE_DEVICES set incorrectly. " |
| 135 | + f"Got {cuda_visible_str}, expected to include {gpu_id}. " |
| 136 | + "Did you override the `CUDA_VISIBLE_DEVICES` environment" |
| 137 | + " variable? If not, please help file an issue on Github." |
| 138 | + ) |
| 139 | + |
| 140 | + else: |
| 141 | + # If called on the driver or outside of Ray Train, return the |
| 142 | + # 0th device. |
| 143 | + device_ids.append(0) |
| 144 | + |
| 145 | + return [torch.device(f"cuda:{device_id}") for device_id in device_ids] |
| 146 | + |
115 | 147 |
|
116 |
| - return torch.device("cuda:0") |
| 148 | +def get_devices() -> List["torch.device"]: |
| 149 | + """Gets the correct torch device list configured for this process. |
| 150 | +
|
| 151 | + Returns a list of torch devices allocated for the current worker. |
| 152 | + If no devices are assigned, then it returns a list with a single CPU device. |
| 153 | + """ |
| 154 | + |
| 155 | + import torch |
| 156 | + |
| 157 | + gpu_ids = [str(id) for id in ray.get_gpu_ids()] |
| 158 | + if len(gpu_ids) > 0: |
| 159 | + return get_cuda_devices() |
| 160 | + else: |
| 161 | + return [torch.device("cpu")] |
0 commit comments