Skip to content

Commit 479f1bf

Browse files
authored
Adding compile_function as execute option. (#536)
* Adding `compile_function` as execute option. This allows users to use jit or aot compilation during the dag finalization process within executors. It should work straightforwardly on jax/numba style jit compilation. It's possible, but maybe ugly, to perform jax-aot-style compilation. * Add numba for compilation tests. * Singlequotes not needed. * Update function doc. * Added another compile test for failure case. * Added another test to ensure config was applied. * Improve tests - remove todo for new test - use pytest conventions * I don’t think numba jit works well on jax arrays. * Update plan.py Simplifying type to make mypy happy.
1 parent e8aaf3f commit 479f1bf

File tree

2 files changed

+90
-2
lines changed

2 files changed

+90
-2
lines changed

cubed/core/plan.py

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import atexit
2+
import dataclasses
23
import inspect
34
import shutil
45
import tempfile
@@ -30,6 +31,8 @@
3031
# Delete local context dirs when Python exits
3132
CONTEXT_DIRS = set()
3233

34+
Decorator = Callable
35+
3336

3437
def delete_on_exit(context_dir: str) -> None:
3538
if context_dir not in CONTEXT_DIRS and is_local_path(context_dir):
@@ -200,13 +203,45 @@ def _create_lazy_zarr_arrays(self, dag):
200203

201204
return dag
202205

206+
def _compile_blockwise(self, dag, compile_function: Decorator) -> nx.MultiDiGraph:
207+
"""Compiles functions from all blockwise ops by mutating the input dag."""
208+
# Recommended: make a copy of the dag before calling this function.
209+
210+
compile_with_config = 'config' in inspect.getfullargspec(compile_function).kwonlyargs
211+
212+
for n in dag.nodes:
213+
node = dag.nodes[n]
214+
215+
if "primitive_op" not in node:
216+
continue
217+
218+
if not isinstance(node["pipeline"].config, BlockwiseSpec):
219+
continue
220+
221+
if compile_with_config:
222+
compiled = compile_function(node["pipeline"].config.function, config=node["pipeline"].config)
223+
else:
224+
compiled = compile_function(node["pipeline"].config.function)
225+
226+
# node is a blockwise primitive_op.
227+
# maybe we should investigate some sort of optics library for frozen dataclasses...
228+
new_pipeline = dataclasses.replace(
229+
node["pipeline"],
230+
config=dataclasses.replace(node["pipeline"].config, function=compiled)
231+
)
232+
node["pipeline"] = new_pipeline
233+
234+
return dag
235+
203236
@lru_cache
204237
def _finalize_dag(
205-
self, optimize_graph: bool = True, optimize_function=None
238+
self, optimize_graph: bool = True, optimize_function=None, compile_function: Optional[Decorator] = None,
206239
) -> nx.MultiDiGraph:
207240
dag = self.optimize(optimize_function).dag if optimize_graph else self.dag
208241
# create a copy since _create_lazy_zarr_arrays mutates the dag
209242
dag = dag.copy()
243+
if callable(compile_function):
244+
dag = self._compile_blockwise(dag, compile_function)
210245
dag = self._create_lazy_zarr_arrays(dag)
211246
return nx.freeze(dag)
212247

@@ -216,11 +251,12 @@ def execute(
216251
callbacks=None,
217252
optimize_graph=True,
218253
optimize_function=None,
254+
compile_function=None,
219255
resume=None,
220256
spec=None,
221257
**kwargs,
222258
):
223-
dag = self._finalize_dag(optimize_graph, optimize_function)
259+
dag = self._finalize_dag(optimize_graph, optimize_function, compile_function)
224260

225261
compute_id = f"compute-{datetime.now().strftime('%Y%m%dT%H%M%S.%f')}"
226262

cubed/tests/test_executor_features.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -315,3 +315,55 @@ def test_check_runtime_memory_processes(spec, executor):
315315

316316
# OK if we use fewer workers
317317
c.compute(executor=executor, max_workers=max_workers // 2)
318+
319+
320+
COMPILE_FUNCTIONS = [lambda fn: fn]
321+
322+
try:
323+
from numba import jit as numba_jit
324+
COMPILE_FUNCTIONS.append(numba_jit)
325+
except ModuleNotFoundError:
326+
pass
327+
328+
try:
329+
if 'jax' in os.environ.get('CUBED_BACKEND_ARRAY_API_MODULE', ''):
330+
from jax import jit as jax_jit
331+
COMPILE_FUNCTIONS.append(jax_jit)
332+
except ModuleNotFoundError:
333+
pass
334+
335+
336+
@pytest.mark.parametrize("compile_function", COMPILE_FUNCTIONS)
337+
def test_check_compilation(spec, executor, compile_function):
338+
a = xp.asarray([[1, 2, 3], [4, 5, 6], [7, 8, 9]], chunks=(2, 2), spec=spec)
339+
b = xp.asarray([[1, 1, 1], [1, 1, 1], [1, 1, 1]], chunks=(2, 2), spec=spec)
340+
c = xp.add(a, b)
341+
assert_array_equal(
342+
c.compute(executor=executor, compile_function=compile_function), np.array([[2, 3, 4], [5, 6, 7], [8, 9, 10]])
343+
)
344+
345+
346+
def test_compilation_can_fail(spec, executor):
347+
def compile_function(func):
348+
raise NotImplementedError(f"Cannot compile {func}")
349+
350+
a = xp.asarray([[1, 2, 3], [4, 5, 6], [7, 8, 9]], chunks=(2, 2), spec=spec)
351+
b = xp.asarray([[1, 1, 1], [1, 1, 1], [1, 1, 1]], chunks=(2, 2), spec=spec)
352+
c = xp.add(a, b)
353+
with pytest.raises(NotImplementedError) as excinfo:
354+
c.compute(executor=executor, compile_function=compile_function)
355+
356+
assert "add" in str(excinfo.value), "Compile function was applied to add operation."
357+
358+
359+
def test_compilation_with_config_can_fail(spec, executor):
360+
def compile_function(func, *, config=None):
361+
raise NotImplementedError(f"Cannot compile {func} with {config}")
362+
363+
a = xp.asarray([[1, 2, 3], [4, 5, 6], [7, 8, 9]], chunks=(2, 2), spec=spec)
364+
b = xp.asarray([[1, 1, 1], [1, 1, 1], [1, 1, 1]], chunks=(2, 2), spec=spec)
365+
c = xp.add(a, b)
366+
with pytest.raises(NotImplementedError) as excinfo:
367+
c.compute(executor=executor, compile_function=compile_function)
368+
369+
assert "BlockwiseSpec" in str(excinfo.value), "Compile function was applied with a config argument."

0 commit comments

Comments
 (0)