Skip to content

Commit 6e47ff5

Browse files
authored
Add extrema support (#392)
On 1.8 `Base.extrema` is `mapreduce` based.
1 parent caff08e commit 6e47ff5

File tree

3 files changed

+27
-8
lines changed

3 files changed

+27
-8
lines changed

src/host/mapreduce.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,9 @@ neutral_element(::typeof(Base.:(*)), T) = one(T)
2222
neutral_element(::typeof(Base.mul_prod), T) = one(T)
2323
neutral_element(::typeof(Base.min), T) = typemax(T)
2424
neutral_element(::typeof(Base.max), T) = typemin(T)
25+
if VERSION >= v"1.8.0-DEV.1465"
26+
neutral_element(::typeof(Base._extrema_rf), ::Type{<:NTuple{2,T}}) where {T} = typemax(T), typemin(T)
27+
end
2528

2629
# resolve ambiguities
2730
Base.mapreduce(f, op, A::AnyGPUArray, As::AbstractArrayOrBroadcasted...;

test/testsuite.jl

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,20 @@ using Adapt
1717
struct ArrayAdaptor{AT} end
1818
Adapt.adapt_storage(::ArrayAdaptor{AT}, xs::AbstractArray) where {AT} = AT(xs)
1919

20+
test_result(a::Number, b::Number) = a b
21+
function test_result(a::AbstractArray{T}, b::AbstractArray{T}) where {T<:Number}
22+
collect(a) collect(b)
23+
end
24+
function test_result(a::AbstractArray{T}, b::AbstractArray{T}) where {T<:NTuple{N,<:Number} where {N}}
25+
ET = eltype(T)
26+
reinterpret(ET, collect(a)) reinterpret(ET, collect(b))
27+
end
28+
function test_result(as::NTuple{N,Any}, bs::NTuple{N,Any}) where {N}
29+
all(zip(as, bs)) do (a, b)
30+
test_result(a, b)
31+
end
32+
end
33+
2034
function compare(f, AT::Type{<:AbstractGPUArray}, xs...; kwargs...)
2135
# copy on the CPU, adapt on the GPU, but keep Ref's
2236
cpu_in = map(x -> isa(x, Base.RefValue) ? x[] : deepcopy(x), xs)
@@ -25,13 +39,7 @@ function compare(f, AT::Type{<:AbstractGPUArray}, xs...; kwargs...)
2539
cpu_out = f(cpu_in...; kwargs...)
2640
gpu_out = f(gpu_in...; kwargs...)
2741

28-
if cpu_out isa Tuple && gpu_out isa Tuple
29-
all(zip(cpu_out,gpu_out)) do (cpu, gpu)
30-
collect(cpu) collect(gpu)
31-
end
32-
else
33-
collect(cpu_out) collect(gpu_out)
34-
end
42+
test_result(cpu_out, gpu_out)
3543
end
3644

3745
function compare(f, AT::Type{<:Array}, xs...; kwargs...)

test/testsuite/reductions.jl

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ end
7777
end
7878
end
7979

80-
@testsuite "reductions/minimum maximum" (AT, eltypes)->begin
80+
@testsuite "reductions/minimum maximum extrema" (AT, eltypes)->begin
8181
@testset "$ET" for ET in eltypes
8282
range = ET <: Real ? (ET(1):ET(10)) : ET
8383
for (sz,dims) in [(10,)=>[1], (10,10)=>[1,2], (10,10,10)=>[1,2,3], (10,10,10)=>[],
@@ -90,6 +90,11 @@ end
9090
@test compare(A->maximum(A), AT, rand(range, sz))
9191
@test compare(A->maximum(x->x*x, A), AT, rand(range, sz))
9292
@test compare(A->maximum(A; dims=dims), AT, rand(range, sz))
93+
if VERSION >= v"1.8.0-DEV.1465"
94+
@test compare(A->extrema(A), AT, rand(range, sz))
95+
@test compare(A->extrema(x->x*x, A), AT, rand(range, sz))
96+
@test compare(A->extrema(A; dims=dims), AT, rand(range, sz))
97+
end
9398
end
9499
end
95100

@@ -98,6 +103,9 @@ end
98103
if !(ET <: Complex)
99104
@test compare((A,R)->minimum!(R, A), AT, rand(range, sz), fill(typemax(ET), red))
100105
@test compare((A,R)->maximum!(R, A), AT, rand(range, sz), fill(typemin(ET), red))
106+
if VERSION >= v"1.8.0-DEV.1465"
107+
@test compare((A,R)->extrema!(R, A), AT, rand(range, sz), fill((typemax(ET),typemin(ET)), red))
108+
end
101109
end
102110
end
103111
end

0 commit comments

Comments
 (0)