Skip to content

Commit 099fab8

Browse files
committed
implement extrema with mapreduce machinery
Update multidimensional.jl Extend `min/maximum` optimization to much shorter length add eagerly `NaN` break for `extrema` performance optimization and code clean
1 parent 8717fbe commit 099fab8

File tree

2 files changed

+117
-94
lines changed

2 files changed

+117
-94
lines changed

base/multidimensional.jl

Lines changed: 71 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1746,41 +1746,80 @@ of `A`.
17461746
This method requires Julia 1.2 or later.
17471747
"""
17481748
extrema(f, A::AbstractArray; dims=:) = _extrema_dims(f, A, dims)
1749-
1750-
_extrema_dims(f, A::AbstractArray, ::Colon) = _extrema_itr(f, A)
1751-
1752-
function _extrema_dims(f, A::AbstractArray, dims)
1753-
sz = size(A)
1754-
for d in dims
1755-
sz = setindex(sz, 1, d)
1756-
end
1757-
T = promote_op(f, eltype(A))
1758-
B = Array{Tuple{T,T}}(undef, sz...)
1759-
return extrema!(f, B, A)
1760-
end
1761-
1762-
@noinline function extrema!(f, B, A)
1763-
require_one_based_indexing(B, A)
1764-
sA = size(A)
1765-
sB = size(B)
1766-
for I in CartesianIndices(sB)
1767-
fAI = f(A[I])
1768-
B[I] = (fAI, fAI)
1769-
end
1770-
Bmax = CartesianIndex(sB)
1771-
@inbounds @simd for I in CartesianIndices(sA)
1772-
J = min(Bmax,I)
1773-
BJ = B[J]
1774-
fAI = f(A[I])
1775-
if fAI < BJ[1]
1776-
B[J] = (fAI, BJ[2])
1777-
elseif fAI > BJ[2]
1778-
B[J] = (BJ[1], fAI)
1749+
_extrema_dims(f, A::AbstractArray, dims) = mapreduce(x -> (fx = f(x); (fx, fx)), _extrema_op, A; dims)
1750+
extrema!(B, A) = extrema!(identity, B, A)
1751+
extrema!(f, B, A) = mapreduce!(x -> (fx = f(x); (fx, fx)), _extrema_op, B, A)
1752+
_extrema_op((a, b), (c, d)) = min(a, c), max(b, d)
1753+
function _extrema_op(x::NTuple{2,T}, y::NTuple{2,T}) where {T<:IEEEFloat}
1754+
(x1, x2), (y1, y2) = x, y
1755+
z1 = ifelse(isnan(x1)|isnan(y1), x1-y1, ifelse(signbit(x1-y1), x1, y1))
1756+
z2 = ifelse(isnan(x1)|isnan(y1), x1-y1, ifelse(signbit(x2-y2), y2, x2))
1757+
z1, z2
1758+
end
1759+
# avoid allocation for BigFloat
1760+
function _extrema_op(x::NTuple{2,T}, y::NTuple{2,T}) where {T<:AbstractFloat}
1761+
(x1, x2), (y1, y2) = x, y
1762+
isnan(x1) && return x
1763+
isnan(y1) && return y
1764+
z1 = x1 < y1 || signbit(x1) > signbit(y1) ? x1 : y1
1765+
z2 = x2 < y2 || signbit(x2) > signbit(y2) ? y2 : x2
1766+
z1, z2
1767+
end
1768+
1769+
function reducedim_init(f, ::typeof(_extrema_op), A::AbstractArray, region)
1770+
ri = reduced_indices(A, region)
1771+
any(i -> isempty(axes(A, i)), region) && _empty_reduce_error()
1772+
A1 = view(A, ri...)
1773+
IT = eltype(A)
1774+
if missing isa IT
1775+
RT = promote_typejoin_union(_return_type(i -> f(i)[1], Tuple{nonmissingtype(IT)}))
1776+
T = Union{Tuple{RT,RT},Tuple{Missing,Missing}}
1777+
else
1778+
RT = promote_typejoin_union(_return_type(i -> f(i)[1], Tuple{IT}))
1779+
T = Union{Tuple{RT,RT}}
1780+
end
1781+
map!(f, reducedim_initarray(A,region,undef,T), A1)
1782+
end
1783+
1784+
function mapreduce_impl(f, op::typeof(_extrema_op),
1785+
A::AbstractArrayOrBroadcasted, fi::Int, la::Int)
1786+
@inline elf(i) = @inbounds f(A[i])[1]
1787+
Eltype = _return_type(elf, Tuple{Int})
1788+
Eltype <: IEEEFloat ||
1789+
return invoke(mapreduce_impl,Tuple{Any,Any,typeof(A),Int,Int},f,op,A,fi,la)
1790+
ini, i = elf(fi), fi
1791+
v = ini, ini
1792+
if la - i >= 8
1793+
@noinline firstnan(temp) = (x=temp[findfirst(isnan,temp)]; ((x, x), fi))
1794+
function simd_kernal(::Val{N}, ini, i) where {N}
1795+
vmins = ntuple(Returns(ini), Val(N))
1796+
vmaxs = vmins
1797+
index = ntuple(identity, Val(N))
1798+
for _ in 1:(la-i)÷N
1799+
temp = map(elf, i .+ index)
1800+
mapreduce(isnan,|,temp) && return firstnan(temp)
1801+
vmins = map(_fast(min), vmins, temp)
1802+
vmaxs = map(_fast(max), vmaxs, temp)
1803+
i += N
1804+
end
1805+
(reduce(_fast(min), vmins), reduce(_fast(max), vmaxs)), i
17791806
end
1807+
isnan(ini) && return v
1808+
if la - i < 64
1809+
v, i = simd_kernal(Val(4), ini, i)
1810+
elseif la - i < 256
1811+
v, i = simd_kernal(Val(8), ini, i)
1812+
else
1813+
v, i = simd_kernal(Val(64÷sizeof(Eltype)), ini, i)
1814+
end
1815+
i == fi && return v
17801816
end
1781-
return B
1817+
while i < la
1818+
v′ = elf(i+=1)
1819+
v = _extrema_op(v, (v′,v′))
1820+
end
1821+
return v
17821822
end
1783-
extrema!(B, A) = extrema!(identity, B, A)
17841823

17851824
# Show for pairs() with Cartesian indices. Needs to be here rather than show.jl for bootstrap order
17861825
function Base.showarg(io::IO, r::Iterators.Pairs{<:Integer, <:Any, <:Any, T}, toplevel) where T <: Union{AbstractVector, Tuple}

base/reduce.jl

Lines changed: 46 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -605,71 +605,55 @@ julia> prod(1:5; init = 1.0)
605605
prod(a; kw...) = mapreduce(identity, mul_prod, a; kw...)
606606

607607
## maximum & minimum
608-
_fast(::typeof(min),x,y) = min(x,y)
609-
_fast(::typeof(max),x,y) = max(x,y)
610-
function _fast(::typeof(max), x::AbstractFloat, y::AbstractFloat)
611-
ifelse(isnan(x),
612-
x,
613-
ifelse(x > y, x, y))
614-
end
615-
616-
function _fast(::typeof(min),x::AbstractFloat, y::AbstractFloat)
617-
ifelse(isnan(x),
618-
x,
619-
ifelse(x < y, x, y))
620-
end
621608

622-
isbadzero(::typeof(max), x::AbstractFloat) = (x == zero(x)) & signbit(x)
623-
isbadzero(::typeof(min), x::AbstractFloat) = (x == zero(x)) & !signbit(x)
624-
isbadzero(op, x) = false
625-
isgoodzero(::typeof(max), x) = isbadzero(min, x)
626-
isgoodzero(::typeof(min), x) = isbadzero(max, x)
627-
628-
function mapreduce_impl(f, op::Union{typeof(max), typeof(min)},
629-
A::AbstractArrayOrBroadcasted, first::Int, last::Int)
630-
# 1. This optimization gives different result from general fallback, if the inputs `f.(A)`
631-
# contains both 'missing' and 'Nan'.
632-
# 2. For Integer cases, general fallback seems faster.
633-
# Based the above reasons, only use this for AbstractFloat cases.
634-
Eltype = _return_type(i -> f(A[i]), Tuple{Int})
635-
Eltype <: AbstractFloat ||
636-
return invoke(mapreduce_impl,Tuple{Any,Any,AbstractArrayOrBroadcasted,Int,Int},f,op,A,first,last)
637-
a1 = @inbounds A[first]
638-
v1 = mapreduce_first(f, op, a1)
639-
v2 = v3 = v4 = v1
640-
chunk_len = 256
641-
start = first + 1
642-
simdstop = start + chunk_len - 4
643-
while simdstop <= last - 3
644-
# short circuit in case of NaN or missing
645-
v1 == v1 || return v1
646-
v2 == v2 || return v2
647-
v3 == v3 || return v3
648-
v4 == v4 || return v4
649-
@inbounds for i in start:4:simdstop
650-
v1 = _fast(op, v1, f(A[i+0]))
651-
v2 = _fast(op, v2, f(A[i+1]))
652-
v3 = _fast(op, v3, f(A[i+2]))
653-
v4 = _fast(op, v4, f(A[i+3]))
609+
# Optimizaiton for min/max reduction
610+
_fast(op) = (x, y) -> _fast(op, x, y)
611+
_fast(op, x, y) = op(x, y)
612+
613+
# used in optimized mapreduce_impl for IEEEFloat
614+
# where nan inputs has been handled
615+
_fast(::typeof(min), x::T, y::T) where {T<:IEEEFloat} = ifelse(signbit(x-y), x, y)
616+
_fast(::typeof(max), x::T, y::T) where {T<:IEEEFloat} = ifelse(signbit(x-y), y, x)
617+
618+
function mapreduce_impl(f, op::Union{typeof(max),typeof(min)},
619+
A::AbstractArrayOrBroadcasted, fi::Int, la::Int)
620+
@inline elf(i) = @inbounds f(A[i])
621+
Eltype = _return_type(elf, Tuple{Int})
622+
# For Integer input, general fallback is about 2x faster.
623+
# Thus limit this optimization to IEEEFloat.
624+
Eltype <: IEEEFloat ||
625+
return invoke(mapreduce_impl,Tuple{Any,Any,typeof(A),Int,Int},f,op,A,fi,la)
626+
v, i = elf(fi), fi
627+
if la - i >= 8
628+
# we always return the first nan
629+
@noinline firstnan(temp) = temp[findfirst(isnan, temp)], fi
630+
function simd_kernal(::Val{N}, ini, i) where {N}
631+
vs = ntuple(Returns(ini), Val(N)) # initial values (non nan)
632+
index = ntuple(identity, Val(N))
633+
for _ in 1:(la-i)÷N
634+
temp = map(elf, i .+ index)
635+
# perform nan check, put this together is faster
636+
mapreduce(isnan,|,temp) && return firstnan(temp)
637+
# since temp has no nan, we can use _fast(op) safely
638+
vs = map(_fast(op), vs, temp)
639+
i += N
640+
end
641+
reduce(_fast(op), vs), i
654642
end
655-
checkbounds(A, simdstop+3)
656-
start += chunk_len
657-
simdstop += chunk_len
658-
end
659-
v = op(op(v1,v2),op(v3,v4))
660-
for i in start:last
661-
@inbounds ai = A[i]
662-
v = op(v, f(ai))
663-
end
664-
665-
# enforce correct order of 0.0 and -0.0
666-
# e.g. maximum([0.0, -0.0]) === 0.0
667-
# should hold
668-
if isbadzero(op, v)
669-
for i in first:last
670-
x = @inbounds A[i]
671-
isgoodzero(op,x) && return x
643+
isnan(v) && return v
644+
# pick a proper unroll-size
645+
if la - i < 64
646+
v, i = simd_kernal(Val(4), v, i)
647+
elseif la - i < 256
648+
v, i = simd_kernal(Val(8), v, i)
649+
else
650+
# fill the cache-line
651+
v, i = simd_kernal(Val(64÷sizeof(Eltype)), v, i)
672652
end
653+
i == fi && return v # return by `firstnan`
654+
end
655+
while i < la
656+
v = op(v, elf(i+=1))
673657
end
674658
return v
675659
end

0 commit comments

Comments
 (0)