Skip to content

Commit e1ea3d6

Browse files
maleadtjpsamaroo
authored andcommitted
RFC: Add task hooks for create and switch events
Certain libraries are configured using global or thread-local state instead of passing handles to every function. CUDA, for example, has a `cudaSetDevice` function that binds a device to the current thread for all future API calls. This is at odds with Julia's task-based concurrency, which presents an execution environment that's local to the current task (e.g., in the case of CUDA, using a different device). This PR adds a hook mechanism that can be used to detect task creation and task switches, allowing synchronization of Julia's task-local environment with a library's global or thread-local state.
1 parent bd1a664 commit e1ea3d6

File tree

5 files changed

+74
-4
lines changed

5 files changed

+74
-4
lines changed

base/task.jl

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -299,6 +299,31 @@ function task_local_storage(body::Function, key, val)
299299
end
300300
end
301301

302+
# N.B. `task` must either be the current task, or a known-stopped task
303+
function attach_task_hook!(task::Task, hook::Ptr{Cvoid})
304+
if task.hooks === nothing
305+
hooks = Ptr{Cvoid}[]
306+
task.hooks = hooks
307+
else
308+
hooks = task.hooks
309+
end
310+
if findfirst(==(hook), hooks) === nothing
311+
push!(hooks, hook)
312+
end
313+
return
314+
end
315+
attach_task_hook!(hook::Ptr{Cvoid}) = attach_task_hook!(current_task(), hook)
316+
317+
function detach_task_hook!(task::Task, hook::Ptr{Cvoid})
318+
if task.hooks === nothing
319+
return
320+
end
321+
hooks = task.hooks::Vector{Ptr{Cvoid}}
322+
deleteat!(hooks, findall(==(hook), hooks))
323+
return
324+
end
325+
detach_task_hook!(hook::Ptr{Cvoid}) = detach_task_hook!(current_task(), hook)
326+
302327
# just wait for a task to be done, no error propagation
303328
function _wait(t::Task)
304329
if !istaskdone(t)

