Skip to content

Commit 4c19973

Browse files
committed
Add mapreduce_single function
Since the demise of `r_promote` in #22825, there is now a type-instability in `mapreduce` if the operator does not give an element of the same type as the input. This arose during my implementation of Kahan summation using a reduction operator, see: JuliaMath/KahanSummation.jl#7 This adds a `mapreduce_single` function which defines what the result should be in these cases.
1 parent 2043060 commit 4c19973

File tree

4 files changed

+117
-64
lines changed

4 files changed

+117
-64
lines changed

base/reduce.jl

Lines changed: 97 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -12,17 +12,25 @@ else
1212
const SmallUnsigned = Union{UInt8,UInt16,UInt32}
1313
end
1414

15-
# Certain reductions like sum and prod may wish to promote the items being reduced over to
16-
# an appropriate size. Note we need x + zero(x) because some types like Bool have their sum
17-
# lie in a larger type.
18-
promote_sys_size(T::Type) = T
19-
promote_sys_size(::Type{<:SmallSigned}) = Int
20-
promote_sys_size(::Type{<:SmallUnsigned}) = UInt
21-
22-
promote_sys_size_add(x) = convert(promote_sys_size(typeof(x + zero(x))), x)
23-
promote_sys_size_mul(x) = convert(promote_sys_size(typeof(x * one(x))), x)
24-
const _PromoteSysSizeFunction = Union{typeof(promote_sys_size_add),
25-
typeof(promote_sys_size_mul)}
15+
"""
16+
Base.add_sum(x,y)
17+
18+
The reduction operator used in `sum`. The main difference from [`+`](@ref) is that small
19+
integers are promoted to `Int`/`UInt`.
20+
"""
21+
add_sum(x,y) = x + y
22+
add_sum(x::SmallSigned,y::SmallSigned) = Int(x) + Int(y)
23+
add_sum(x::SmallUnsigned,y::SmallUnsigned) = UInt(x) + UInt(y)
24+
25+
"""
26+
Base.mul_prod(x,y)
27+
28+
The reduction operator used in `prod`. The main difference from [`*`](@ref) is that small
29+
integers are promoted to `Int`/`UInt`.
30+
"""
31+
mul_prod(x,y) = x * y
32+
mul_prod(x::SmallSigned,y::SmallSigned) = Int(x) * Int(y)
33+
mul_prod(x::SmallUnsigned,y::SmallUnsigned) = UInt(x) * UInt(y)
2634

2735
## foldl && mapfoldl
2836

@@ -64,7 +72,7 @@ function mapfoldl(f, op, itr)
6472
return Base.mapreduce_empty_iter(f, op, itr, iteratoreltype(itr))
6573
end
6674
(x, i) = next(itr, i)
67-
v0 = f(x)
75+
v0 = mapreduce_first(f, op, x)
6876
mapfoldl_impl(f, op, v0, itr, i)
6977
end
7078

@@ -133,7 +141,7 @@ function mapfoldr(f, op, itr)
133141
if isempty(itr)
134142
return Base.mapreduce_empty_iter(f, op, itr, iteratoreltype(itr))
135143
end
136-
return mapfoldr_impl(f, op, f(itr[i]), itr, i-1)
144+
return mapfoldr_impl(f, op, mapreduce_first(f, op, itr[i]), itr, i-1)
137145
end
138146

