Skip to content

Commit 6ff3631

Browse files
committed
Decouple sum/prod promotion from reduce
1 parent 8cba5cd commit 6ff3631

File tree

4 files changed

+75
-66
lines changed

4 files changed

+75
-66
lines changed

base/reduce.jl

Lines changed: 30 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -5,45 +5,29 @@
55
###### Generic (map)reduce functions ######
66

77
if Int === Int32
8-
const SmallSigned = Union{Int8,Int16}
9-
const SmallUnsigned = Union{UInt8,UInt16}
8+
const SmallSigned = Union{Int8,Int16}
9+
const SmallUnsigned = Union{UInt8,UInt16}
1010
else
11-
const SmallSigned = Union{Int8,Int16,Int32}
12-
const SmallUnsigned = Union{UInt8,UInt16,UInt32}
11+
const SmallSigned = Union{Int8,Int16,Int32}
12+
const SmallUnsigned = Union{UInt8,UInt16,UInt32}
1313
end
1414

15-
const CommonReduceResult = Union{UInt64,UInt128,Int64,Int128,Float16,Float32,Float64}
16-
const WidenReduceResult = Union{SmallSigned, SmallUnsigned}
17-
18-
promote_sys_size{T}(::Type{T}) = T
19-
promote_sys_size{T<:SmallSigned}(::Type{T}) = Int
20-
promote_sys_size{T<:SmallUnsigned}(::Type{T}) = UInt
21-
# r_promote_type: promote T to the type of reduce(op, ::Array{T})
22-
# (some "extra" methods are required here to avoid ambiguity warnings)
23-
r_promote_type(op, ::Type{T}) where {T} = T
24-
r_promote_type(op, ::Type{T}) where {T<:WidenReduceResult} = promote_sys_size(T)
25-
r_promote_type(::typeof(+), ::Type{T}) where {T<:WidenReduceResult} = promote_sys_size(T)
26-
r_promote_type(::typeof(*), ::Type{T}) where {T<:WidenReduceResult} = promote_sys_size(T)
27-
r_promote_type(::typeof(+), ::Type{T}) where {T<:Number} = typeof(zero(T)+zero(T))
28-
r_promote_type(::typeof(*), ::Type{T}) where {T<:Number} = typeof(one(T)*one(T))
29-
r_promote_type(::typeof(scalarmax), ::Type{T}) where {T<:WidenReduceResult} = T
30-
r_promote_type(::typeof(scalarmin), ::Type{T}) where {T<:WidenReduceResult} = T
31-
r_promote_type(::typeof(max), ::Type{T}) where {T<:WidenReduceResult} = T
32-
r_promote_type(::typeof(min), ::Type{T}) where {T<:WidenReduceResult} = T
33-
34-
# r_promote: promote x to the type of reduce(op, [x])
35-
r_promote(op, x::T) where {T} = convert(r_promote_type(op, T), x)
15+
# Certain reductions like sum and prod may wish to promote the items being
16+
# reduced over to an appropriate size.
17+
promote_sys_size(x) = x
18+
promote_sys_size(x::SmallSigned) = Int(x)
19+
promote_sys_size(x::SmallUnsigned) = UInt(x)
3620

3721
## foldl && mapfoldl
3822