src/jltypes.c

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2769,7 +2769,7 @@ void jl_init_types(void) JL_GC_DISABLED
27692769
NULL,
27702770
jl_any_type,
27712771
jl_emptysvec,
2772-
jl_perm_symsvec(16,
2772+
jl_perm_symsvec(17,
27732773
"next",
27742774
"queue",
27752775
"storage",
@@ -2782,11 +2782,12 @@ void jl_init_types(void) JL_GC_DISABLED
27822782
"rngState2",
27832783
"rngState3",
27842784
"rngState4",
2785+
"hooks",
27852786
"_state",
27862787
"sticky",
27872788
"_isexception",
27882789
"priority"),
2789-
jl_svec(16,
2790+
jl_svec(17,
27902791
jl_any_type,
27912792
jl_any_type,
27922793
jl_any_type,
@@ -2799,6 +2800,7 @@ void jl_init_types(void) JL_GC_DISABLED
27992800
jl_uint64_type,
28002801
jl_uint64_type,
28012802
jl_uint64_type,
2803+
jl_any_type,
28022804
jl_uint8_type,
28032805
jl_bool_type,
28042806
jl_bool_type,

src/julia.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1925,6 +1925,7 @@ typedef struct _jl_task_t {
19251925
// 4 byte padding on 32-bit systems
19261926
// uint32_t padding0;
19271927
uint64_t rngState[JL_RNG_SIZE];
1928+
jl_value_t *hooks;
19281929
_Atomic(uint8_t) _state;
19291930
uint8_t sticky; // record whether this Task can be migrated to a new thread
19301931
_Atomic(uint8_t) _isexception; // set if `result` is an exception to throw or that we exited with
@@ -1979,6 +1980,9 @@ JL_DLLEXPORT void JL_NORETURN jl_no_exc_handler(jl_value_t *e, jl_task_t *ct);
19791980
JL_DLLEXPORT JL_CONST_FUNC jl_gcframe_t **(jl_get_pgcstack)(void) JL_GLOBALLY_ROOTED JL_NOTSAFEPOINT;
19801981
#define jl_current_task (container_of(jl_get_pgcstack(), jl_task_t, gcstack))
19811982

1983+
typedef void *(*jl_task_switch_hook_t)(uint8_t is_switch,
1984+
jl_value_t *t JL_PROPAGATES_ROOT);
1985+
19821986
#include "julia_locks.h" // requires jl_task_t definition
19831987

19841988
JL_DLLEXPORT void jl_enter_handler(jl_handler_t *eh);

src/task.c

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -290,9 +290,35 @@ JL_NO_ASAN static void restore_stack2(jl_task_t *t, jl_ptls_t ptls, jl_task_t *l
290290
/* Rooted by the base module */
291291
static _Atomic(jl_function_t*) task_done_hook_func JL_GLOBALLY_ROOTED = NULL;
292292

293+
void run_task_hooks(uint8_t code, jl_value_t *t)
294+
{
295+
jl_task_t *ct = jl_current_task;
296+
jl_array_t *hooks = (jl_array_t *)ct->hooks;
297+
if ((jl_value_t *)hooks != jl_nothing) {
298+
jl_ptls_t ptls = ct->ptls;
299+
int last_in = ptls->in_pure_callback;
300+
ptls->in_pure_callback = 1;
301+
for (int i = 0; i < jl_array_len(hooks); i++) {
302+
jl_task_switch_hook_t hook =
303+
((jl_task_switch_hook_t *)jl_array_data(hooks))[i];
304+
JL_TRY {
305+
hook(code, t);
306+
}
307+
JL_CATCH {
308+
jl_printf((JL_STREAM*)STDERR_FILENO, "WARNING: task hook threw an error:\n");
309+
jl_static_show((JL_STREAM*)STDERR_FILENO, jl_current_exception());
310+
jl_printf((JL_STREAM*)STDERR_FILENO, "\n");
311+
jlbacktrace(); // written to STDERR_FILENO
312+
}
313+
}
314+
ptls->in_pure_callback = last_in;
315+
}
316+
}
317+
293318
void JL_NORETURN jl_finish_task(jl_task_t *t)
294319
{
295320
jl_task_t *ct = jl_current_task;
321+
run_task_hooks(4, jl_nothing);
296322
JL_PROBE_RT_FINISH_TASK(ct);
297323
JL_SIGATOMIC_BEGIN();
298324
if (jl_atomic_load_relaxed(&t->_isexception))
@@ -635,10 +661,12 @@ JL_DLLEXPORT void jl_switch(void) JL_NOTSAFEPOINT_LEAVE JL_NOTSAFEPOINT_ENTER
635661
if (ptls->in_finalizer)
636662
jl_error("task switch not allowed from inside gc finalizer");
637663
if (ptls->in_pure_callback)
638-
jl_error("task switch not allowed from inside staged nor pure functions");
664+
jl_error("task switch not allowed from inside staged nor pure functions or callbacks");
639665
if (!jl_set_task_tid(t, jl_atomic_load_relaxed(&ct->tid))) // manually yielding to a task
640666
jl_error("cannot switch to task running on another thread");
641667

668+
run_task_hooks(2, jl_nothing);
669+
642670
JL_PROBE_RT_PAUSE_TASK(ct);
643671

644672
// Store old values on the stack and reset
@@ -688,6 +716,7 @@ JL_DLLEXPORT void jl_switch(void) JL_NOTSAFEPOINT_LEAVE JL_NOTSAFEPOINT_ENTER
688716
jl_sigint_safepoint(ptls);
689717

690718
JL_PROBE_RT_RUN_TASK(ct);
719+
run_task_hooks(3, jl_nothing);
691720
jl_gc_unsafe_leave(ptls, gc_state);
692721
}
693722

@@ -1078,6 +1107,7 @@ JL_DLLEXPORT jl_task_t *jl_new_task(jl_function_t *start, jl_value_t *completion
10781107
t->next = jl_nothing;
10791108
t->queue = jl_nothing;
10801109
t->tls = jl_nothing;
1110+
t->hooks = jl_nothing;
10811111
jl_atomic_store_relaxed(&t->_state, JL_TASK_STATE_RUNNABLE);
10821112
t->start = start;
10831113
t->result = jl_nothing;
@@ -1119,6 +1149,11 @@ JL_DLLEXPORT jl_task_t *jl_new_task(jl_function_t *start, jl_value_t *completion
11191149
#ifdef _COMPILER_ASAN_ENABLED_
11201150
t->ctx.asan_fake_stack = NULL;
11211151
#endif
1152+
1153+
JL_GC_PUSH1(&t);
1154+
run_task_hooks(0, (jl_value_t *)t);
1155+
JL_GC_POP();
1156+
11221157
return t;
11231158
}
11241159

@@ -1234,6 +1269,9 @@ CFI_NORETURN
12341269

12351270
ct->started = 1;
12361271
JL_PROBE_RT_START_TASK(ct);
1272+
1273+
run_task_hooks(1, jl_nothing);
1274+
12371275
if (jl_atomic_load_relaxed(&ct->_isexception)) {
12381276
record_backtrace(ptls, 0);
12391277
jl_push_excstack(&ct->excstack, ct->result,
@@ -1670,6 +1708,7 @@ jl_task_t *jl_init_root_task(jl_ptls_t ptls, void *stack_lo, void *stack_hi)
16701708
ct->next = jl_nothing;
16711709
ct->queue = jl_nothing;
16721710
ct->tls = jl_nothing;
1711+
ct->hooks = jl_nothing;
16731712
jl_atomic_store_relaxed(&ct->_state, JL_TASK_STATE_RUNNABLE);
16741713
ct->start = NULL;
16751714
ct->result = jl_nothing;

test/staged.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,7 @@ let gf_err, tsk = @async nothing # create a Task for yield to try to run
179179
yield()
180180
gf_err_ref[] += 1000
181181
end
182-
Expected = ErrorException("task switch not allowed from inside staged nor pure functions")
182+
Expected = ErrorException("task switch not allowed from inside staged nor pure functions or callbacks")
183183
@test_throws Expected gf_err()
184184
@test_throws Expected gf_err()
185185
@test gf_err_ref[] == 4

0 commit comments

Comments
 (0)