Skip to content

Commit 8dfd805

Browse files
authored
mapreduce: don't use broadcast when only dealing with a single arg. (#564)
1 parent bafce58 commit 8dfd805

File tree

2 files changed

+29
-20
lines changed

2 files changed

+29
-20
lines changed

src/host/broadcast.jl

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -95,9 +95,8 @@ end
9595
allequal(x) = true
9696
allequal(x, y, z...) = x == y && allequal(y, z...)
9797

98-
function Base.map(f, x::AnyGPUArray, xs::AbstractArray...)
98+
function Base.map(f, xs::AnyGPUArray...)
9999
# if argument sizes match, their shape needs to be preserved
100-
xs = (x, xs...)
101100
if allequal(size.(xs)...)
102101
return f.(xs...)
103102
end
@@ -112,12 +111,12 @@ function Base.map(f, x::AnyGPUArray, xs::AbstractArray...)
112111
# see `broadcast`
113112
ElType = BrokenBroadcast{ElType}
114113
end
115-
dest = similar(x, ElType, common_length)
114+
dest = similar(first(xs), ElType, common_length)
116115

117116
return map!(f, dest, xs...)
118117
end
119118

120-
function Base.map!(f, dest::AnyGPUArray, xs::AbstractArray...)
119+
function Base.map!(f, dest::AnyGPUArray, xs::AnyGPUArray...)
121120
# custom broadcast, ignoring the container size mismatches
122121
# (avoids the reshape + view that our mapreduce impl has to do)
123122
indices = LinearIndices.((dest, xs...))

src/host/mapreduce.jl

Lines changed: 26 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -31,19 +31,6 @@ Base.mapreduce(f, op, A::Broadcast.Broadcasted{<:AbstractGPUArrayStyle}, As::Abs
3131
dims=:, init=nothing) = _mapreduce(f, op, A, As...; dims=dims, init=init)
3232

3333
function _mapreduce(f::F, op::OP, As::Vararg{Any,N}; dims::D, init) where {F,OP,N,D}
34-
# mapreduce should apply `f` like `map` does, consuming elements like iterators
35-
bc = if allequal(size.(As)...)
36-
Broadcast.instantiate(Broadcast.broadcasted(f, As...))
37-
else
38-
# TODO: can we avoid the reshape + view?
39-
indices = LinearIndices.(As)
40-
common_length = minimum(length.(indices))
41-
Bs = map(As) do A
42-
view(reshape(A, length(A)), 1:common_length)
43-
end
44-
Broadcast.instantiate(Broadcast.broadcasted(f, Bs...))
45-
end
46-
4734
# figure out the destination container type by looking at the initializer element,
4835
# or by relying on inference to reason through the map and reduce functions
4936
if init === nothing
@@ -57,16 +44,39 @@ function _mapreduce(f::F, op::OP, As::Vararg{Any,N}; dims::D, init) where {F,OP,
5744
ET = typeof(init)
5845
end
5946

60-
sz = size(bc)
47+
# apply the mapping function to the input arrays
48+
if N == 1
49+
# ... with only a single input, we can defer this to the reduce step
50+
A = only(As)
51+
else
52+
# mapreduce should apply `f` like `map` does, consuming elements like iterators
53+
A = if allequal(size.(As)...)
54+
Broadcast.instantiate(Broadcast.broadcasted(f, As...))
55+
else
56+
# TODO: can we avoid the reshape + view?
57+
indices = LinearIndices.(As)
58+
common_length = minimum(length.(indices))
59+
Bs = map(As) do A
60+
view(reshape(A, length(A)), 1:common_length)
61+
end
62+
Broadcast.instantiate(Broadcast.broadcasted(f, Bs...))
63+
end
64+
f = identity
65+
end
66+
67+
# allocate an output container
68+
sz = size(A)
6169
red = ntuple(i->(dims==Colon() || i in dims) ? 1 : sz[i], length(sz))
62-
R = similar(bc, ET, red)
70+
R = similar(A, ET, red)
6371

72+
# perform the reduction
6473
if prod(sz) == 0
6574
fill!(R, init)
6675
else
67-
mapreducedim!(identity, op, R, bc; init=init)
76+
mapreducedim!(f, op, R, A; init)
6877
end
6978

79+
# return the result
7080
if dims === Colon()
7181
@allowscalar R[]
7282
else

0 commit comments

Comments
 (0)