Skip to content

Commit 4c29b50

Browse files
authored
Add Ray executor (#687)
* Add Ray executor * Mypy fixes
1 parent 6ee7940 commit 4c29b50

File tree

5 files changed

+143
-0
lines changed

5 files changed

+143
-0
lines changed

.github/workflows/ray-tests.yml

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
name: Ray tests
2+
3+
on:
4+
push:
5+
branches:
6+
- "main"
7+
pull_request:
8+
workflow_dispatch:
9+
10+
concurrency:
11+
group: ${{ github.workflow }}-${{ github.ref }}
12+
cancel-in-progress: true
13+
14+
jobs:
15+
test:
16+
runs-on: ${{ matrix.os }}
17+
strategy:
18+
fail-fast: false
19+
matrix:
20+
os: ["ubuntu-latest"]
21+
python-version: ["3.12"]
22+
23+
steps:
24+
- name: Checkout source
25+
uses: actions/checkout@v3
26+
with:
27+
fetch-depth: 0
28+
29+
- name: Set up Python
30+
uses: actions/setup-python@v3
31+
with:
32+
python-version: ${{ matrix.python-version }}
33+
architecture: x64
34+
35+
- name: Setup Graphviz
36+
uses: ts-graphviz/setup-graphviz@v2
37+
38+
- name: Install
39+
run: |
40+
python -m pip install --upgrade pip
41+
python -m pip install -e '.[test]' 'ray[default]'
42+
43+
- name: Run tests
44+
run: |
45+
pytest -vs

cubed/runtime/create.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,10 @@ def create_executor(name: str, executor_options: Optional[dict] = None) -> Execu
3030
from cubed.runtime.executors.local import ProcessesExecutor
3131

3232
return ProcessesExecutor(**executor_options)
33+
elif name == "ray":
34+
from cubed.runtime.executors.ray import RayExecutor
35+
36+
return RayExecutor(**executor_options)
3337
elif name == "single-threaded":
3438
from cubed.runtime.executors.local import SingleThreadedExecutor
3539

cubed/runtime/executors/ray.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
import asyncio
2+
from typing import Optional, Sequence
3+
4+
import ray
5+
from networkx import MultiDiGraph
6+
7+
from cubed.runtime.asyncio import async_map_dag
8+
from cubed.runtime.backup import use_backups_default
9+
from cubed.runtime.types import Callback, DagExecutor
10+
from cubed.runtime.utils import asyncio_run, execute_with_stats
11+
from cubed.spec import Spec
12+
13+
14+
class RayExecutor(DagExecutor):
15+
"""An execution engine that uses Ray."""
16+
17+
def __init__(self, **kwargs):
18+
self.kwargs = kwargs
19+
20+
@property
21+
def name(self) -> str:
22+
return "ray"
23+
24+
def execute_dag(
25+
self,
26+
dag: MultiDiGraph,
27+
callbacks: Optional[Sequence[Callback]] = None,
28+
resume: Optional[bool] = None,
29+
spec: Optional[Spec] = None,
30+
compute_id: Optional[str] = None,
31+
**kwargs,
32+
) -> None:
33+
merged_kwargs = {**self.kwargs, **kwargs}
34+
35+
ray_init = merged_kwargs.pop("ray_init", None)
36+
if ray_init is not None:
37+
ray.init(**ray_init)
38+
39+
asyncio_run(
40+
self._async_execute_dag(
41+
dag,
42+
callbacks=callbacks,
43+
resume=resume,
44+
spec=spec,
45+
compute_id=compute_id,
46+
**merged_kwargs,
47+
)
48+
)
49+
50+
async def _async_execute_dag(
51+
self,
52+
dag: MultiDiGraph,
53+
callbacks: Optional[Sequence[Callback]] = None,
54+
resume: Optional[bool] = None,
55+
spec: Optional[Spec] = None,
56+
compute_arrays_in_parallel: Optional[bool] = None,
57+
**kwargs,
58+
) -> None:
59+
if spec is not None:
60+
if "use_backups" not in kwargs and use_backups_default(spec):
61+
kwargs["use_backups"] = True
62+
63+
allowed_mem = spec.allowed_mem if spec is not None else 2_000_000_000
64+
retries = kwargs.pop("retries", 2)
65+
66+
# note we can define the remote function here (doesn't need to be at top-level), and pass in memory and retries
67+
@ray.remote(memory=allowed_mem, max_retries=retries, retry_exceptions=True)
68+
def run_remotely(input, func=None, config=None, name=None, compute_id=None):
69+
# note we can't use the execution_stat decorator since it doesn't work with ray decorators
70+
result, stats = execute_with_stats(func, input, config=config)
71+
return result, stats
72+
73+
def create_futures_func(input, **kwargs):
74+
return [
75+
(i, asyncio.wrap_future(run_remotely.remote(i, **kwargs).future()))
76+
for i in input
77+
]
78+
79+
await async_map_dag(
80+
create_futures_func,
81+
dag=dag,
82+
callbacks=callbacks,
83+
resume=resume,
84+
compute_arrays_in_parallel=compute_arrays_in_parallel,
85+
**kwargs,
86+
)

cubed/tests/utils.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,12 @@
5151
except ImportError:
5252
pass
5353

54+
try:
55+
ALL_EXECUTORS.append(create_executor("ray"))
56+
MAIN_EXECUTORS.append(create_executor("ray"))
57+
except ImportError:
58+
pass
59+
5460
MODAL_EXECUTORS = []
5561

5662
try:

setup.cfg

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,8 @@ ignore_missing_imports = True
6464
ignore_missing_imports = True
6565
[mypy-pytest.*]
6666
ignore_missing_imports = True
67+
[mypy-ray.*]
68+
ignore_missing_imports = True
6769
[mypy-rechunker.*]
6870
ignore_missing_imports = True
6971
[mypy-seaborn.*]

0 commit comments

Comments
 (0)