From 1106bfb9cae251392a82978e9308c01d97e2aed1 Mon Sep 17 00:00:00 2001 From: Julian P Samaroo Date: Wed, 9 Apr 2025 19:38:04 -0700 Subject: [PATCH] Sch: Don't overwrite error result --- src/sch/Sch.jl | 4 ++++ src/sch/util.jl | 25 +++++++++++++++++++------ 2 files changed, 23 insertions(+), 6 deletions(-) diff --git a/src/sch/Sch.jl b/src/sch/Sch.jl index 0335fe3e0..a4f1f1f5c 100644 --- a/src/sch/Sch.jl +++ b/src/sch/Sch.jl @@ -599,6 +599,10 @@ function scheduler_run(ctx, state::ComputeState, d::Thunk, options) state.transfer_rate[] = (state.transfer_rate[] + metadata.transfer_rate) รท 2 end end + if thunk_failed && res isa RemoteException + res = res.captured + end + @assert !haskey(state.cache, node) state.cache[node] = res state.errored[node] = thunk_failed if node.options !== nothing && node.options.checkpoint !== nothing diff --git a/src/sch/util.jl b/src/sch/util.jl index dd148d336..f2d99e7ab 100644 --- a/src/sch/util.jl +++ b/src/sch/util.jl @@ -178,16 +178,29 @@ end "Marks `thunk` and all dependent thunks as failed." function set_failed!(state, origin, thunk=origin) - filter!(x->x!==thunk, state.ready) - ex = state.cache[origin] - if ex isa RemoteException - ex = ex.captured + @assert islocked(state.lock) + + if haskey(state.cache, thunk) + @assert state.errored[thunk] + # We've already been called previously with this thunk + return + end + + filter!(x -> x !== thunk, state.ready) + # N.B. If origin === thunk, we assume that the caller has already set the error + if origin !== thunk + origin_ex = state.cache[origin] + if origin_ex isa RemoteException + origin_ex = origin_ex.captured + end + state.cache[thunk] = DTaskFailedException(thunk, origin, origin_ex) + state.errored[thunk] = true end - state.cache[thunk] = DTaskFailedException(thunk, origin, ex) - state.errored[thunk] = true finish_failed!(state, thunk, origin) end function finish_failed!(state, thunk, origin=nothing) + @assert islocked(state.lock) + # FIXME: This is duplicative with finish_task! fill_registered_futures!(state, thunk, true) if haskey(state.waiting_data, thunk) for dep in state.waiting_data[thunk]