Skip to content

Commit 2385e89

Browse files
committed
Sch: Track unique Chunk per DRef
1 parent a67aa20 commit 2385e89

File tree

2 files changed

+19
-15
lines changed

2 files changed

+19
-15
lines changed

src/sch/Sch.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ Fields:
5858
- `running_on::Dict{Thunk,OSProc}` - Map from `Thunk` to the OS process executing it
5959
- `thunk_dict::Dict{Int, WeakThunk}` - Maps from thunk IDs to a `Thunk`
6060
- `node_order::Any` - Function that returns the order of a thunk
61+
- `equiv_chunks::WeakKeyDict{DRef,Chunk}` - Cache mapping from `DRef` to a `Chunk` which contains it
6162
- `worker_time_pressure::Dict{Int,Dict{Processor,UInt64}}` - Maps from worker ID to processor pressure
6263
- `worker_storage_pressure::Dict{Int,Dict{Union{StorageResource,Nothing},UInt64}}` - Maps from worker ID to storage resource pressure
6364
- `worker_storage_capacity::Dict{Int,Dict{Union{StorageResource,Nothing},UInt64}}` - Maps from worker ID to storage resource capacity
@@ -84,6 +85,7 @@ struct ComputeState
8485
running_on::Dict{Thunk,OSProc}
8586
thunk_dict::Dict{Int, WeakThunk}
8687
node_order::Any
88+
equiv_chunks::WeakKeyDict{DRef,Chunk}
8789
worker_time_pressure::Dict{Int,Dict{Processor,UInt64}}
8890
worker_storage_pressure::Dict{Int,Dict{Union{StorageResource,Nothing},UInt64}}
8991
worker_storage_capacity::Dict{Int,Dict{Union{StorageResource,Nothing},UInt64}}
@@ -113,6 +115,7 @@ function start_state(deps::Dict, node_order, chan)
113115
Dict{Thunk,OSProc}(),
114116
Dict{Int, WeakThunk}(),
115117
node_order,
118+
WeakKeyDict{DRef,Chunk}(),
116119
Dict{Int,Dict{Processor,UInt64}}(),
117120
Dict{Int,Dict{Union{StorageResource,Nothing},UInt64}}(),
118121
Dict{Int,Dict{Union{StorageResource,Nothing},UInt64}}(),
@@ -428,6 +431,11 @@ function scheduler_run(ctx, state::ComputeState, d::Thunk, options::SchedulerOpt
428431
state.transfer_rate[] = (state.transfer_rate[] + metadata.transfer_rate) ÷ 2
429432
end
430433
end
434+
if res isa Chunk
435+
if !haskey(state.equiv_chunks, res)
436+
state.equiv_chunks[res.handle::DRef] = res
437+
end
438+
end
431439
store_result!(state, node, res; error=thunk_failed)
432440
if node.options !== nothing && node.options.checkpoint !== nothing
433441
try

src/submission.jl

Lines changed: 11 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -101,18 +101,17 @@ function eager_submit_internal!(ctx, state, task, tid, payload::AnyPayload; uid_
101101
elseif valuetype(arg) <: Chunk
102102
# N.B. Different Chunks with the same DRef handle will hash to the same slot,
103103
# so we just pick an equivalent Chunk as our upstream
104+
# FIXME: It's super slow to iterate over state.waiting_data, can we cache this?
104105
chunk = value(arg)::Chunk
105-
if haskey(state.waiting_data, chunk)
106-
newchunk = nothing
107-
for other in keys(state.waiting_data)
108-
if other isa Chunk && other.handle == chunk.handle
109-
newchunk = other
110-
break
111-
end
106+
function find_equivalent_chunk(state, chunk::C) where {C<:Chunk}
107+
if haskey(state.equiv_chunks, chunk.handle)
108+
return state.equiv_chunks[chunk.handle]::C
109+
else
110+
state.equiv_chunks[chunk.handle] = chunk
111+
return chunk
112112
end
113-
@assert newchunk !== nothing
114-
chunk = newchunk::Chunk
115113
end
114+
chunk = find_equivalent_chunk(state, chunk)
116115
#=FIXME:UNIQUE=#
117116
@inbounds fargs[idx] = Argument(arg.pos, WeakChunk(chunk))
118117
end
@@ -193,6 +192,7 @@ function eager_submit_internal!(ctx, state, task, tid, payload::AnyPayload; uid_
193192

194193
return thunk_id
195194
end
195+
empty!(equiv_chunks)
196196
end
197197
end
198198
struct UnrefThunk
@@ -238,16 +238,12 @@ function eager_submit!(payload::AnyPayload)
238238
return remotecall_fetch(1, payload) do payload
239239
Sch.init_eager()
240240
state = Dagger.Sch.EAGER_STATE[]
241-
@lock state.lock begin
242-
eager_submit_internal!(payload)
243-
end
241+
@lock state.lock eager_submit_internal!(payload)
244242
end
245243
else
246244
Sch.init_eager()
247245
state = Dagger.Sch.EAGER_STATE[]
248-
return lock(state.lock) do
249-
eager_submit_internal!(payload)
250-
end
246+
return @lock state.lock eager_submit_internal!(payload)
251247
end
252248
end
253249

0 commit comments

Comments
 (0)