Skip to content

Commit fc401ba

Browse files
N5N3tkf
andauthored
Some mapreduce improvement (#63)
* fully drop `_extrema_dims`, `_extrema_itr` * Generalize sparse `mapreduce` * fix empty reduction test * Add NaN test Co-Authored-By: Takafumi Arakaki <takafumi.a@gmail.com>
1 parent cb6a670 commit fc401ba

File tree

6 files changed

+73
-108
lines changed

6 files changed

+73
-108
lines changed

src/higherorderfns.jl

Lines changed: 1 addition & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ module HigherOrderFns
44

55
# This module provides higher order functions specialized for sparse arrays,
66
# particularly map[!]/broadcast[!] for SparseVectors and SparseMatrixCSCs at present.
7-
import Base: map, map!, broadcast, copy, copyto!, _extrema_dims, _extrema_itr
7+
import Base: map, map!, broadcast, copy, copyto!
88

99
using Base: front, tail, to_shape
1010
using ..SparseArrays: SparseVector, SparseMatrixCSC, AbstractSparseVector, AbstractSparseMatrixCSC,
@@ -29,7 +29,6 @@ using LinearAlgebra
2929
# (11) Define broadcast[!] methods handling combinations of scalars, sparse vectors/matrices,
3030
# structured matrices, and one- and two-dimensional Arrays.
3131
# (12) Define map[!] methods handling combinations of sparse and structured matrices.
32-
# (13) Define extrema methods optimized for sparse vectors/matrices.
3332

3433

3534
# (0) BroadcastStyle rules and convenience types for dispatch
@@ -1163,60 +1162,4 @@ map(f::Tf, A::SparseOrStructuredMatrix, Bs::Vararg{SparseOrStructuredMatrix,N})
11631162
map!(f::Tf, C::AbstractSparseMatrixCSC, A::SparseOrStructuredMatrix, Bs::Vararg{SparseOrStructuredMatrix,N}) where {Tf,N} =
11641163
(_checksameshape(C, A, Bs...); _noshapecheck_map!(f, C, _sparsifystructured(A), map(_sparsifystructured, Bs)...))
11651164

1166-
1167-
# (13) extrema methods optimized for sparse vectors/matrices.
1168-
function _extrema_itr(f, A::SparseVecOrMat)
1169-
M = length(A)
1170-
iszero(M) && throw(ArgumentError("Sparse array must have at least one element."))
1171-
N = nnz(A)
1172-
iszero(N) && return f(zero(eltype(A))), f(zero(eltype(A)))
1173-
vmin, vmax = _extrema_itr(f, nonzeros(A))
1174-
if N != M
1175-
f0 = f(zero(eltype(A)))
1176-
vmin = min(f0, vmin)
1177-
vmax = max(f0, vmax)
1178-
end
1179-
vmin, vmax
1180-
end
1181-
1182-
function _extrema_dims(f, x::SparseVector, dims)
1183-
sz = zeros(1)
1184-
for d in dims
1185-
sz[d] = 1
1186-
end
1187-
if sz == [1] && !iszero(length(x))
1188-
return [_extrema_itr(f, x)]
1189-
end
1190-
invoke(_extrema_dims, Tuple{Any, AbstractArray, Any}, f, x, dims)
1191-
end
1192-
1193-
function _extrema_dims(f, A::AbstractSparseMatrix, dims)
1194-
sz = zeros(2)
1195-
for d in dims
1196-
sz[d] = 1
1197-
end
1198-
if sz == [1, 0] && !iszero(length(A))
1199-
T = eltype(A)
1200-
B = Array{Tuple{T,T}}(undef, 1, size(A, 2))
1201-
@inbounds for col_idx in 1:size(A, 2)
1202-
col = @view A[:,col_idx]
1203-
fx = (nnz(col) == size(A, 1)) ? f(A[1,col_idx]) : f(zero(T))
1204-
B[col_idx] = (fx, fx)
1205-
for x in nonzeros(col)
1206-
fx = f(x)
1207-
if fx < B[col_idx][1]
1208-
B[col_idx] = (fx, B[col_idx][2])
1209-
elseif fx > B[col_idx][2]
1210-
B[col_idx] = (B[col_idx][1], fx)
1211-
end
1212-
end
1213-
end
1214-
return B
1215-
end
1216-
invoke(_extrema_dims, Tuple{Any, AbstractArray, Any}, f, A, dims)
1217-
end
1218-
1219-
_extrema_dims(f, A::SparseVector, ::Colon) = _extrema_itr(f, A)
1220-
_extrema_dims(f, A::AbstractSparseMatrix, ::Colon) = _extrema_itr(f, A)
1221-
12221165
end

src/sparsematrix.jl

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1914,13 +1914,19 @@ function Base._mapreduce(f, op, ::Base.IndexCartesian, A::AbstractSparseMatrixCS
19141914
end
19151915
end
19161916

1917-
# Specialized mapreduce for +/*
1918-
_mapreducezeros(f, ::typeof(+), ::Type{T}, nzeros::Integer, v0) where {T} =
1919-
nzeros == 0 ? v0 : f(zero(T))*nzeros + v0
1920-
_mapreducezeros(f, ::typeof(*), ::Type{T}, nzeros::Integer, v0) where {T} =
1921-
nzeros == 0 ? v0 : f(zero(T))^nzeros * v0
1922-
1923-
function Base._mapreduce(f, op::typeof(*), A::AbstractSparseMatrixCSC{T}) where T
1917+
# Specialized mapreduce for +/*/min/max/_extrema_rf
1918+
_mapreducezeros(f, op::Union{typeof(Base.add_sum),typeof(+)}, ::Type{T}, nzeros::Integer, v0) where {T} =
1919+
nzeros == 0 ? op(zero(v0), v0) : op(f(zero(T))*nzeros, v0)
1920+
_mapreducezeros(f, op::Union{typeof(Base.mul_prod),typeof(*)},::Type{T}, nzeros::Integer, v0) where {T} =
1921+
nzeros == 0 ? op(one(v0), v0) : op(f(zero(T))^nzeros, v0)
1922+
_mapreducezeros(f, op::Union{typeof(min),typeof(max)}, ::Type{T}, nzeros::Integer, v0) where {T} =
1923+
nzeros == 0 ? v0 : op(v0, f(zero(T)))
1924+
if isdefined(Base, :_extrema_rf)
1925+
_mapreducezeros(f::Base.ExtremaMap, op::typeof(Base._extrema_rf), ::Type{T}, nzeros::Integer, v0) where {T} =
1926+
nzeros == 0 ? v0 : op(v0, f(zero(T)))
1927+
end
1928+
1929+
function Base._mapreduce(f, op::typeof(*), ::Base.IndexCartesian, A::AbstractSparseMatrixCSC{T}) where T
19241930
nzeros = widelength(A)-nnz(A)
19251931
if nzeros == 0
19261932
# No zeros, so don't compute f(0) since it might throw

src/sparsevector.jl

Lines changed: 21 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1392,46 +1392,32 @@ for (fun, mode) in [(:+, 1), (:-, 1), (:*, 0), (:min, 2), (:max, 2)]
13921392
end
13931393

13941394
### Reduction
1395-
1396-
function sum(f, x::AbstractSparseVector)
1397-
n = length(x)
1398-
n > 0 || return sum(f, nonzeros(x)) # return zero() of proper type
1399-
m = nnz(x)
1400-
(m == 0 ? n * f(zero(eltype(x))) :
1401-
m == n ? sum(f, nonzeros(x)) :
1402-
Base.add_sum((n - m) * f(zero(eltype(x))), sum(f, nonzeros(x))))
1403-
end
1404-
1405-
sum(x::AbstractSparseVector) = sum(nonzeros(x))
1406-
1407-
function maximum(f, x::AbstractSparseVector)
1408-
n = length(x)
1409-
if n == 0
1410-
if f === abs || f === abs2
1411-
return zero(eltype(x)) # preserving maximum(abs/abs2, x) behaviour in 1.0.x
1412-
else
1413-
throw(ArgumentError("maximum over an empty array is not allowed."))
1414-
end
1395+
Base.reducedim_initarray(A::AbstractSparseVector, region, v0, ::Type{R}) where {R} =
1396+
fill!(Array{R}(undef, Base.to_shape(Base.reduced_indices(A, region))), v0)
1397+
1398+
function Base._mapreduce(f, op, ::IndexCartesian, A::AbstractSparseVector{T}) where {T}
1399+
isempty(A) && return Base.mapreduce_empty(f, op, T)
1400+
z = nnz(A)
1401+
rest, ini = if z == 0
1402+
length(A)-z-1, f(zero(T))
1403+
else
1404+
length(A)-z, Base.mapreduce_impl(f, op, nonzeros(A), 1, z)
14151405
end
1416-
m = nnz(x)
1417-
(m == 0 ? f(zero(eltype(x))) :
1418-
m == n ? maximum(f, nonzeros(x)) :
1419-
max(f(zero(eltype(x))), maximum(f, nonzeros(x))))
1406+
_mapreducezeros(f, op, T, rest, ini)
14201407
end
14211408

1422-
maximum(x::AbstractSparseVector) = maximum(identity, x)
1423-
1424-
function minimum(f, x::AbstractSparseVector)
1425-
n = length(x)
1426-
n > 0 || throw(ArgumentError("minimum over an empty array is not allowed."))
1427-
m = nnz(x)
1428-
(m == 0 ? f(zero(eltype(x))) :
1429-
m == n ? minimum(f, nonzeros(x)) :
1430-
min(f(zero(eltype(x))), minimum(f, nonzeros(x))))
1409+
function Base.mapreducedim!(f, op, R::AbstractVector, A::AbstractSparseVector)
1410+
# dim1 reduction could be safely replaced with a mapreduce
1411+
if length(R) == 1
1412+
I = firstindex(R)
1413+
v = Base._mapreduce(f, op, IndexCartesian(), A)
1414+
R[I] = op(R[I], v)
1415+
return R
1416+
end
1417+
# otherwise there's no reduction
1418+
map!((x, y) -> op(x, f(y)), R, R, A)
14311419
end
14321420

1433-
minimum(x::AbstractSparseVector) = minimum(identity, x)
1434-
14351421
for (fun, comp, word) in ((:findmin, :(<), "minimum"), (:findmax, :(>), "maximum"))
14361422
@eval function $fun(f, x::AbstractSparseVector{T}) where {T}
14371423
n = length(x)

test/higherorderfns.jl

Lines changed: 34 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -709,24 +709,54 @@ end
709709
@test extrema(f, x) == extrema(f, y)
710710
@test extrema(spzeros(n, n)) == (0.0, 0.0)
711711
@test extrema(spzeros(n)) == (0.0, 0.0)
712-
@test_throws ArgumentError extrema(spzeros(0, 0))
713-
@test_throws ArgumentError extrema(spzeros(0))
712+
# TODO: Remove the temporary skip once https://github.com/JuliaLang/julia/pull/43604 is merged
713+
if isdefined(Base, :_extrema_rf)
714+
@test_throws "reducing over an empty" extrema(spzeros(0, 0))
715+
@test_throws "reducing over an empty" extrema(spzeros(0))
716+
end
714717
@test extrema(sparse(ones(n, n))) == (1.0, 1.0)
715718
@test extrema(sparse(ones(n))) == (1.0, 1.0)
716719
@test extrema(A; dims=:) == extrema(B; dims=:)
717720
@test extrema(A; dims=1) == extrema(B; dims=1)
718721
@test extrema(A; dims=2) == extrema(B; dims=2)
719722
@test extrema(A; dims=(1,2)) == extrema(B; dims=(1,2))
720723
@test extrema(f, A; dims=1) == extrema(f, B; dims=1)
721-
@test extrema(sparse(C); dims=1) == extrema(C; dims=1)
724+
# TODO: Remove the temporary skip once https://github.com/JuliaLang/julia/pull/43604 is merged
725+
if isdefined(Base, :_extrema_rf)
726+
@test_throws "reducing over an empty" extrema(sparse(C); dims=1) == extrema(C; dims=1)
727+
end
722728
@test extrema(A; dims=[]) == extrema(B; dims=[])
723729
@test extrema(x; dims=:) == extrema(y; dims=:)
724730
@test extrema(x; dims=1) == extrema(y; dims=1)
725731
@test extrema(f, x; dims=1) == extrema(f, y; dims=1)
726-
@test_throws BoundsError extrema(sparse(z); dims=1)
732+
# TODO: Remove the temporary skip once https://github.com/JuliaLang/julia/pull/43604 is merged
733+
if isdefined(Base, :_extrema_rf)
734+
@test_throws "reducing over an empty" extrema(sparse(z); dims=1)
735+
end
727736
@test extrema(x; dims=[]) == extrema(y; dims=[])
728737
end
729738

739+
# TODO: Remove the temporary skip once https://github.com/JuliaLang/julia/pull/43604 is merged
740+
if isdefined(Base, :_extrema_rf)
741+
function test_extrema(a; dims_test = ((), 1, 2, (1,2), 3))
742+
for dims in dims_test
743+
vext = extrema(a; dims)
744+
vmin, vmax = minimum(a; dims), maximum(a; dims)
745+
@test all(x -> isequal(x[1], x[2:3]), zip(vext,vmin,vmax))
746+
end
747+
end
748+
@testset "NaN test for sparse extrema" begin
749+
for sz = (3, 10, 100)
750+
A = sprand(sz, sz, 0.3)
751+
A[rand(1:sz^2,sz)] .= NaN
752+
test_extrema(A)
753+
A = sprand(sz*sz, 0.3)
754+
A[rand(1:sz^2,sz)] .= NaN
755+
test_extrema(A; dims_test = ((), 1, 2))
756+
end
757+
end
758+
end
759+
730760
@testset "issue #42670 - error in sparsevec outer product" begin
731761
A = spzeros(Int, 4)
732762
B = copy(A)

test/sparse.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -871,8 +871,8 @@ end
871871
occursin("collection slices must be non-empty", str)
872872
@test sum(sparse(Int[])) === 0
873873
@test prod(sparse(Int[])) === 1
874-
@test_throws ArgumentError minimum(sparse(Int[]))
875-
@test_throws ArgumentError maximum(sparse(Int[]))
874+
@test_throws "reducing over an empty" minimum(sparse(Int[]))
875+
@test_throws "reducing over an empty" maximum(sparse(Int[]))
876876

877877
for f in (sum, prod)
878878
@test isequal(f(spzeros(0, 1), dims=1), f(Matrix{Int}(I, 0, 1), dims=1))

test/sparsevector.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -893,8 +893,8 @@ end
893893
end
894894

895895
let x = spzeros(Float64, 0)
896-
@test_throws ArgumentError minimum(t -> true, x)
897-
@test_throws ArgumentError maximum(t -> true, x)
896+
@test_throws "reducing over an empty" minimum(t -> true, x)
897+
@test_throws "reducing over an empty" maximum(t -> true, x)
898898
@test_throws ArgumentError findmin(x)
899899
@test_throws ArgumentError findmax(x)
900900
end

0 commit comments

Comments
 (0)