Skip to content

Commit e31685d

Browse files
committed
reducedim working
1 parent 1c050f3 commit e31685d

File tree

2 files changed

+14
-19
lines changed

2 files changed

+14
-19
lines changed

base/reducedim.jl

Lines changed: 10 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,8 @@ end
6262
###### Generic reduction functions #####
6363

6464
## initialization
65-
66-
for (Op, initfun) in ((:(typeof(+)), :zero), (:(typeof(*)), :one))
65+
# initarray! is only called by sum!, prod!, etc.
66+
for (Op, initfun) in ((:(typeof(add_sum)), :zero), (:(typeof(mul_prod)), :one))
6767
@eval initarray!(a::AbstractArray{T}, ::$(Op), init::Bool, src::AbstractArray) where {T} = (init && fill!(a, $(initfun)(T)); a)
6868
end
6969

@@ -75,6 +75,7 @@ for (Op, initval) in ((:(typeof(&)), true), (:(typeof(|)), false))
7575
@eval initarray!(a::AbstractArray, ::$(Op), init::Bool, src::AbstractArray) = (init && fill!(a, $initval); a)
7676
end
7777

78+
# reducedim_initarray is called by
7879
reducedim_initarray(A::AbstractArray, region, v0, ::Type{R}) where {R} = fill!(similar(A,R,reduced_indices(A,region)), v0)
7980
reducedim_initarray(A::AbstractArray, region, v0::T) where {T} = reducedim_initarray(A, region, v0, T)
8081

@@ -104,10 +105,10 @@ reducedim_initarray0_empty(A::AbstractArray, region,::typeof(identity), ops) = m
104105
promote_union(T::Union) = promote_type(promote_union(T.a), promote_union(T.b))
105106
promote_union(T) = T
106107

107-
function reducedim_init(f, op::typeof(+), A::AbstractArray, region)
108+
function reducedim_init(f, op::Union{typeof(+),typeof(add_sum)}, A::AbstractArray, region)
108109
_reducedim_init(f, op, zero, sum, A, region)
109110
end
110-
function reducedim_init(f, op::typeof(*), A::AbstractArray, region)
111+
function reducedim_init(f, op::Union{typeof(*),typeof(mul_prod)}, A::AbstractArray, region)
111112
_reducedim_init(f, op, one, prod, A, region)
112113
end
113114
function _reducedim_init(f, op, fv, fop, A, region)
@@ -143,19 +144,12 @@ let
143144
[AbstractArray{t} for t in uniontypes(BitIntFloat)]...,
144145
[AbstractArray{Complex{t}} for t in uniontypes(BitIntFloat)]...}
145146

146-
global reducedim_init(f::typeof(identity), op::typeof(+), A::T, region) =
147-
reducedim_initarray(A, region, zero(eltype(A)))
148-
global reducedim_init(f::typeof(identity), op::typeof(*), A::T, region) =
149-
reducedim_initarray(A, region, one(eltype(A)))
150-
global reducedim_init(f::Union{typeof(abs),typeof(abs2)}, op::typeof(+), A::T, region) =
151-
reducedim_initarray(A, region, real(zero(eltype(A))))
152-
global reducedim_init(f::Union{typeof(abs),typeof(abs2)}, op::typeof(*), A::T, region) =
153-
reducedim_initarray(A, region, real(one(eltype(A))))
147+
global reducedim_init(f, op::Union{typeof(+),typeof(add_sum)}, A::T, region) =
148+
reducedim_initarray(A, region, mapreduce_single(f, op, zero(eltype(A))))
149+
global reducedim_init(f, op::Union{typeof(*),typeof(mul_prod)}, A::T, region) =
150+
reducedim_initarray(A, region, mapreduce_single(f, op, one(eltype(A))))
154151
end
155152

156-
reducedim_init(f::Union{typeof(identity),typeof(abs),typeof(abs2)}, op::typeof(+), A::AbstractArray{Bool}, region) =
157-
reducedim_initarray(A, region, 0)
158-
159153
## generic (map)reduction
160154

161155
has_fast_linear_indexing(a::AbstractArray) = false
@@ -610,7 +604,7 @@ julia> any!([1 1], A)
610604
"""
611605
any!(r, A)
612606

613-
for (fname, op) in [(:sum, :add_tosys), (:prod, :mul_tosys),
607+
for (fname, op) in [(:sum, :add_sum), (:prod, :mul_prod),
614608
(:maximum, :scalarmax), (:minimum, :scalarmin),
615609
(:all, :&), (:any, :|)]
616610
fname! = Symbol(fname, '!')

test/reducedim.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -339,9 +339,10 @@ for region in Any[-1, 0, (-1, 2), [0, 1], (1,-2,3), [0 1;
339339
end
340340

341341
# check type of result
342-
under_test = [UInt8, Int8, Int32, Int64, BigInt]
343-
@testset "type of sum(::Array{$T}" for T in under_test
342+
@testset "type of sum(::Array{$T}" for T in [UInt8, Int8, Int32, Int64, BigInt]
344343
result = sum(T[1 2 3; 4 5 6; 7 8 9], 2)
345344
@test result == hcat([6, 15, 24])
346-
@test eltype(result) === (under_test in (Int64, BigInt) ? under_test : Int)
345+
@test eltype(result) === (T <: Base.SmallSigned ? Int :
346+
T <: Base.SmallUnsigned ? UInt :
347+
T)
347348
end

0 commit comments

Comments
 (0)