@@ -31,19 +31,6 @@ Base.mapreduce(f, op, A::Broadcast.Broadcasted{<:AbstractGPUArrayStyle}, As::Abs
31
31
dims= :, init= nothing ) = _mapreduce (f, op, A, As... ; dims= dims, init= init)
32
32
33
33
function _mapreduce (f:: F , op:: OP , As:: Vararg{Any,N} ; dims:: D , init) where {F,OP,N,D}
34
- # mapreduce should apply `f` like `map` does, consuming elements like iterators
35
- bc = if allequal (size .(As)... )
36
- Broadcast. instantiate (Broadcast. broadcasted (f, As... ))
37
- else
38
- # TODO : can we avoid the reshape + view?
39
- indices = LinearIndices .(As)
40
- common_length = minimum (length .(indices))
41
- Bs = map (As) do A
42
- view (reshape (A, length (A)), 1 : common_length)
43
- end
44
- Broadcast. instantiate (Broadcast. broadcasted (f, Bs... ))
45
- end
46
-
47
34
# figure out the destination container type by looking at the initializer element,
48
35
# or by relying on inference to reason through the map and reduce functions
49
36
if init === nothing
@@ -57,16 +44,39 @@ function _mapreduce(f::F, op::OP, As::Vararg{Any,N}; dims::D, init) where {F,OP,
57
44
ET = typeof (init)
58
45
end
59
46
60
- sz = size (bc)
47
+ # apply the mapping function to the input arrays
48
+ if N == 1
49
+ # ... with only a single input, we can defer this to the reduce step
50
+ A = only (As)
51
+ else
52
+ # mapreduce should apply `f` like `map` does, consuming elements like iterators
53
+ A = if allequal (size .(As)... )
54
+ Broadcast. instantiate (Broadcast. broadcasted (f, As... ))
55
+ else
56
+ # TODO : can we avoid the reshape + view?
57
+ indices = LinearIndices .(As)
58
+ common_length = minimum (length .(indices))
59
+ Bs = map (As) do A
60
+ view (reshape (A, length (A)), 1 : common_length)
61
+ end
62
+ Broadcast. instantiate (Broadcast. broadcasted (f, Bs... ))
63
+ end
64
+ f = identity
65
+ end
66
+
67
+ # allocate an output container
68
+ sz = size (A)
61
69
red = ntuple (i-> (dims== Colon () || i in dims) ? 1 : sz[i], length (sz))
62
- R = similar (bc , ET, red)
70
+ R = similar (A , ET, red)
63
71
72
+ # perform the reduction
64
73
if prod (sz) == 0
65
74
fill! (R, init)
66
75
else
67
- mapreducedim! (identity , op, R, bc; init = init)
76
+ mapreducedim! (f , op, R, A; init)
68
77
end
69
78
79
+ # return the result
70
80
if dims === Colon ()
71
81
@allowscalar R[]
72
82
else
0 commit comments