Skip to content

RFC: Add a hook for detecting task switches. #39994

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/gc.c
Original file line number Diff line number Diff line change
Expand Up @@ -2799,6 +2799,8 @@ static void mark_roots(jl_gc_mark_cache_t *gc_cache, jl_gc_mark_sp_t *sp)
gc_mark_queue_obj(gc_cache, sp, jl_an_empty_vec_any);
if (jl_module_init_order != NULL)
gc_mark_queue_obj(gc_cache, sp, jl_module_init_order);
if (jl_task_switch_hooks != NULL)
gc_mark_queue_obj(gc_cache, sp, jl_task_switch_hooks);
for (size_t i = 0; i < jl_current_modules.size; i += 2) {
if (jl_current_modules.table[i + 1] != HT_NOTFOUND) {
gc_mark_queue_obj(gc_cache, sp, jl_current_modules.table[i]);
Expand Down
1 change: 1 addition & 0 deletions src/jl_exported_data.inc
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
XX(jl_argumenterror_type) \
XX(jl_array_any_type) \
XX(jl_array_int32_type) \
XX(jl_array_voidpointer_type) \
XX(jl_array_symbol_type) \
XX(jl_array_type) \
XX(jl_array_typename) \
Expand Down
54 changes: 28 additions & 26 deletions src/jltypes.c
Original file line number Diff line number Diff line change
Expand Up @@ -2048,7 +2048,32 @@ void jl_init_types(void) JL_GC_DISABLED
jl_function_type->name->mt = NULL; // subtypes of Function have independent method tables
jl_builtin_type->name->mt = NULL; // so they don't share the Any type table

jl_svec_t *tv = jl_svec2(tvar("T"), tvar("N"));
// Ref{T} where {T}
jl_svec_t *tv = jl_svec1(tvar("T"));
jl_ref_type = (jl_unionall_t*)
jl_new_abstracttype((jl_value_t*)jl_symbol("Ref"), core, jl_any_type, tv)->name->wrapper;

// Ptr{T} <: Ref{T} where {T}
tv = jl_svec1(tvar("T"));
jl_pointer_type = (jl_unionall_t*)
jl_new_primitivetype((jl_value_t*)jl_symbol("Ptr"), core,
(jl_datatype_t*)jl_apply_type((jl_value_t*)jl_ref_type, jl_svec_data(tv), 1), tv,
sizeof(void*)*8)->name->wrapper;
jl_pointer_typename = ((jl_datatype_t*)jl_unwrap_unionall((jl_value_t*)jl_pointer_type))->name;
// common subtypes
jl_voidpointer_type = (jl_datatype_t*)jl_apply_type1((jl_value_t*)jl_pointer_type, (jl_value_t*)jl_nothing_type);
jl_uint8pointer_type = (jl_datatype_t*)jl_apply_type1((jl_value_t*)jl_pointer_type, (jl_value_t*)jl_uint8_type);

// LLVMPtr{T, AS} where {T, AS}
tv = jl_svec2(tvar("T"), tvar("AS"));
jl_svec_t *tv_base = jl_svec1(tvar("T"));
jl_llvmpointer_type = (jl_unionall_t*)
jl_new_primitivetype((jl_value_t*)jl_symbol("LLVMPtr"), core,
(jl_datatype_t*)jl_apply_type((jl_value_t*)jl_ref_type, jl_svec_data(tv_base), 1), tv,
sizeof(void*)*8)->name->wrapper;
jl_llvmpointer_typename = ((jl_datatype_t*)jl_unwrap_unionall((jl_value_t*)jl_llvmpointer_type))->name;

tv = jl_svec2(tvar("T"), tvar("N"));
jl_abstractarray_type = (jl_unionall_t*)
jl_new_abstracttype((jl_value_t*)jl_symbol("AbstractArray"), core,
jl_any_type, tv)->name->wrapper;
Expand All @@ -2073,6 +2098,7 @@ void jl_init_types(void) JL_GC_DISABLED
jl_array_symbol_type = jl_apply_type2((jl_value_t*)jl_array_type, (jl_value_t*)jl_symbol_type, jl_box_long(1));
jl_array_uint8_type = jl_apply_type2((jl_value_t*)jl_array_type, (jl_value_t*)jl_uint8_type, jl_box_long(1));
jl_array_int32_type = jl_apply_type2((jl_value_t*)jl_array_type, (jl_value_t*)jl_int32_type, jl_box_long(1));
jl_array_voidpointer_type = jl_apply_type2((jl_value_t*)jl_array_type, (jl_value_t*)jl_voidpointer_type, jl_box_long(1));
jl_an_empty_vec_any = (jl_value_t*)jl_alloc_vec_any(0); // used internally
jl_nonfunction_mt->leafcache = (jl_array_t*)jl_an_empty_vec_any;
jl_type_type_mt->leafcache = (jl_array_t*)jl_an_empty_vec_any;
Expand Down Expand Up @@ -2328,26 +2354,6 @@ void jl_init_types(void) JL_GC_DISABLED
jl_intrinsic_type = jl_new_primitivetype((jl_value_t*)jl_symbol("IntrinsicFunction"), core,
jl_builtin_type, jl_emptysvec, 32);

tv = jl_svec1(tvar("T"));
jl_ref_type = (jl_unionall_t*)
jl_new_abstracttype((jl_value_t*)jl_symbol("Ref"), core, jl_any_type, tv)->name->wrapper;

tv = jl_svec1(tvar("T"));
jl_pointer_type = (jl_unionall_t*)
jl_new_primitivetype((jl_value_t*)jl_symbol("Ptr"), core,
(jl_datatype_t*)jl_apply_type((jl_value_t*)jl_ref_type, jl_svec_data(tv), 1), tv,
sizeof(void*)*8)->name->wrapper;
jl_pointer_typename = ((jl_datatype_t*)jl_unwrap_unionall((jl_value_t*)jl_pointer_type))->name;

// LLVMPtr{T, AS} where {T, AS}
tv = jl_svec2(tvar("T"), tvar("AS"));
jl_svec_t *tv_base = jl_svec1(tvar("T"));
jl_llvmpointer_type = (jl_unionall_t*)
jl_new_primitivetype((jl_value_t*)jl_symbol("LLVMPtr"), core,
(jl_datatype_t*)jl_apply_type((jl_value_t*)jl_ref_type, jl_svec_data(tv_base), 1), tv,
sizeof(void*)*8)->name->wrapper;
jl_llvmpointer_typename = ((jl_datatype_t*)jl_unwrap_unionall((jl_value_t*)jl_llvmpointer_type))->name;

// Type{T} where T<:Tuple
tttvar = jl_new_typevar(jl_symbol("T"),
(jl_value_t*)jl_bottom_type,
Expand Down Expand Up @@ -2395,12 +2401,10 @@ void jl_init_types(void) JL_GC_DISABLED
jl_value_t *listt = jl_new_struct(jl_uniontype_type, jl_task_type, jl_nothing_type);
jl_svecset(jl_task_type->types, 0, listt);

jl_value_t *pointer_void = jl_apply_type1((jl_value_t*)jl_pointer_type, (jl_value_t*)jl_nothing_type);

tv = jl_svec2(tvar("A"), tvar("R"));
jl_opaque_closure_type = (jl_unionall_t*)jl_new_datatype(jl_symbol("OpaqueClosure"), core, jl_function_type, tv,
jl_perm_symsvec(6, "captures", "isva", "world", "source", "invoke", "specptr"),
jl_svec(6, jl_any_type, jl_bool_type, jl_long_type, jl_any_type, pointer_void, pointer_void), 0, 0, 6)->name->wrapper;
jl_svec(6, jl_any_type, jl_bool_type, jl_long_type, jl_any_type, jl_voidpointer_type, jl_voidpointer_type), 0, 0, 6)->name->wrapper;
jl_opaque_closure_typename = ((jl_datatype_t*)jl_unwrap_unionall((jl_value_t*)jl_opaque_closure_type))->name;
jl_compute_field_offsets((jl_datatype_t*)jl_unwrap_unionall((jl_value_t*)jl_opaque_closure_type));

Expand All @@ -2411,8 +2415,6 @@ void jl_init_types(void) JL_GC_DISABLED
0, 0, 5);

// complete builtin type metadata
jl_voidpointer_type = (jl_datatype_t*)pointer_void;
jl_uint8pointer_type = (jl_datatype_t*)jl_apply_type1((jl_value_t*)jl_pointer_type, (jl_value_t*)jl_uint8_type);
jl_svecset(jl_datatype_type->types, 6, jl_voidpointer_type);
jl_svecset(jl_datatype_type->types, 7, jl_int32_type);
jl_svecset(jl_datatype_type->types, 8, jl_int32_type);
Expand Down
4 changes: 4 additions & 0 deletions src/julia.h
Original file line number Diff line number Diff line change
Expand Up @@ -713,6 +713,7 @@ extern JL_DLLIMPORT jl_value_t *jl_array_uint8_type JL_GLOBALLY_ROOTED;
extern JL_DLLIMPORT jl_value_t *jl_array_any_type JL_GLOBALLY_ROOTED;
extern JL_DLLIMPORT jl_value_t *jl_array_symbol_type JL_GLOBALLY_ROOTED;
extern JL_DLLIMPORT jl_value_t *jl_array_int32_type JL_GLOBALLY_ROOTED;
extern JL_DLLIMPORT jl_value_t *jl_array_voidpointer_type JL_GLOBALLY_ROOTED;
extern JL_DLLIMPORT jl_datatype_t *jl_expr_type JL_GLOBALLY_ROOTED;
extern JL_DLLIMPORT jl_datatype_t *jl_globalref_type JL_GLOBALLY_ROOTED;
extern JL_DLLIMPORT jl_datatype_t *jl_linenumbernode_type JL_GLOBALLY_ROOTED;
Expand Down Expand Up @@ -1827,6 +1828,9 @@ JL_DLLEXPORT void JL_NORETURN jl_sig_throw(void);
JL_DLLEXPORT void JL_NORETURN jl_rethrow_other(jl_value_t *e JL_MAYBE_UNROOTED);
JL_DLLEXPORT void JL_NORETURN jl_no_exc_handler(jl_value_t *e);

typedef void *(*jl_task_switch_hook_t)(jl_task_t *t JL_PROPAGATES_ROOT);
JL_DLLEXPORT void jl_hook_task_switch(jl_task_switch_hook_t hook);

#include "locks.h" // requires jl_task_t definition

JL_DLLEXPORT void jl_enter_handler(jl_handler_t *eh);
Expand Down
1 change: 1 addition & 0 deletions src/julia_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -368,6 +368,7 @@ void jl_gc_track_malloced_array(jl_ptls_t ptls, jl_array_t *a) JL_NOTSAFEPOINT;
void jl_gc_count_allocd(size_t sz) JL_NOTSAFEPOINT;
void jl_gc_run_all_finalizers(jl_ptls_t ptls);
void jl_release_task_stack(jl_ptls_t ptls, jl_task_t *task);
extern jl_array_t *jl_task_switch_hooks JL_GLOBALLY_ROOTED;

void gc_queue_binding(jl_binding_t *bnd) JL_NOTSAFEPOINT;
void gc_setmark_buf(jl_ptls_t ptls, void *buf, uint8_t, size_t) JL_NOTSAFEPOINT;
Expand Down
13 changes: 3 additions & 10 deletions src/stackwalk.c
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,6 @@ NOINLINE size_t rec_backtrace(jl_bt_element_t *bt_data, size_t maxsize, int skip
return bt_size;
}

static jl_value_t *array_ptr_void_type JL_ALWAYS_LEAFTYPE = NULL;
// Return backtrace information as an svec of (bt1, bt2, [sp])
//
// The stack pointers `sp` are returned only when `returnsp` evaluates to true.
Expand All @@ -231,11 +230,8 @@ JL_DLLEXPORT jl_value_t *jl_backtrace_from_here(int returnsp, int skip)
jl_array_t *sp = NULL;
jl_array_t *bt2 = NULL;
JL_GC_PUSH3(&ip, &sp, &bt2);
if (array_ptr_void_type == NULL) {
array_ptr_void_type = jl_apply_type2((jl_value_t*)jl_array_type, (jl_value_t*)jl_voidpointer_type, jl_box_long(1));
}
ip = jl_alloc_array_1d(array_ptr_void_type, 0);
sp = returnsp ? jl_alloc_array_1d(array_ptr_void_type, 0) : NULL;
ip = jl_alloc_array_1d(jl_array_voidpointer_type, 0);
sp = returnsp ? jl_alloc_array_1d(jl_array_voidpointer_type, 0) : NULL;
bt2 = jl_alloc_array_1d(jl_array_any_type, 0);
const size_t maxincr = 1000;
bt_context_t context;
Expand Down Expand Up @@ -290,10 +286,7 @@ static void decode_backtrace(jl_bt_element_t *bt_data, size_t bt_size,
jl_array_t **bt2out JL_REQUIRE_ROOTED_SLOT)
{
jl_array_t *bt, *bt2;
if (array_ptr_void_type == NULL) {
array_ptr_void_type = jl_apply_type2((jl_value_t*)jl_array_type, (jl_value_t*)jl_voidpointer_type, jl_box_long(1));
}
bt = *btout = jl_alloc_array_1d(array_ptr_void_type, bt_size);
bt = *btout = jl_alloc_array_1d(jl_array_voidpointer_type, bt_size);
static_assert(sizeof(jl_bt_element_t) == sizeof(void*),
"jl_bt_element_t is presented as Ptr{Cvoid} on julia side");
memcpy(bt->data, bt_data, bt_size * sizeof(jl_bt_element_t));
Expand Down
3 changes: 2 additions & 1 deletion src/staticdata.c
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ extern "C" {
// TODO: put WeakRefs on the weak_refs list during deserialization
// TODO: handle finalizers

#define NUM_TAGS 147
#define NUM_TAGS 148

// An array of references that need to be restored from the sysimg
// This is a manually constructed dual of the gvars array, which would be produced by codegen for Julia code, for C.
Expand Down Expand Up @@ -110,6 +110,7 @@ jl_value_t **const*const get_tags(void) {
INSERT_TAG(jl_array_symbol_type);
INSERT_TAG(jl_array_uint8_type);
INSERT_TAG(jl_array_int32_type);
INSERT_TAG(jl_array_voidpointer_type);
INSERT_TAG(jl_int32_type);
INSERT_TAG(jl_int64_type);
INSERT_TAG(jl_bool_type);
Expand Down
24 changes: 23 additions & 1 deletion src/task.c
Original file line number Diff line number Diff line change
Expand Up @@ -480,6 +480,17 @@ static jl_ptls_t NOINLINE refetch_ptls(void)
return jl_get_ptls_states();
}

jl_array_t *jl_task_switch_hooks JL_GLOBALLY_ROOTED = NULL;
JL_DLLEXPORT void jl_hook_task_switch(jl_task_switch_hook_t hook)
{
if (jl_task_switch_hooks == NULL) {
jl_task_switch_hooks = jl_alloc_array_1d(jl_array_voidpointer_type, 0);
}
jl_array_grow_end(jl_task_switch_hooks, 1);
((jl_task_switch_hook_t *)jl_array_data(
jl_task_switch_hooks))[jl_array_len(jl_task_switch_hooks) - 1] = hook;
}

JL_DLLEXPORT void jl_switch(void)
{
jl_ptls_t ptls = jl_get_ptls_states();
Expand All @@ -497,7 +508,7 @@ JL_DLLEXPORT void jl_switch(void)
if (ptls->in_finalizer)
jl_error("task switch not allowed from inside gc finalizer");
if (ptls->in_pure_callback)
jl_error("task switch not allowed from inside staged nor pure functions");
jl_error("task switch not allowed from inside staged nor pure functions or callbacks");
if (t->sticky && jl_atomic_load_acquire(&t->tid) == -1) {
// manually yielding to a task
if (jl_atomic_compare_exchange(&t->tid, -1, ptls->tid) != -1)
Expand All @@ -507,6 +518,17 @@ JL_DLLEXPORT void jl_switch(void)
jl_error("cannot switch to task running on another thread");
}

if (jl_task_switch_hooks) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need a lock (or some atomics) here? If so, wouldn't it increase the overhead of the task switch even if there are no hooks?

int last_in = ptls->in_pure_callback;
ptls->in_pure_callback = 1;
for (int i = 0; i < jl_array_len(jl_task_switch_hooks); i++) {
jl_task_switch_hook_t hook =
((jl_task_switch_hook_t *)jl_array_data(jl_task_switch_hooks))[i];
hook(t);
}
ptls->in_pure_callback = last_in;
}

// Store old values on the stack and reset
sig_atomic_t defer_signal = ptls->defer_signal;
int8_t gc_state = jl_gc_unsafe_enter(ptls);
Expand Down