Skip to content

Commit 55a9bda

Browse files
committed
Modifications to work well with Bool
1 parent 49f2cb8 commit 55a9bda

File tree

2 files changed

+24
-16
lines changed

2 files changed

+24
-16
lines changed

base/reduce.jl

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,17 @@ else
1212
const SmallUnsigned = Union{UInt8,UInt16,UInt32}
1313
end
1414

15-
# Certain reductions like sum and prod may wish to promote the items being
16-
# reduced over to an appropriate size.
17-
promote_sys_size(x) = x
18-
promote_sys_size(x::Union{Bool, SmallSigned}) = Int(x)
19-
promote_sys_size(x::SmallUnsigned) = UInt(x)
15+
# Certain reductions like sum and prod may wish to promote the items being reduced over to
16+
# an appropriate size. Note we need x + zero(x) because some types like Bool have their sum
17+
# lie in a larger type.
18+
promote_sys_size(T::Type) = T
19+
promote_sys_size(::Type{<:SmallSigned}) = Int
20+
promote_sys_size(::Type{<:SmallUnsigned}) = UInt
21+
22+
promote_sys_size_add(x) = convert(promote_sys_size(typeof(x + zero(x))), x)
23+
promote_sys_size_mul(x) = convert(promote_sys_size(typeof(x * one(x))), x)
24+
const _PromoteSysSizeFunction = Union{typeof(promote_sys_size_add),
25+
typeof(promote_sys_size_mul)}
2026

2127
## foldl && mapfoldl
2228

@@ -240,8 +246,8 @@ reduce_empty(::typeof(|), ::Type{Bool}) = false
240246

241247
mapreduce_empty(f, op, T) = _empty_reduce_error()
242248
mapreduce_empty(::typeof(identity), op, T) = reduce_empty(op, T)
243-
mapreduce_empty(::typeof(promote_sys_size), op, T) =
244-
promote_sys_size(mapreduce_empty(identity, op, T))
249+
mapreduce_empty(f::_PromoteSysSizeFunction, op, T) =
250+
f(mapreduce_empty(identity, op, T))
245251
mapreduce_empty(::typeof(abs), ::typeof(+), T) = abs(zero(T))
246252
mapreduce_empty(::typeof(abs2), ::typeof(+), T) = abs2(zero(T))
247253
mapreduce_empty(::typeof(abs), ::Union{typeof(scalarmax), typeof(max)}, T) =
@@ -251,8 +257,8 @@ mapreduce_empty(::typeof(abs2), ::Union{typeof(scalarmax), typeof(max)}, T) =
251257

252258
# Allow mapreduce_empty to “see through” promote_sys_size
253259
let ComposedFunction = typename(typeof(identity identity)).wrapper
254-
global mapreduce_empty(f::ComposedFunction{typeof(promote_sys_size)}, op, T) =
255-
promote_sys_size(mapreduce_empty(f.g, op, T))
260+
global mapreduce_empty(f::ComposedFunction{<:_PromoteSysSizeFunction}, op, T) =
261+
f.f(mapreduce_empty(f.g, op, T))
256262
end
257263

258264
mapreduce_empty_iter(f, op, itr, ::HasEltype) = mapreduce_empty(f, op, eltype(itr))
@@ -366,7 +372,7 @@ In the former case, the integers are widened to system word size and therefore
366372
the result is 128. In the latter case, no such widening happens and integer
367373
overflow results in -128.
368374
"""
369-
sum(f::Callable, a) = mapreduce(promote_sys_size f, +, a)
375+
sum(f::Callable, a) = mapreduce(promote_sys_size_add f, +, a)
370376

371377
"""
372378
sum(itr)
@@ -382,7 +388,7 @@ julia> sum(1:20)
382388
210
383389
```
384390
"""
385-
sum(a) = mapreduce(promote_sys_size, +, a)
391+
sum(a) = mapreduce(promote_sys_size_add, +, a)
386392
sum(a::AbstractArray{Bool}) = count(a)
387393

388394

@@ -397,7 +403,7 @@ summation algorithm for additional accuracy.
397403
"""
398404
function sum_kbn(A)
399405
T = _default_eltype(typeof(A))
400-
c = promote_sys_size(zero(T)::T)
406+
c = promote_sys_size_add(zero(T)::T)
401407
i = start(A)
402408
if done(A, i)
403409
return c
@@ -432,7 +438,7 @@ julia> prod(abs2, [2; 3; 4])
432438
576
433439
```
434440
"""
435-
prod(f::Callable, a) = mapreduce(promote_sys_size f, *, a)
441+
prod(f::Callable, a) = mapreduce(promote_sys_size_mul f, *, a)
436442

437443
"""
438444
prod(itr)
@@ -448,7 +454,7 @@ julia> prod(1:20)
448454
2432902008176640000
449455
```
450456
"""
451-
prod(a) = mapreduce(promote_sys_size, *, a)
457+
prod(a) = mapreduce(promote_sys_size_mul, *, a)
452458

453459
## maximum & minimum
454460

base/reducedim.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -614,8 +614,10 @@ for (fname, op) in [(:sum, :+), (:prod, :*),
614614
(:maximum, :scalarmax), (:minimum, :scalarmin),
615615
(:all, :&), (:any, :|)]
616616
function compose_promote_sys_size(x)
617-
if fname in [:sum, :prod]
618-
:(promote_sys_size $x)
617+
if fname === :sum
618+
:(promote_sys_size_add $x)
619+
elseif fname === :prod
620+
:(promote_sys_size_mul $x)
619621
else
620622
x
621623
end

0 commit comments

Comments
 (0)