Skip to content

Commit a2633e8

Browse files
committed
TEMP Optimizations
1 parent fb1dcbd commit a2633e8

19 files changed

+1099
-770
lines changed

src/Dagger.jl

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ if !isdefined(Base, :ScopedValues)
2121
else
2222
import Base.ScopedValues: ScopedValue, with
2323
end
24+
import TaskLocalValues: TaskLocalValue
2425

2526
if !isdefined(Base, :get_extension)
2627
import Requires: @require
@@ -32,9 +33,13 @@ import TimespanLogging: timespan_start, timespan_finish
3233
include("lib/util.jl")
3334
include("utils/dagdebug.jl")
3435

36+
# Logging Basics
37+
include("utils/logging.jl")
38+
3539
# Distributed data
3640
include("utils/locked-object.jl")
3741
include("utils/tasks.jl")
42+
include("utils/reuse.jl")
3843

3944
import MacroTools: @capture
4045
include("options.jl")
@@ -46,6 +51,7 @@ include("task-tls.jl")
4651
include("scopes.jl")
4752
include("utils/scopes.jl")
4853
include("dtask.jl")
54+
include("argument.jl")
4955
include("queue.jl")
5056
include("thunk.jl")
5157
include("submission.jl")
@@ -62,34 +68,34 @@ include("sch/Sch.jl"); using .Sch
6268
# Data dependency task queue
6369
include("datadeps.jl")
6470

71+
# File IO
72+
include("file-io.jl")
73+
6574
# Array computations
6675
include("array/darray.jl")
6776
include("array/alloc.jl")
6877
include("array/map-reduce.jl")
6978
include("array/copy.jl")
70-
71-
# File IO
72-
include("file-io.jl")
73-
79+
include("array/random.jl")
7480
include("array/operators.jl")
7581
include("array/indexing.jl")
7682
include("array/setindex.jl")
7783
include("array/matrix.jl")
7884
include("array/sparse_partition.jl")
85+
include("array/parallel-blocks.jl")
7986
include("array/sort.jl")
8087
include("array/linalg.jl")
8188
include("array/mul.jl")
8289
include("array/cholesky.jl")
8390

91+
# Custom Logging Events
92+
include("utils/logging-events.jl")
93+
8494
# Visualization
8595
include("visualization.jl")
8696
include("ui/gantt-common.jl")
8797
include("ui/gantt-text.jl")
8898

89-
# Logging
90-
include("utils/logging-events.jl")
91-
include("utils/logging.jl")
92-
9399
# Precompilation
94100
import PrecompileTools: @compile_workload
95101
include("precompile.jl")

src/argument.jl

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
mutable struct ArgPosition
2+
positional::Bool
3+
idx::Int
4+
kw::Symbol
5+
end
6+
ArgPosition() = ArgPosition(true, 0, :NULL)
7+
ArgPosition(pos::ArgPosition) = ArgPosition(pos.positional, pos.idx, pos.kw)
8+
ispositional(pos::ArgPosition) = pos.positional
9+
iskw(pos::ArgPosition) = !pos.positional
10+
function pos_idx(pos::ArgPosition)
11+
@assert pos.positional
12+
@assert pos.idx > 0
13+
@assert pos.kw == :NULL
14+
return pos.idx
15+
end
16+
function pos_kw(pos::ArgPosition)
17+
@assert !pos.positional
18+
@assert pos.idx == 0
19+
@assert pos.kw != :NULL
20+
return pos.kw
21+
end
22+
mutable struct Argument
23+
pos::ArgPosition
24+
value
25+
end
26+
Argument(pos::Integer, value) = Argument(ArgPosition(true, pos, :NULL), value)
27+
Argument(kw::Symbol, value) = Argument(ArgPosition(false, 0, kw), value)
28+
ispositional(arg::Argument) = ispositional(arg.pos)
29+
iskw(arg::Argument) = iskw(arg.pos)
30+
pos_idx(arg::Argument) = pos_idx(arg.pos)
31+
pos_kw(arg::Argument) = pos_kw(arg.pos)
32+
value(arg::Argument) = arg.value
33+
valuetype(arg::Argument) = typeof(arg.value)
34+
Base.iterate(arg::Argument) = (arg.pos, true)
35+
function Base.iterate(arg::Argument, state::Bool)
36+
if state
37+
return (arg.value, false)
38+
else
39+
return nothing
40+
end
41+
end
42+
43+
Base.copy(arg::Argument) = Argument(ArgPosition(arg.pos), arg.value)
44+
chunktype(arg::Argument) = chunktype(value(arg))