3923
@noinline function mapfoldl_impl(f, op, v0, itr, i)
4024
# Unroll the while loop once; if v0 is known, the call to op may
4125
# be evaluated at compile time
4226
if done(itr, i)
43-
return r_promote(op, v0)
27+
return v0
4428
else
4529
(x, i) = next(itr, i)
46-
v = op(r_promote(op, v0), f(x))
30+
v = op(v0, f(x))
4731
while !done(itr, i)
4832
@inbounds (x, i) = next(itr, i)
4933
v = op(v, f(x))
@@ -108,10 +92,10 @@ function mapfoldr_impl(f, op, v0, itr, i::Integer)
10892
# Unroll the while loop once; if v0 is known, the call to op may
10993
# be evaluated at compile time
11094
if isempty(itr) || i == 0
111-
return r_promote(op, v0)
95+
return v0
11296
else
11397
x = itr[i]
114-
v = op(f(x), r_promote(op, v0))
98+
v = op(f(x), v0)
11599
while i > 1
116100
x = itr[i -= 1]
117101
v = op(f(x), v)
@@ -180,12 +164,12 @@ foldr(op, itr) = mapfoldr(identity, op, itr)
180164
@noinline function mapreduce_impl(f, op, A::AbstractArray, ifirst::Integer, ilast::Integer, blksize::Int)
181165
if ifirst == ilast
182166
@inbounds a1 = A[ifirst]
183-
return r_promote(op, f(a1))
167+
return f(a1)
184168
elseif ifirst + blksize > ilast
185169
# sequential portion
186170
@inbounds a1 = A[ifirst]
187171
@inbounds a2 = A[ifirst+1]
188-
v = op(r_promote(op, f(a1)), r_promote(op, f(a2)))
172+
v = op(f(a1), f(a2))
189173
@simd for i = ifirst + 2 : ilast
190174
@inbounds ai = A[i]
191175
v = op(v, f(ai))
@@ -245,13 +229,14 @@ pairwise_blocksize(::typeof(abs2), ::typeof(+)) = 4096
245229
# handling empty arrays
246230
_empty_reduce_error() = throw(ArgumentError("reducing over an empty collection is not allowed"))
247231
mr_empty(f, op, T) = _empty_reduce_error()
248-
# use zero(T)::T to improve type information when zero(T) is not defined
249-
mr_empty(::typeof(identity), op::typeof(+), T) = r_promote(op, zero(T)::T)
250-
mr_empty(::typeof(abs), op::typeof(+), T) = r_promote(op, abs(zero(T)::T))
251-
mr_empty(::typeof(abs2), op::typeof(+), T) = r_promote(op, abs2(zero(T)::T))
252-
mr_empty(::typeof(identity), op::typeof(*), T) = r_promote(op, one(T)::T)
253-
mr_empty(::typeof(abs), op::typeof(scalarmax), T) = abs(zero(T)::T)
254-
mr_empty(::typeof(abs2), op::typeof(scalarmax), T) = abs2(zero(T)::T)
232+
mr_empty(::typeof(identity), op::typeof(+), T) = zero(T)
233+
mr_empty(::typeof(abs), op::typeof(+), T) = abs(zero(T))
234+
mr_empty(::typeof(abs2), op::typeof(+), T) = abs2(zero(T))
235+
mr_empty(::typeof(identity), op::typeof(*), T) = one(T)
236+
mr_empty(::typeof(promote_sys_size), op, T) =
237+
promote_sys_size(mr_empty(identity, op, T))
238+
mr_empty(::typeof(abs), op::typeof(scalarmax), T) = abs(zero(T))
239+
mr_empty(::typeof(abs2), op::typeof(scalarmax), T) = abs2(zero(T))
255240
mr_empty(::typeof(abs), op::typeof(max), T) = mr_empty(abs, scalarmax, T)
256241
mr_empty(::typeof(abs2), op::typeof(max), T) = mr_empty(abs2, scalarmax, T)
257242
mr_empty(f, op::typeof(&), T) = true
@@ -271,12 +256,12 @@ function _mapreduce(f, op, ::IndexLinear, A::AbstractArray{T}) where T
271256
return mr_empty(f, op, T)
272257
elseif n == 1
273258
@inbounds a1 = A[inds[1]]
274-
return r_promote(op, f(a1))
259+
return f(a1)
275260
elseif n < 16 # process short array here, avoid mapreduce_impl() compilation
276261
@inbounds i = inds[1]
277262
@inbounds a1 = A[i]
278263
@inbounds a2 = A[i+=1]
279-
s = op(r_promote(op, f(a1)), r_promote(op, f(a2)))
264+
s = op(f(a1), f(a2))
280265
while i < last(inds)
281266
@inbounds Ai = A[i+=1]
282267
s = op(s, f(Ai))
@@ -353,7 +338,7 @@ julia> sum(abs2, [2; 3; 4])
353338
29
354339
```
355340
"""
356-
sum(f::Callable, a) = mapreduce(f, +, a)
341+
sum(f::Callable, a) = mapreduce(promote_sys_size f, +, a)
357342

358343
"""
359344
sum(itr)
@@ -365,7 +350,7 @@ julia> sum(1:20)
365350
210
366351
```
367352
"""
368-
sum(a) = mapreduce(identity, +, a)
353+
sum(a) = mapreduce(promote_sys_size, +, a)
369354
sum(a::AbstractArray{Bool}) = countnz(a)
370355

371356

@@ -380,7 +365,7 @@ summation algorithm for additional accuracy.
380365
"""
381366
function sum_kbn(A)
382367
T = _default_eltype(typeof(A))
383-
c = r_promote(+, zero(T)::T)
368+
c = promote_sys_size(zero(T)::T)
384369
i = start(A)
385370
if done(A, i)
386371
return c
@@ -411,7 +396,7 @@ julia> prod(abs2, [2; 3; 4])
411396
576
412397
```
413398
"""
414-
prod(f::Callable, a) = mapreduce(f, *, a)
399+
prod(f::Callable, a) = mapreduce(promote_sys_size f, *, a)
415400

416401
"""
417402
prod(itr)
@@ -423,7 +408,7 @@ julia> prod(1:20)
423408
2432902008176640000
424409
```
425410
"""
426-
prod(a) = mapreduce(identity, *, a)
411+
prod(a) = mapreduce(promote_sys_size, *, a)
427412

428413
## maximum & minimum
429414

base/reducedim.jl

Lines changed: 24 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -116,19 +116,22 @@ reducedim_init(f, op::typeof(|), A::AbstractArray, region) = reducedim_initarray
116116

117117
# specialize to make initialization more efficient for common cases
118118

119-
for (IT, RT) in ((CommonReduceResult, :(eltype(A))), (SmallSigned, :Int), (SmallUnsigned, :UInt))
120-
T = Union{[AbstractArray{t} for t in uniontypes(IT)]..., [AbstractArray{Complex{t}} for t in uniontypes(IT)]...}
121-
@eval begin
122-
reducedim_init(f::typeof(identity), op::typeof(+), A::$T, region) =
123-
reducedim_initarray(A, region, zero($RT))
124-
reducedim_init(f::typeof(identity), op::typeof(*), A::$T, region) =
125-
reducedim_initarray(A, region, one($RT))
126-
reducedim_init(f::Union{typeof(abs),typeof(abs2)}, op::typeof(+), A::$T, region) =
127-
reducedim_initarray(A, region, real(zero($RT)))
128-
reducedim_init(f::Union{typeof(abs),typeof(abs2)}, op::typeof(*), A::$T, region) =
129-
reducedim_initarray(A, region, real(one($RT)))
130-
end
119+
let
120+
const BitIntFloat = Union{BitInteger, Math.IEEEFloat}
121+
const T = Union{
122+
[AbstractArray{t} for t in uniontypes(BitIntFloat)]...,
123+
[AbstractArray{Complex{t}} for t in uniontypes(BitIntFloat)]...}
124+
125+
global reducedim_init(f::typeof(identity), op::typeof(+), A::T, region) =
126+
reducedim_initarray(A, region, zero(eltype(A)))
127+
global reducedim_init(f::typeof(identity), op::typeof(*), A::T, region) =
128+
reducedim_initarray(A, region, one(eltype(A)))
129+
global reducedim_init(f::Union{typeof(abs),typeof(abs2)}, op::typeof(+), A::T, region) =
130+
reducedim_initarray(A, region, real(zero(eltype(A))))
131+
global reducedim_init(f::Union{typeof(abs),typeof(abs2)}, op::typeof(*), A::T, region) =
132+
reducedim_initarray(A, region, real(one(eltype(A))))
131133
end
134+
132135
reducedim_init(f::Union{typeof(identity),typeof(abs),typeof(abs2)}, op::typeof(+), A::AbstractArray{Bool}, region) =
133136
reducedim_initarray(A, region, 0)
134137

@@ -574,14 +577,22 @@ any!(r, A)
574577
for (fname, op) in [(:sum, :+), (:prod, :*),
575578
(:maximum, :scalarmax), (:minimum, :scalarmin),
576579
(:all, :&), (:any, :|)]
580+
function compose_pss(x)
581+
if fname in [:sum, :prod]
582+
:(promote_sys_size $x)
583+
else
584+
x
585+
end
586+
end
577587
fname! = Symbol(fname, '!')
578588
@eval begin
589+
# TODO: check this
579590
$(fname!)(f::Function, r::AbstractArray, A::AbstractArray; init::Bool=true) =
580591
mapreducedim!(f, $(op), initarray!(r, $(op), init), A)
581592
$(fname!)(r::AbstractArray, A::AbstractArray; init::Bool=true) = $(fname!)(identity, r, A; init=init)
582593

583594
$(fname)(f::Function, A::AbstractArray, region) =
584-
mapreducedim(f, $(op), A, region)
595+
mapreducedim($(compose_pss(:f)), $(op), A, region)
585596
$(fname)(A::AbstractArray, region) = $(fname)(identity, A, region)
586597
end
587598
end

base/tuple.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -294,9 +294,13 @@ reverse(t::Tuple) = revargs(t...)
294294

295295
# TODO: these definitions cannot yet be combined, since +(x...)
296296
# where x might be any tuple matches too many methods.
297+
# TODO: this is inconsistent with the regular sum in cases where the arguments
298+
# require size promotion to system size.
297299
sum(x::Tuple{Any, Vararg{Any}}) = +(x...)
298300

299301
# NOTE: should remove, but often used on array sizes
302+
# TODO: this is inconsistent with the regular prod in cases where the arguments
303+
# require size promotion to system size.
300304
prod(x::Tuple{}) = 1
301305
prod(x::Tuple{Any, Vararg{Any}}) = *(x...)
302306

test/reduce.jl

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
# fold(l|r) & mapfold(l|r)
44
@test foldl(+, Int64[]) === Int64(0) # In reference to issues #7465/#20144 (PR #20160)
5-
@test foldl(+, Int16[]) === Int(0)
5+
@test foldl(+, Int16[]) === Int16(0) # In reference to issues #21536
66
@test foldl(-, 1:5) == -13
77
@test foldl(-, 10, 1:5) == -5
88

@@ -19,7 +19,7 @@
1919
@test Base.mapfoldl((x)-> x true, |, false, [true false true false false]) == true
2020

2121
@test foldr(+, Int64[]) === Int64(0) # In reference to issue #20144 (PR #20160)
22-
@test foldr(+, Int16[]) === Int(0)
22+
@test foldr(+, Int16[]) === Int16(0) # In reference to issues #21536
2323
@test foldr(-, 1:5) == 3
2424
@test foldr(-, 10, 1:5) == -7
2525
@test foldr(+, [1]) == 1 # Issue #21493
@@ -29,7 +29,7 @@
2929

3030
# reduce
3131
@test reduce(+, Int64[]) === Int64(0) # In reference to issue #20144 (PR #20160)
32-
@test reduce(+, Int16[]) === Int(0)
32+
@test reduce(+, Int16[]) === Int16(0) # In reference to issues #21536
3333
@test reduce((x,y)->"($x+$y)", 9:11) == "((9+10)+11)"
3434
@test reduce(max, [8 6 7 5 3 0 9]) == 9
3535
@test reduce(+, 1000, 1:5) == (1000 + 1 + 2 + 3 + 4 + 5)
@@ -69,12 +69,21 @@
6969
typeof(mapreduce(abs, +, Float32[10, 11, 12, 13]))
7070

7171
# sum
72+
@testset "sums promote to at least machine size" begin
73+
@testset for T in [Int8, Int16, Int32]
74+
@test sum(T[]) === Int(0)
75+
end
76+
@testset for T in [UInt8, UInt16, UInt32]
77+
@test sum(T[]) === UInt(0)
78+
end
79+
@testset for T in [Int, Int64, Int128, UInt, UInt64, UInt128,
80+
Float16, Float32, Float64]
81+
@test sum(T[]) === T(0)
82+
end
83+
@test sum(BigInt[]) == big(0) && sum(BigInt[]) isa BigInt
84+
end
7285

73-
@test sum(Int8[]) === Int(0)
74-
@test sum(Int[]) === Int(0)
75-
@test sum(Float64[]) === 0.0
76-
77-
@test sum(Int8(3)) === Int8(3)
86+
@test sum(Int8(3)) === Int(3)
7887
@test sum(3) === 3
7988
@test sum(3.0) === 3.0
8089

0 commit comments

Comments
 (0)