Skip to content

Commit 91db1af

Browse files
authored
[core][cgraph] Remove air/torch_util deps in doc tests (#51312)
Removing the AIR dependency as it is not user facing. Signed-off-by: Rui Qiao <ruisearch42@gmail.com>
1 parent 2f72014 commit 91db1af

File tree

2 files changed

+2
-12
lines changed

2 files changed

+2
-12
lines changed

doc/source/ray-core/doc_code/cgraph_overlap.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,19 +3,14 @@
33
import time
44
import torch
55
from ray.dag import InputNode, MultiOutputNode
6-
from ray.experimental.channel.torch_tensor_type import TorchTensorType
7-
from ray.air._internal import torch_utils
86

97

108
@ray.remote(num_cpus=0, num_gpus=1)
119
class TorchTensorWorker:
12-
def __init__(self):
13-
self.device = torch_utils.get_devices()[0]
14-
1510
def send(self, shape, dtype, value: int, send_tensor=True):
1611
if not send_tensor:
1712
return 1
18-
return torch.ones(shape, dtype=dtype, device=self.device) * value
13+
return torch.ones(shape, dtype=dtype, device="cuda") * value
1914

2015
def recv_and_matmul(self, two_d_tensor):
2116
"""
@@ -27,7 +22,6 @@ def recv_and_matmul(self, two_d_tensor):
2722
# Check that tensor got loaded to the correct device.
2823
assert two_d_tensor.dim() == 2
2924
assert two_d_tensor.size(0) == two_d_tensor.size(1)
30-
assert two_d_tensor.device == self.device
3125
torch.matmul(two_d_tensor, two_d_tensor)
3226
return (two_d_tensor[0][0].item(), two_d_tensor.shape, two_d_tensor.dtype)
3327

doc/source/ray-core/doc_code/cgraph_profiling.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,16 +23,12 @@
2323
import ray
2424
import torch
2525
from ray.dag import InputNode
26-
from ray.air._internal import torch_utils
2726

2827

2928
@ray.remote(num_gpus=1, runtime_env={"nsight": "default"})
3029
class RayActor:
31-
def __init__(self):
32-
self.device = torch_utils.get_devices()[0]
33-
3430
def send(self, shape, dtype, value: int):
35-
return torch.ones(shape, dtype=dtype, device=self.device) * value
31+
return torch.ones(shape, dtype=dtype, device="cuda") * value
3632

3733
def recv(self, tensor):
3834
return (tensor[0].item(), tensor.shape, tensor.dtype)

0 commit comments

Comments
 (0)