139147
"""
@@ -174,7 +182,7 @@ foldr(op, itr) = mapfoldr(identity, op, itr)
174182
@noinline function mapreduce_impl(f, op, A::AbstractArray, ifirst::Integer, ilast::Integer, blksize::Int)
175183
if ifirst == ilast
176184
@inbounds a1 = A[ifirst]
177-
return f(a1)
185+
return mapreduce_first(f, op, a1)
178186
elseif ifirst + blksize > ilast
179187
# sequential portion
180188
@inbounds a1 = A[ifirst]
@@ -238,34 +246,88 @@ pairwise_blocksize(::typeof(abs2), ::typeof(+)) = 4096
238246

239247
# handling empty arrays
240248
_empty_reduce_error() = throw(ArgumentError("reducing over an empty collection is not allowed"))
249+
250+
"""
251+
Base.reduce_empty(op, T)
252+
253+
The value to be returned when calling [`reduce`](@ref), [`foldl`](@ref) or [`foldr`](@ref)
254+
with reduction `op` over an empty array with element type of `T`.
255+
256+
If not defined, this will throw an `ArgumentError`.
257+
"""
241258
reduce_empty(op, T) = _empty_reduce_error()
242259
reduce_empty(::typeof(+), T) = zero(T)
260+
reduce_empty(::typeof(+), ::Type{Bool}) = zero(Int)
243261
reduce_empty(::typeof(*), T) = one(T)
262+
reduce_empty(::typeof(*), ::Type{Char}) = ""
244263
reduce_empty(::typeof(&), ::Type{Bool}) = true
245264
reduce_empty(::typeof(|), ::Type{Bool}) = false
246265

266+
reduce_empty(::typeof(add_sum), T) = reduce_empty(+, T)
267+
reduce_empty(::typeof(add_sum), ::Type{T}) where {T<:SmallSigned} = zero(Int)
268+
reduce_empty(::typeof(add_sum), ::Type{T}) where {T<:SmallUnsigned} = zero(UInt)
269+
reduce_empty(::typeof(mul_prod), T) = reduce_empty(*, T)
270+
reduce_empty(::typeof(mul_prod), ::Type{T}) where {T<:SmallSigned} = one(Int)
271+
reduce_empty(::typeof(mul_prod), ::Type{T}) where {T<:SmallUnsigned} = one(UInt)
272+
273+
"""
274+
Base.mapreduce_empty(f, op, T)
275+
276+
The value to be returned when calling [`mapreduce`](@ref), [`mapfoldl`](@ref`) or
277+
[`mapfoldr`](@ref) with map `f` and reduction `op` over an empty array with element type
278+
of `T`.
279+
280+
If not defined, this will throw an `ArgumentError`.
281+
"""
247282
mapreduce_empty(f, op, T) = _empty_reduce_error()
248283
mapreduce_empty(::typeof(identity), op, T) = reduce_empty(op, T)
249-
mapreduce_empty(f::_PromoteSysSizeFunction, op, T) =
250-
f(mapreduce_empty(identity, op, T))
251-
mapreduce_empty(::typeof(abs), ::typeof(+), T) = abs(zero(T))
252-
mapreduce_empty(::typeof(abs2), ::typeof(+), T) = abs2(zero(T))
253-
mapreduce_empty(::typeof(abs), ::Union{typeof(scalarmax), typeof(max)}, T) =
254-
abs(zero(T))
255-
mapreduce_empty(::typeof(abs2), ::Union{typeof(scalarmax), typeof(max)}, T) =
256-
abs2(zero(T))
257-
258-
# Allow mapreduce_empty to “see through” promote_sys_size
259-
let ComposedFunction = typename(typeof(identity identity)).wrapper
260-
global mapreduce_empty(f::ComposedFunction{<:_PromoteSysSizeFunction}, op, T) =
261-
f.f(mapreduce_empty(f.g, op, T))
262-
end
284+
mapreduce_empty(::typeof(abs), op, T) = abs(reduce_empty(op, T))
285+
mapreduce_empty(::typeof(abs2), op, T) = abs2(reduce_empty(op, T))
286+
287+
mapreduce_empty(f::typeof(abs), ::Union{typeof(scalarmax), typeof(max)}, T) = abs(zero(T))
288+
mapreduce_empty(f::typeof(abs2), ::Union{typeof(scalarmax), typeof(max)}, T) = abs2(zero(T))
263289

264290
mapreduce_empty_iter(f, op, itr, ::HasEltype) = mapreduce_empty(f, op, eltype(itr))
265291
mapreduce_empty_iter(f, op::typeof(&), itr, ::EltypeUnknown) = true
266292
mapreduce_empty_iter(f, op::typeof(|), itr, ::EltypeUnknown) = false
267293
mapreduce_empty_iter(f, op, itr, ::EltypeUnknown) = _empty_reduce_error()
268294

295+
# handling of single-element iterators
296+
"""
297+
Base.reduce_first(op, x)
298+
299+
The value to be returned when calling [`reduce`](@ref), [`foldl`](@ref`) or
300+
[`foldr`](@ref) with reduction `op` over an iterator which contains a single element
301+
`x`. This value may also used to initialise the recursion, so that `reduce(op, [x, y])`
302+
may call `op(reduce_first(op, x), y)`.
303+
304+
The default is `x` for most types. The main purpose is to ensure type stability, so
305+
additional methods should only be defined for cases where `op` gives a result with
306+
different types than its inputs.
307+
"""
308+
reduce_first(op, x) = x
309+
reduce_first(::typeof(+), x::Bool) = Int(x)
310+
reduce_first(::typeof(*), x::Char) = string(x)
311+
312+
reduce_first(::typeof(add_sum), x) = reduce_first(+, x)
313+
reduce_first(::typeof(add_sum), x::SmallSigned) = Int(x)
314+
reduce_first(::typeof(add_sum), x::SmallUnsigned) = UInt(x)
315+
reduce_first(::typeof(mul_prod), x) = reduce_first(*, x)
316+
reduce_first(::typeof(mul_prod), x::SmallSigned) = Int(x)
317+
reduce_first(::typeof(mul_prod), x::SmallUnsigned) = UInt(x)
318+
319+
"""
320+
Base.mapreduce_first(f, op, x)
321+
322+
The value to be returned when calling [`mapreduce`](@ref), [`mapfoldl`](@ref`) or
323+
[`mapfoldr`](@ref) with map `f` and reduction `op` over an iterator which contains a
324+
single element `x`. This value may also used to initialise the recursion, so that
325+
`mapreduce(f, op, [x, y])` may call `op(reduce_first(op, f, x), f(y))`.
326+
327+
The default is `reduce_first(op, f(x))`.
328+
"""
329+
mapreduce_first(f, op, x) = reduce_first(op, f(x))
330+
269331
_mapreduce(f, op, A::AbstractArray) = _mapreduce(f, op, IndexStyle(A), A)
270332

271333
function _mapreduce(f, op, ::IndexLinear, A::AbstractArray{T}) where T
@@ -275,7 +337,7 @@ function _mapreduce(f, op, ::IndexLinear, A::AbstractArray{T}) where T
275337
return mapreduce_empty(f, op, T)
276338
elseif n == 1
277339
@inbounds a1 = A[inds[1]]
278-
return f(a1)
340+
return mapreduce_first(f, op, a1)
279341
elseif n < 16 # process short array here, avoid mapreduce_impl() compilation
280342
@inbounds i = inds[1]
281343
@inbounds a1 = A[i]
@@ -294,7 +356,7 @@ end
294356
_mapreduce(f, op, ::IndexCartesian, A::AbstractArray) = mapfoldl(f, op, A)
295357

296358
mapreduce(f, op, A::AbstractArray) = _mapreduce(f, op, IndexStyle(A), A)
297-
mapreduce(f, op, a::Number) = f(a)
359+
mapreduce(f, op, a::Number) = mapreduce_first(f, op, a)
298360

299361
"""
300362
reduce(op, v0, itr)
@@ -372,7 +434,7 @@ In the former case, the integers are widened to system word size and therefore
372434
the result is 128. In the latter case, no such widening happens and integer
373435
overflow results in -128.
374436
"""
375-
sum(f::Callable, a) = mapreduce(promote_sys_size_add f, +, a)
437+
sum(f, a) = mapreduce(f, add_sum, a)
376438

377439
"""
378440
sum(itr)
@@ -388,7 +450,7 @@ julia> sum(1:20)
388450
210
389451
```
390452
"""
391-
sum(a) = mapreduce(promote_sys_size_add, +, a)
453+
sum(a) = sum(identity, a)
392454
sum(a::AbstractArray{Bool}) = count(a)
393455

394456
## prod
@@ -406,7 +468,7 @@ julia> prod(abs2, [2; 3; 4])
406468
576
407469
```
408470
"""
409-
prod(f::Callable, a) = mapreduce(promote_sys_size_mul f, *, a)
471+
prod(f, a) = mapreduce(f, mul_prod, a)
410472

411473
"""
412474
prod(itr)
@@ -422,7 +484,7 @@ julia> prod(1:20)
422484
2432902008176640000
423485
```
424486
"""
425-
prod(a) = mapreduce(promote_sys_size_mul, *, a)
487+
prod(a) = mapreduce(identity, mul_prod, a)
426488

427489
## maximum & minimum
428490

@@ -433,7 +495,7 @@ function mapreduce_impl(f, op::Union{typeof(scalarmax),
433495
A::AbstractArray, first::Int, last::Int)
434496
# locate the first non NaN number
435497
@inbounds a1 = A[first]
436-
v = f(a1)
498+
v = mapreduce_first(f, op, a1)
437499
i = first + 1
438500
while (v == v) && (i <= last)
439501
@inbounds ai = A[i]

base/reducedim.jl

Lines changed: 11 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,8 @@ end
6262
###### Generic reduction functions #####
6363

6464
## initialization
65-
66-
for (Op, initfun) in ((:(typeof(+)), :zero), (:(typeof(*)), :one))
65+
# initarray! is only called by sum!, prod!, etc.
66+
for (Op, initfun) in ((:(typeof(add_sum)), :zero), (:(typeof(mul_prod)), :one))
6767
@eval initarray!(a::AbstractArray{T}, ::$(Op), init::Bool, src::AbstractArray) where {T} = (init && fill!(a, $(initfun)(T)); a)
6868
end
6969

@@ -75,6 +75,7 @@ for (Op, initval) in ((:(typeof(&)), true), (:(typeof(|)), false))
7575
@eval initarray!(a::AbstractArray, ::$(Op), init::Bool, src::AbstractArray) = (init && fill!(a, $initval); a)
7676
end
7777

78+
# reducedim_initarray is called by
7879
reducedim_initarray(A::AbstractArray, region, v0, ::Type{R}) where {R} = fill!(similar(A,R,reduced_indices(A,region)), v0)
7980
reducedim_initarray(A::AbstractArray, region, v0::T) where {T} = reducedim_initarray(A, region, v0, T)
8081

@@ -104,10 +105,10 @@ reducedim_initarray0_empty(A::AbstractArray, region,::typeof(identity), ops) = m
104105
promote_union(T::Union) = promote_type(promote_union(T.a), promote_union(T.b))
105106
promote_union(T) = T
106107

107-
function reducedim_init(f, op::typeof(+), A::AbstractArray, region)
108+
function reducedim_init(f, op::Union{typeof(+),typeof(add_sum)}, A::AbstractArray, region)
108109
_reducedim_init(f, op, zero, sum, A, region)
109110
end
110-
function reducedim_init(f, op::typeof(*), A::AbstractArray, region)
111+
function reducedim_init(f, op::Union{typeof(*),typeof(mul_prod)}, A::AbstractArray, region)
111112
_reducedim_init(f, op, one, prod, A, region)
112113
end
113114
function _reducedim_init(f, op, fv, fop, A, region)
@@ -143,19 +144,12 @@ let
143144
[AbstractArray{t} for t in uniontypes(BitIntFloat)]...,
144145
[AbstractArray{Complex{t}} for t in uniontypes(BitIntFloat)]...}
145146

146-
global reducedim_init(f::typeof(identity), op::typeof(+), A::T, region) =
147-
reducedim_initarray(A, region, zero(eltype(A)))
148-
global reducedim_init(f::typeof(identity), op::typeof(*), A::T, region) =
149-
reducedim_initarray(A, region, one(eltype(A)))
150-
global reducedim_init(f::Union{typeof(abs),typeof(abs2)}, op::typeof(+), A::T, region) =
151-
reducedim_initarray(A, region, real(zero(eltype(A))))
152-
global reducedim_init(f::Union{typeof(abs),typeof(abs2)}, op::typeof(*), A::T, region) =
153-
reducedim_initarray(A, region, real(one(eltype(A))))
147+
global reducedim_init(f, op::Union{typeof(+),typeof(add_sum)}, A::T, region) =
148+
reducedim_initarray(A, region, mapreduce_first(f, op, zero(eltype(A))))
149+
global reducedim_init(f, op::Union{typeof(*),typeof(mul_prod)}, A::T, region) =
150+
reducedim_initarray(A, region, mapreduce_first(f, op, one(eltype(A))))
154151
end
155152

156-
reducedim_init(f::Union{typeof(identity),typeof(abs),typeof(abs2)}, op::typeof(+), A::AbstractArray{Bool}, region) =
157-
reducedim_initarray(A, region, 0)
158-
159153
## generic (map)reduction
160154

161155
has_fast_linear_indexing(a::AbstractArray) = false
@@ -610,26 +604,17 @@ julia> any!([1 1], A)
610604
"""
611605
any!(r, A)
612606

613-
for (fname, op) in [(:sum, :+), (:prod, :*),
607+
for (fname, op) in [(:sum, :add_sum), (:prod, :mul_prod),
614608
(:maximum, :scalarmax), (:minimum, :scalarmin),
615609
(:all, :&), (:any, :|)]
616-
function compose_promote_sys_size(x)
617-
if fname === :sum
618-
:(promote_sys_size_add $x)
619-
elseif fname === :prod
620-
:(promote_sys_size_mul $x)
621-
else
622-
x
623-
end
624-
end
625610
fname! = Symbol(fname, '!')
626611
@eval begin
627612
$(fname!)(f::Function, r::AbstractArray, A::AbstractArray; init::Bool=true) =
628613
mapreducedim!(f, $(op), initarray!(r, $(op), init, A), A)
629614
$(fname!)(r::AbstractArray, A::AbstractArray; init::Bool=true) = $(fname!)(identity, r, A; init=init)
630615

631616
$(fname)(f::Function, A::AbstractArray, region) =
632-
mapreducedim($(compose_promote_sys_size(:f)), $(op), A, region)
617+
mapreducedim(f, $(op), A, region)
633618
$(fname)(A::AbstractArray, region) = $(fname)(identity, A, region)
634619
end
635620
end

test/reduce.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -391,3 +391,8 @@ test18695(r) = sum( t^2 for t in r )
391391
# test neutral element not picked incorrectly for &, |
392392
@test @inferred(foldl(&, Int[1])) === 1
393393
@test_throws ArgumentError foldl(&, Int[])
394+
395+
# prod on Chars
396+
@test prod(Char[]) == ""
397+
@test prod(Char['a']) == "a"
398+
@test prod(Char['a','b']) == "ab"

test/reducedim.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -339,9 +339,10 @@ for region in Any[-1, 0, (-1, 2), [0, 1], (1,-2,3), [0 1;
339339
end
340340

341341
# check type of result
342-
under_test = [UInt8, Int8, Int32, Int64, BigInt]
343-
@testset "type of sum(::Array{$T}" for T in under_test
342+
@testset "type of sum(::Array{$T}" for T in [UInt8, Int8, Int32, Int64, BigInt]
344343
result = sum(T[1 2 3; 4 5 6; 7 8 9], 2)
345344
@test result == hcat([6, 15, 24])
346-
@test eltype(result) === typeof(Base.promote_sys_size_add(zero(T)))
345+
@test eltype(result) === (T <: Base.SmallSigned ? Int :
346+
T <: Base.SmallUnsigned ? UInt :
347+
T)
347348
end

0 commit comments

Comments
 (0)