src/array/darray.jl

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,7 @@ domainchunks(d::DArray) = d.subdomains
173173
size(x::DArray) = size(domain(x))
174174
stage(ctx, c::DArray) = c
175175

176-
function Base.collect(d::DArray; tree=false)
176+
function Base.collect(d::DArray{T,N}; tree=false, copyto=false) where {T,N}
177177
a = fetch(d)
178178
if isempty(d.chunks)
179179
return Array{eltype(d)}(undef, size(d)...)
@@ -183,6 +183,13 @@ function Base.collect(d::DArray; tree=false)
183183
return fetch(a.chunks[1])
184184
end
185185

186+
if copyto
187+
C = Array{T,N}(undef, size(a))
188+
DC = view(C, Blocks(size(a)...))
189+
copyto!(DC, a)
190+
return C
191+
end
192+
186193
dimcatfuncs = [(x...) -> d.concat(x..., dims=i) for i in 1:ndims(d)]
187194
if tree
188195
collect(fetch(treereduce_nd(map(x -> ((args...,) -> Dagger.@spawn x(args...)) , dimcatfuncs), a.chunks)))

src/array/indexing.jl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
import TaskLocalValues: TaskLocalValue
2-
31
### getindex
42

53
struct GetIndex{T,N} <: ArrayOp{T,N}

src/array/parallel-blocks.jl

