Skip to content

Commit ab4904c

Browse files
committed
Fix broadcasting on Julia 1.0
1 parent f8c26bd commit ab4904c

File tree

4 files changed

+25
-33
lines changed

4 files changed

+25
-33
lines changed

src/darray.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -568,9 +568,9 @@ function Base.reshape(A::DArray{T,1,S}, d::Dims) where {T,S<:Array}
568568
i2 = CartesianIndices(sztail)[i]
569569
globalidx = [ I[j][i2[j-1]] for j=2:nd ]
570570

571-
a = sub2ind(d, d1offs, globalidx...)
571+
a = LinearIndices(d)[d1offs, globalidx...]
572572

573-
B[:,i] = A[a:(a+nr-1)]
573+
B[:,i] = Array(A[a:(a+nr-1)])
574574
end
575575
B
576576
end
@@ -706,15 +706,15 @@ end
706706
Base.size(P::ProductIndices) = P.sz
707707
# This gets passed to map to avoid breaking propagation of inbounds
708708
Base.@propagate_inbounds propagate_getindex(A, I...) = A[I...]
709-
Base.@propagate_inbounds Base.getindex(P::ProductIndices{_,N}, I::Vararg{Int, N}) where {_,N} =
709+
Base.@propagate_inbounds Base.getindex(P::ProductIndices{J,N}, I::Vararg{Int, N}) where {J,N} =
710710
Bool((&)(map(propagate_getindex, P.indices, I)...))
711711

712712
struct MergedIndices{I,N} <: AbstractArray{CartesianIndex{N}, N}
713713
indices::I
714714
sz::NTuple{N,Int}
715715
end
716716
Base.size(M::MergedIndices) = M.sz
717-
Base.@propagate_inbounds Base.getindex(M::MergedIndices{_,N}, I::Vararg{Int, N}) where {_,N} =
717+
Base.@propagate_inbounds Base.getindex(M::MergedIndices{J,N}, I::Vararg{Int, N}) where {J,N} =
718718
CartesianIndex(map(propagate_getindex, M.indices, I))
719719
# Additionally, we optimize bounds checking when using MergedIndices as an
720720
# array index since checking, e.g., A[1:500, 1:500] is *way* faster than

src/mapreduce.jl

Lines changed: 16 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -20,37 +20,25 @@ end
2020
Base.BroadcastStyle(::Type{<:DArray}) = Broadcast.ArrayStyle{DArray}()
2121
Base.BroadcastStyle(::Type{<:DArray}, ::Any) = Broadcast.ArrayStyle{DArray}()
2222

23-
function Base.similar(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{DArray}}, ::Type{ElType}) where {ElType}
24-
DA = find_darray(bc)
25-
DArray(I -> Array{ElType}(undef, map(length,I)), DA)
26-
end
27-
28-
"`DA = find_darray(As)` returns the first DArray among the arguments."
29-
find_darray(bc::Base.Broadcast.Broadcasted) = find_darray(bc.args)
30-
find_darray(args::Tuple) = find_darray(find_darray(args[1]), Base.tail(args))
31-
find_darray(x) = x
32-
find_darray(a::DArray, rest) = a
33-
find_darray(::Any, rest) = find_darray(rest)
34-
35-
function Base.copyto!(dest::DArray, bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{DArray}})
36-
@sync for p in procs(dest)
37-
@async remotecall_fetch(p) do
38-
copyto!(localpart(dest), rewrite_local(bc))
39-
end
23+
function Base.copy(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{DArray}})
24+
T = Base.Broadcast.combine_eltypes(bc.f, bc.args)
25+
shape = Base.Broadcast.combine_axes(bc.args...)
26+
iter = Base.CartesianIndices(shape)
27+
D = DArray(map(length, shape)) do I
28+
A = map(bc.args) do a
29+
if isa(a, Union{Number,Ref})
30+
return a
31+
else
32+
return localtype(a)(
33+
a[ntuple(i -> i > ndims(a) ? 1 : (size(a, i) == 1 ? (1:1) : I[i]), length(shape))...]
34+
)
35+
end
36+
end
37+
broadcast(bc.f, A...)
4038
end
41-
dest
39+
return D
4240
end
4341

44-
"""
45-
Transform a Broadcasted{Broadcast.ArrayStyle{DArray}} object into an equivalent
46-
Broadcasted{Broadcast.DefaultArrayStyle} object for the localparts.
47-
"""
48-
rewrite_local(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{DArray}}) = Broadcast.broadcasted(bc.f, rewrite_local(bc.args)...)
49-
rewrite_local(args::Tuple) = map(rewrite_local, args)
50-
rewrite_local(a::DArray) = localpart(a)
51-
rewrite_local(x) = x
52-
53-
5442
function Base.reduce(f, d::DArray)
5543
results = asyncmap(procs(d)) do p
5644
remotecall_fetch(p) do

test/darray.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -835,6 +835,10 @@ check_leaks()
835835
c = a .- m
836836
d = convert(Array, a) .- convert(Array, m)
837837
@test c == d
838+
e = @DArray [ones(10) for i=1:4]
839+
f = 2 .* e
840+
@test Array(f) == 2 .* Array(e)
841+
@test Array(map(x -> sum(x) .+ 2, e)) == map(x -> sum(x) .+ 2, e)
838842
d_closeall()
839843
end
840844

test/runtests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ end
2020
@everywhere using Random
2121
@everywhere using LinearAlgebra
2222

23-
@everywhere srand(1234 + myid())
23+
@everywhere Random.seed!(1234 + myid())
2424

2525
const MYID = myid()
2626
const OTHERIDS = filter(id-> id != MYID, procs())[rand(1:(nprocs()-1))]

0 commit comments

Comments
 (0)