Skip to content

Commit 937b954

Browse files
committed
simplify sum, prod behaviour
1 parent df2820a commit 937b954

File tree

1 file changed

+40
-33
lines changed

1 file changed

+40
-33
lines changed

base/reduce.jl

Lines changed: 40 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -13,16 +13,14 @@ else
1313
end
1414

1515
# Certain reductions like sum and prod may wish to promote the items being reduced over to
16-
# an appropriate size. Note we need x + zero(x) because some types like Bool have their sum
17-
# lie in a larger type.
18-
promote_sys_size(T::Type) = T
19-
promote_sys_size(::Type{<:SmallSigned}) = Int
20-
promote_sys_size(::Type{<:SmallUnsigned}) = UInt
16+
# a wider type.
17+
add_tosys(x,y) = x + y
18+
add_tosys(x::SmallSigned,y::SmallSigned) = Int(x) + Int(y)
19+
add_tosys(x::SmallUnsigned,y::SmallUnsigned) = UInt(x) + UInt(y)
2120

22-
promote_sys_size_add(x) = convert(promote_sys_size(typeof(x + zero(x))), x)
23-
promote_sys_size_mul(x) = convert(promote_sys_size(typeof(x * one(x))), x)
24-
const _PromoteSysSizeFunction = Union{typeof(promote_sys_size_add),
25-
typeof(promote_sys_size_mul)}
21+
mul_tosys(x,y) = x * y
22+
mul_tosys(x::SmallSigned,y::SmallSigned) = Int(x) * Int(y)
23+
mul_tosys(x::SmallUnsigned,y::SmallUnsigned) = UInt(x) * UInt(y)
2624

2725
## foldl && mapfoldl
2826

@@ -253,6 +251,12 @@ reduce_empty(::typeof(*), T) = one(T)
253251
reduce_empty(::typeof(&), ::Type{Bool}) = true
254252
reduce_empty(::typeof(|), ::Type{Bool}) = false
255253

254+
reduce_empty(::typeof(add_tosys), T) = zero(T)
255+
reduce_empty(::typeof(add_tosys), ::Type{T}) where {T<:SmallSigned} = zero(Int)
256+
reduce_empty(::typeof(add_tosys), ::Type{T}) where {T<:SmallUnsigned} = zero(UInt)
257+
reduce_empty(::typeof(mul_tosys), T) = one(T)
258+
reduce_empty(::typeof(mul_tosys), ::Type{T}) where {T<:SmallSigned} = one(Int)
259+
reduce_empty(::typeof(mul_tosys), ::Type{T}) where {T<:SmallUnsigned} = one(UInt)
256260

257261
"""
258262
Base.mapreduce_empty(f, op, T)
@@ -263,39 +267,42 @@ array with element type of `T`.
263267
If not defined, this will throw an `ArgumentError`.
264268
"""
265269
mapreduce_empty(f, op, T) = _empty_reduce_error()
266-
mapreduce_empty(::typeof(identity), op, T) = reduce_empty(op, T)
267-
mapreduce_empty(f::_PromoteSysSizeFunction, op, T) =
268-
f(mapreduce_empty(identity, op, T))
269-
mapreduce_empty(::typeof(abs), ::typeof(+), T) = abs(zero(T))
270-
mapreduce_empty(::typeof(abs2), ::typeof(+), T) = abs2(zero(T))
271-
mapreduce_empty(::typeof(abs), ::Union{typeof(scalarmax), typeof(max)}, T) =
272-
abs(zero(T))
273-
mapreduce_empty(::typeof(abs2), ::Union{typeof(scalarmax), typeof(max)}, T) =
274-
abs2(zero(T))
275-
276-
# Allow mapreduce_empty to “see through” promote_sys_size
277-
let ComposedFunction = typename(typeof(identity identity)).wrapper
278-
global mapreduce_empty(f::ComposedFunction{<:_PromoteSysSizeFunction}, op, T) =
279-
f.f(mapreduce_empty(f.g, op, T))
280-
end
270+
mapreduce_empty(f::typeof(identity), op, T) = f(reduce_empty(op, T))
271+
mapreduce_empty(f::typeof(abs), op, T) = f(reduce_empty(op, T))
272+
mapreduce_empty(f::typeof(abs2), op, T) = f(reduce_empty(op, T))
273+
274+
mapreduce_empty(f::typeof(abs), ::Union{typeof(scalarmax), typeof(max)}, T) = abs(zero(T))
275+
mapreduce_empty(f::typeof(abs2), ::Union{typeof(scalarmax), typeof(max)}, T) = abs2(zero(T))
281276

