Skip to content

Commit e015b3e

Browse files
committed
simplify min/max init
1 parent 05ddddc commit e015b3e

File tree

2 files changed

+18
-49
lines changed

2 files changed

+18
-49
lines changed

base/reducedim.jl

Lines changed: 11 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,6 @@ end
9191
# reducedim_initarray is called by
9292
reducedim_initarray(A::AbstractArrayOrBroadcasted, region, init, ::Type{R}) where {R} = fill!(similar(A,R,reduced_indices(A,region)), init)
9393
reducedim_initarray(A::AbstractArrayOrBroadcasted, region, init::T) where {T} = reducedim_initarray(A, region, init, T)
94-
# TODO: extend this to minimum and maximum
9594
reducedim_initarray(A::AbstractArrayOrBroadcasted, region, ::UndefInitializer, ::Type{R}) where {R} = similar(A,R,reduced_indices(A,region))
9695
# TODO: better way to handle reducedim initialization
9796
#
@@ -126,45 +125,21 @@ function _reducedim_init(f, op, fv, fop, A, region)
126125
end
127126

128127
# initialization when computing minima and maxima requires a little care
129-
for (f1, f2, initval, typeextreme) in ((:min, :max, :Inf, :typemax), (:max, :min, :(-Inf), :typemin))
130-
@eval function reducedim_init(f, op::typeof($f1), A::AbstractArray, region)
131-
# First compute the reduce indices. This will throw an ArgumentError
132-
# if any region is invalid
133-
ri = reduced_indices(A, region)
128+
function reducedim_init(f::F, ::Union{typeof(min),typeof(max)}, A::AbstractArray, region) where {F}
129+
# First compute the reduce indices. This will throw an ArgumentError
130+
# if any region is invalid
131+
ri = reduced_indices(A, region)
134132

135-
# Next, throw if reduction is over a region with length zero
136-
any(i -> isempty(axes(A, i)), region) && _empty_reduce_error()
133+
# Next, throw if reduction is over a region with length zero
134+
any(i -> isempty(axes(A, i)), region) && _empty_reduce_error()
137135

138-
# Make a view of the first slice of the region
139-
A1 = view(A, ri...)
136+
# Make a view of the first slice of the region
137+
A1 = view(A, ri...)
140138

141-
if isempty(A1)
142-
# If the slice is empty just return non-view version as the initial array
143-
return copy(A1)
144-
else
145-
# otherwise use the min/max of the first slice as initial value
146-
v0 = mapreduce(f, $f2, A1)
147-
148-
T = _realtype(f, promote_union(eltype(A)))
149-
Tr = v0 isa T ? T : typeof(v0)
150-
151-
# but NaNs and missing need to be avoided as initial values
152-
if (v0 == v0) === false
153-
# v0 is NaN
154-
v0 = $initval
155-
elseif isunordered(v0)
156-
# v0 is missing or a third-party unordered value
157-
Tnm = nonmissingtype(Tr)
158-
# TODO: Some types, like BigInt, don't support typemin/typemax.
159-
# So a Matrix{Union{BigInt, Missing}} can still error here.
160-
v0 = $typeextreme(Tnm)
161-
end
162-
# v0 may have changed type.
163-
Tr = v0 isa T ? T : typeof(v0)
139+
# calculate the output type
140+
T = promote_typejoin_union(_return_type(f, Tuple{eltype(A)}))
164141

165-
return reducedim_initarray(A, region, v0, Tr)
166-
end
167-
end
142+
map!(f, reducedim_initarray(A,region,undef,T), A1)
168143
end
169144
reducedim_init(f::Union{typeof(abs),typeof(abs2)}, op::typeof(max), A::AbstractArray{T}, region) where {T} =
170145
reducedim_initarray(A, region, zero(f(zero(T))), _realtype(f, T))

test/reduce.jl

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -378,35 +378,29 @@ A = circshift(reshape(1:24,2,3,4), (0,1,1))
378378
@test size(extrema(A,dims=(1,2,3))) == size(maximum(A,dims=(1,2,3)))
379379
@test extrema(x->div(x, 2), A, dims=(2,3)) == reshape([(0,11),(1,12)],2,1,1)
380380

381-
# TODO: drop `a′` once `minimum` and `maximum` is fixed
382-
# (the following test_broken pass)
383-
function test_extrema(a, a′ = a; dims_test = ((), 1, 2, (1,2), 3))
381+
function test_extrema(a; dims_test = ((), 1, 2, (1,2), 3))
384382
for dims in dims_test
385383
vext = extrema(a; dims)
386-
vmin, vmax = minimum(a; dims), maximum(a; dims)
387-
@test all(x -> isequal(x[1], x[2:3]), zip(vext,vmin,vmax)) || foreach(i -> display(i),(eltype(a), vext,vmin,vmax))
384+
vmin, vmax = minimum(a; dims), maximum(a; dims)
385+
@test all(x -> isequal(x[1], x[2:3]), zip(vext,vmin,vmax))
388386
end
389387
end
390-
@test_broken minimum([missing BigInt(1)], dims = 2)[1] === missing
391388
@testset "0.0,-0.0 test for extrema with dims" begin
392389
@test extrema([-0.0;0.0], dims = 1)[1] === (-0.0,0.0)
393390
@test tuple(extrema([-0.0;0.0], dims = 2)...) === ((-0.0, -0.0), (0.0, 0.0))
394391
end
395392
@testset "NaN/missing test for extrema with dims #43599" begin
396393
for sz = (3, 10, 100)
397394
for T in (Int, BigInt, Float64, BigFloat)
398-
Aₘ = Matrix{Union{Float64, Missing}}(rand(-sz:sz, sz, sz))
395+
Aₘ = Matrix{Union{T, Missing}}(rand(-sz:sz, sz, sz))
399396
Aₘ[rand(1:sz*sz, sz)] .= missing
400-
ATₘ = Matrix{Union{T, Missing}}(Aₘ)
401-
test_extrema(ATₘ, Aₘ)
397+
test_extrema(Aₘ)
402398
if T <: AbstractFloat
403399
Aₙ = map(i -> ismissing(i) ? T(NaN) : i, Aₘ)
404-
ATₙ = map(i -> ismissing(i) ? T(NaN) : i, ATₘ)
405-
test_extrema(ATₙ, Aₙ)
400+
test_extrema(Aₙ)
406401
p = rand(1:sz*sz, sz)
407402
Aₘ[p] .= NaN
408-
ATₘ[p] .= NaN
409-
test_extrema(ATₘ, Aₘ)
403+
test_extrema(Aₘ)
410404
end
411405
end
412406
end

0 commit comments

Comments
 (0)