Lines changed: 174 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,174 @@
1+
export ParallelBlocks
2+
3+
using Statistics
4+
5+
struct ParallelBlocks{N} <: Dagger.AbstractSingleBlocks{N}
6+
n::Int
7+
end
8+
ParallelBlocks(n::Integer) = ParallelBlocks{0}(n)
9+
ParallelBlocks{N}(dist::ParallelBlocks) where N = ParallelBlocks{N}(dist.n)
10+
ParallelBlocks() = ParallelBlocks(Dagger.num_processors())
11+
12+
Base.convert(::Type{ParallelBlocks{N}}, dist::ParallelBlocks) where N =
13+
ParallelBlocks{N}(dist.n)
14+
15+
wrap_chunks(chunks::Vector{<:Dagger.Chunk}, N::Integer, dist::ParallelBlocks) =
16+
wrap_chunks(chunks, N, dist.n)
17+
wrap_chunks(chunks::Vector{<:Dagger.Chunk}, N::Integer, n::Integer) =
18+
convert(Array{Any}, reshape(chunks, ntuple(i->i == 1 ? n : 1, N)))
19+
20+
function _finish_allocation(f::Function, dist::ParallelBlocks, dims::NTuple{N,Int}) where N
21+
d = ArrayDomain(map(x->1:x, dims))
22+
s = reshape([d for _ in 1:dist.n],
23+
ntuple(i->i == 1 ? dist.n : 1, N))
24+
data = [f(dims) for _ in 1:dist.n]
25+
dist = ParallelBlocks{N}(dist)
26+
chunks = wrap_chunks(map(Dagger.tochunk, data), N, dist)
27+
return Dagger.DArray(eltype(first(data)), d, s, chunks, dist)
28+
end
29+
30+
for fn in [:rand, :randn, :zeros, :ones]
31+
@eval begin
32+
function Base.$fn(dist::ParallelBlocks, ::Type{ET}, dims::Dims) where {ET}
33+
f(block) = $fn(ET, block)
34+
_finish_allocation(f, dist, dims)
35+
end
36+
Base.$fn(dist::ParallelBlocks, T::Type, dims::Integer...) = $fn(dist, T, dims)
37+
Base.$fn(dist::ParallelBlocks, T::Type, dims::Tuple) = $fn(dist, T, dims)
38+
Base.$fn(dist::ParallelBlocks, dims::Integer...) = $fn(dist, Float64, dims)
39+
Base.$fn(dist::ParallelBlocks, dims::Tuple) = $fn(dist, Float64, dims)
40+
end
41+
end
42+
# FIXME: sprand
43+
44+
function Dagger.distribute(data::AbstractArray{T,N}, dist::ParallelBlocks) where {T,N}
45+
dims = size(data)
46+
d = ArrayDomain(map(x->1:x, dims))
47+
s = Dagger.DomainBlocks(ntuple(_->1, N),
48+
ntuple(i->[dims[i]], N))
49+
chunks = [Dagger.tochunk(copy(data)) for _ in 1:dist.n]
50+
new_dist = ParallelBlocks{N}(dist)
51+
return Dagger.DArray(T, d, s, wrap_chunks(chunks, N, dist), new_dist)
52+
end
53+
54+
_invalid_call_pblocks(f::Symbol) =
55+
error("`$f` is not valid for a `DArray` partitioned with `ParallelBlocks`.\nConsider `Dagger.pmap($f, x)` instead.")
56+
57+
Base.collect(::Dagger.DArray{T,N,<:ParallelBlocks} where {T,N}) =
58+
_invalid_call_pblocks(:collect)
59+
Base.getindex(::Dagger.DArray{T,N,<:ParallelBlocks} where {T,N}, x...) =
60+
_invalid_call_pblocks(:getindex)
61+
Base.setindex!(::Dagger.DArray{T,N,<:ParallelBlocks} where {T,N}, value, x...) =
62+
_invalid_call_pblocks(:setindex!)
63+
64+
function pmap(f::Function, A::Dagger.DArray{T,N,ParallelBlocks{N}}) where {T,N}
65+
# TODO: Chunks might not be `Array`s
66+
# FIXME
67+
#AT = Array{T,N}
68+
#ET = eltype(Base.promote_op(f, AT))
69+
ET = Any
70+
new_chunks = map(A.chunks) do chunk
71+
Dagger.@spawn f(chunk)
72+
end
73+
return DArray(ET, A.domain, A.subdomains, new_chunks, A.partitioning)
74+
end
75+
# FIXME: More useful `show` method
76+
Base.show(io::IO, ::MIME"text/plain", A::Dagger.DArray{T,N,ParallelBlocks{N}}) where {T,N} =
77+
print(io, typeof(A))
78+
pfetch(A::Dagger.DArray{T,N,ParallelBlocks{N}}) where {T,N} =
79+
map(fetch, A.chunks)
80+
pcollect(A::Dagger.DArray{T,N,ParallelBlocks{N}}) where {T,N} =
81+
map(collect, pfetch(A))
82+
83+
function Base.map(f::Function, A::Dagger.DArray{T,N,ParallelBlocks{N}}) where {T,N}
84+
ET = Base.promote_op(f, T)
85+
new_chunks = map(A.chunks) do chunk
86+
Dagger.@spawn map(f, chunk)
87+
end
88+
return DArray(ET, A.domain, A.subdomains, new_chunks, A.partitioning)
89+
end
90+
function Base.map!(f::Function,
91+
x::Dagger.DArray{T1,N1,ParallelBlocks{N1}} where {T1,N1},
92+
y::Dagger.DArray{T2,N2,ParallelBlocks{N2}} where {T2,N2})
93+
x_dist = x.partitioning
94+
y_dist = y.partitioning
95+
if x_dist.n != y_dist.n
96+
throw(ArgumentError("Can't `map!` over non-matching `ParallelBlocks` distributions: $(x_dist.n) != $(y_dist.n)"))
97+
end
98+
@sync for i in 1:x_dist.n
99+
Dagger.@spawn map!(f, x.chunks[i], y.chunks[i])
100+
end
101+
end
102+
103+
#=
104+
function Base.reduce(f::Function, x::Dagger.DArray{T,N,ParallelBlocks{N}};
105+
dims=:) where {T,N}
106+
error("Out-of-place Reduce")
107+
if dims == Base.:(:)
108+
localpart = fetch(Dagger.reduce_async(f, x))
109+
return MPI.Allreduce(localpart, f, comm)
110+
elseif dims === nothing
111+
localpart = fetch(x.chunks[1])
112+
return MPI.Allreduce(localpart, f, comm)
113+
else
114+
error("Not yet implemented")
115+
end
116+
end
117+
=#
118+
function allreduce!(op::Function, x::Dagger.DArray{T,N,ParallelBlocks{N}}; nchunks::Integer=0) where {T,N}
119+
if nchunks == 0
120+
nchunks = x.partitioning.n
121+
end
122+
@assert nchunks == x.partitioning.n "Number of chunks must match the number of partitions"
123+
124+
# Split each chunk along the last dimension
125+
chunk_size = cld(size(x, ndims(x)), nchunks)
126+
chunk_dist = Blocks(ntuple(i->i == N ? chunk_size : size(x, i), N))
127+
chunk_ds = partition(chunk_dist, x.subdomains[1])
128+
num_par_chunks = length(x.chunks)
129+
130+
# Allocate temporary buffer
131+
y = copy(x)
132+
133+
# Ring-reduce into temporary buffer
134+
Dagger.spawn_datadeps() do
135+
for j in 1:length(chunk_ds)
136+
for i in 1:num_par_chunks
137+
for step in 1:(num_par_chunks-1)
138+
from_idx = i
139+
to_idx = mod1(i+step, num_par_chunks)
140+
from_chunk = x.chunks[from_idx]
141+
to_chunk = y.chunks[to_idx]
142+
sd = chunk_ds[mod1(j+i-1, length(chunk_ds))].indexes
143+
# FIXME: Specify aliasing based on `sd`
144+
Dagger.@spawn _reduce_view!(op,
145+
InOut(to_chunk), sd,
146+
In(from_chunk), sd)
147+
end
148+
end
149+
end
150+
151+
# Copy from temporary buffer back to origin
152+
for i in 1:num_par_chunks
153+
Dagger.@spawn copyto!(Out(x.chunks[i]), In(y.chunks[i]))
154+
end
155+
end
156+
157+
return x
158+
end
159+
function _reduce_view!(op, to, to_view, from, from_view)
160+
to_viewed = view(to, to_view...)
161+
from_viewed = view(from, from_view...)
162+
reduce!(op, to_viewed, from_viewed)
163+
return
164+
end
165+
function reduce!(op, to, from)
166+
to .= op.(to, from)
167+
end
168+
169+
function Statistics.mean!(A::Dagger.DArray{T,N,ParallelBlocks{N}}) where {T,N}
170+
allreduce!(+, A)
171+
len = length(A.chunks)
172+
map!(x->x ./ len, A, A)
173+
return A
174+
end

src/array/random.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
using Random
2+
3+
function Random.randn!(A::DArray{T}) where T
4+
Ac = A.chunks
5+
6+
Dagger.spawn_datadeps() do
7+
for chunk in Ac
8+
Dagger.@spawn randn!(InOut(chunk))
9+
end
10+
end
11+
12+
return A
13+
end

src/chunks.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -300,9 +300,13 @@ function unwrap_weak_checked(c::WeakChunk)
300300
@assert cw !== nothing "WeakChunk expired: ($(c.wid), $(c.id))"
301301
return cw
302302
end
303+
wrap_weak(c::Chunk) = WeakChunk(c)
304+
isweak(c::WeakChunk) = true
305+
isweak(c::Chunk) = false
303306
is_task_or_chunk(c::WeakChunk) = true
304307
Serialization.serialize(io::AbstractSerializer, wc::WeakChunk) =
305308
error("Cannot serialize a WeakChunk")
309+
chunktype(c::WeakChunk) = chunktype(unwrap_weak_checked(c))
306310

307311
Base.@deprecate_binding AbstractPart Union{Chunk, Thunk}
308312
Base.@deprecate_binding Part Chunk

0 commit comments

Comments
 (0)