282277
mapreduce_empty_iter(f, op, itr, ::HasEltype) = mapreduce_empty(f, op, eltype(itr))
283278
mapreduce_empty_iter(f, op::typeof(&), itr, ::EltypeUnknown) = true
284279
mapreduce_empty_iter(f, op::typeof(|), itr, ::EltypeUnknown) = false
285280
mapreduce_empty_iter(f, op, itr, ::EltypeUnknown) = _empty_reduce_error()
286281

287282
# handling of single-element iterators
283+
"""
284+
Base.reduce_single(f, op, x)
285+
286+
The value to be returned when calling [`reduce`] with `op` over an iterator which contains
287+
a single element `x`.
288+
289+
The default is `x`.
290+
"""
291+
reduce_single(op, x) = x
292+
reduce_single(::typeof(add_tosys), x::SmallSigned) = Int(x)
293+
reduce_single(::typeof(add_tosys), x::SmallUnsigned) = UInt(x)
294+
reduce_single(::typeof(mul_tosys), x::SmallSigned) = Int(x)
295+
reduce_single(::typeof(mul_tosys), x::SmallUnsigned) = UInt(x)
296+
288297
"""
289298
Base.mapreduce_single(f, op, x)
290299
291300
The value to be returned when calling [`mapreduce`] with `f` and `op` over an iterator
292301
which contains a single element `x`.
293302
294-
The default is `f(x)`.
303+
The default is `f(reduce_single(op, x))`.
295304
"""
296-
mapreduce_single(f, op, x) = f(x)
297-
298-
305+
mapreduce_single(f, op, x) = f(reduce_single(op, x))
299306

300307
_mapreduce(f, op, A::AbstractArray) = _mapreduce(f, op, IndexStyle(A), A)
301308

@@ -325,7 +332,7 @@ end
325332
_mapreduce(f, op, ::IndexCartesian, A::AbstractArray) = mapfoldl(f, op, A)
326333

327334
mapreduce(f, op, A::AbstractArray) = _mapreduce(f, op, IndexStyle(A), A)
328-
mapreduce(f, op, a::Number) = f(a)
335+
mapreduce(f, op, a::Number) = mapreduce_single(f, op, a)
329336

330337
"""
331338
reduce(op, v0, itr)
@@ -403,7 +410,7 @@ In the former case, the integers are widened to system word size and therefore
403410
the result is 128. In the latter case, no such widening happens and integer
404411
overflow results in -128.
405412
"""
406-
sum(f::Callable, a) = mapreduce(promote_sys_size_add f, +, a)
413+
sum(f, a) = mapreduce(f, add_tosys, a)
407414

408415
"""
409416
sum(itr)
@@ -419,7 +426,7 @@ julia> sum(1:20)
419426
210
420427
```
421428
"""
422-
sum(a) = mapreduce(promote_sys_size_add, +, a)
429+
sum(a) = sum(identity, a)
423430
sum(a::AbstractArray{Bool}) = count(a)
424431

425432
## prod
@@ -437,7 +444,7 @@ julia> prod(abs2, [2; 3; 4])
437444
576
438445
```
439446
"""
440-
prod(f::Callable, a) = mapreduce(promote_sys_size_mul f, *, a)
447+
prod(f::Callable, a) = mapreduce(f, mul_tosys, a)
441448

442449
"""
443450
prod(itr)
@@ -453,7 +460,7 @@ julia> prod(1:20)
453460
2432902008176640000
454461
```
455462
"""
456-
prod(a) = mapreduce(promote_sys_size_mul, *, a)
463+
prod(a) = mapreduce(identity, mul_tosys, a)
457464

458465
## maximum & minimum
459466

0 commit comments

Comments
 (0)