Skip to content

Commit b92ca0c

Browse files
committed
Make findmax(A; dims) consistent with findmax(A) again
findmax(A; dims) and friends can handle `missing` now, too.
1 parent 1f25724 commit b92ca0c

File tree

2 files changed

+37
-11
lines changed

2 files changed

+37
-11
lines changed

base/reducedim.jl

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -144,10 +144,18 @@ for (f1, f2, initval) in ((:min, :max, :Inf), (:max, :min, :(-Inf)))
144144
# otherwise use the min/max of the first slice as initial value
145145
v0 = mapreduce(f, $f2, A1)
146146

147-
# but NaNs need to be avoided as initial values
148-
v0 = v0 != v0 ? typeof(v0)($initval) : v0
149-
150147
T = _realtype(f, promote_union(eltype(A)))
148+
149+
# but NaNs and missing need to be avoided as initial values
150+
if (v0 != v0) === true
151+
v0 = typeof(v0)($initval)
152+
elseif ismissing(v0)
153+
# If it's a union type, pick the initval from the other type.
154+
if typeof(T) == Union
155+
v0 = (T.a == Missing ? T.b : T.a)($initval)
156+
end
157+
end
158+
151159
Tr = v0 isa T ? T : typeof(v0)
152160
return reducedim_initarray(A, region, v0, Tr)
153161
end
@@ -926,7 +934,7 @@ function findminmax!(f, Rval, Rind, A::AbstractArray{T,N}) where {T,N}
926934
for i in axes(A,1)
927935
k, kss = y::Tuple
928936
tmpAv = A[i,IA]
929-
if tmpRi == zi || (tmpRv == tmpRv && (tmpAv != tmpAv || f(tmpAv, tmpRv)))
937+
if tmpRi == zi || f(tmpRv, tmpAv)
930938
tmpRv = tmpAv
931939
tmpRi = k
932940
end
@@ -943,7 +951,7 @@ function findminmax!(f, Rval, Rind, A::AbstractArray{T,N}) where {T,N}
943951
tmpAv = A[i,IA]
944952
tmpRv = Rval[i,IR]
945953
tmpRi = Rind[i,IR]
946-
if tmpRi == zi || (tmpRv == tmpRv && (tmpAv != tmpAv || f(tmpAv, tmpRv)))
954+
if tmpRi == zi || f(tmpRv, tmpAv)
947955
Rval[i,IR] = tmpAv
948956
Rind[i,IR] = k
949957
end
@@ -963,7 +971,7 @@ dimensions of `rval` and `rind`, and store the results in `rval` and `rind`.
963971
"""
964972
function findmin!(rval::AbstractArray, rind::AbstractArray, A::AbstractArray;
965973
init::Bool=true)
966-
findminmax!(isless, init && !isempty(A) ? fill!(rval, first(A)) : rval, fill!(rind,zero(eltype(keys(A)))), A)
974+
findminmax!(isgreater, init && !isempty(A) ? fill!(rval, first(A)) : rval, fill!(rind,zero(eltype(keys(A)))), A)
967975
end
968976

969977
"""
@@ -996,13 +1004,11 @@ function _findmin(A, region)
9961004
end
9971005
(similar(A, ri), zeros(eltype(keys(A)), ri))
9981006
else
999-
findminmax!(isless, fill!(similar(A, ri), first(A)),
1007+
findminmax!(isgreater, fill!(similar(A, ri), first(A)),
10001008
zeros(eltype(keys(A)), ri), A)
10011009
end
10021010
end
10031011

1004-
_isgreater(a, b) = isless(b,a)
1005-
10061012
"""
10071013
findmax!(rval, rind, A) -> (maxval, index)
10081014
@@ -1012,7 +1018,7 @@ dimensions of `rval` and `rind`, and store the results in `rval` and `rind`.
10121018
"""
10131019
function findmax!(rval::AbstractArray, rind::AbstractArray, A::AbstractArray;
10141020
init::Bool=true)
1015-
findminmax!(_isgreater, init && !isempty(A) ? fill!(rval, first(A)) : rval, fill!(rind,zero(eltype(keys(A)))), A)
1021+
findminmax!(isless, init && !isempty(A) ? fill!(rval, first(A)) : rval, fill!(rind,zero(eltype(keys(A)))), A)
10161022
end
10171023

10181024
"""
@@ -1045,7 +1051,7 @@ function _findmax(A, region)
10451051
end
10461052
similar(A, ri), zeros(eltype(keys(A)), ri)
10471053
else
1048-
findminmax!(_isgreater, fill!(similar(A, ri), first(A)),
1054+
findminmax!(isless, fill!(similar(A, ri), first(A)),
10491055
zeros(eltype(keys(A)), ri), A)
10501056
end
10511057
end

test/reducedim.jl

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,26 @@ for (tup, rval, rind) in [((1,), [5.0 5.0 6.0], [CartesianIndex(2,1) CartesianIn
219219
@test isequal(maximum!(copy(rval), A, init=false), rval)
220220
end
221221

222+
@testset "missing in findmin/findmax" begin
223+
B = [1.0 missing NaN;
224+
5.0 NaN missing]
225+
for (tup, rval, rind) in [(1, [5.0 missing missing], [CartesianIndex(2, 1) CartesianIndex(1, 2) CartesianIndex(2, 3)]),
226+
(2, [missing; missing], [CartesianIndex(1, 2) CartesianIndex(2, 3)] |> permutedims)]
227+
(rval′, rind′) = findmax(B, dims=tup)
228+
@test all(rval′ .=== rval)
229+
@test all(rind′ .== rind)
230+
@test all(maximum(B, dims=tup) .=== rval)
231+
end
232+
233+
for (tup, rval, rind) in [(1, [1.0 missing missing], [CartesianIndex(1, 1) CartesianIndex(1, 2) CartesianIndex(2, 3)]),
234+
(2, [missing; missing], [CartesianIndex(1, 2) CartesianIndex(2, 3)] |> permutedims)]
235+
(rval′, rind′) = findmin(B, dims=tup)
236+
@test all(rval′ .=== rval)
237+
@test all(rind′ .== rind)
238+
@test all(minimum(B, dims=tup) .=== rval)
239+
end
240+
end
241+
222242
#issue #23209
223243

224244
A = [1.0 3.0 6.0;

0 commit comments

Comments
 (0)