diff --git a/stdlib/SparseArrays/src/higherorderfns.jl b/stdlib/SparseArrays/src/higherorderfns.jl index e01be3697097c..ec2bd31d8c5e9 100644 --- a/stdlib/SparseArrays/src/higherorderfns.jl +++ b/stdlib/SparseArrays/src/higherorderfns.jl @@ -4,7 +4,7 @@ module HigherOrderFns # This module provides higher order functions specialized for sparse arrays, # particularly map[!]/broadcast[!] for SparseVectors and SparseMatrixCSCs at present. -import Base: map, map!, broadcast, copy, copyto! +import Base: map, map!, broadcast, copy, copyto!, _extrema_dims, _extrema_itr using Base: front, tail, to_shape using ..SparseArrays: SparseVector, SparseMatrixCSC, AbstractSparseVector, AbstractSparseMatrixCSC, @@ -29,6 +29,7 @@ using LinearAlgebra # (11) Define broadcast[!] methods handling combinations of scalars, sparse vectors/matrices, # structured matrices, and one- and two-dimensional Arrays. # (12) Define map[!] methods handling combinations of sparse and structured matrices. +# (13) Define extrema methods optimized for sparse vectors/matrices. # (0) BroadcastStyle rules and convenience types for dispatch @@ -1154,4 +1155,60 @@ map(f::Tf, A::SparseOrStructuredMatrix, Bs::Vararg{SparseOrStructuredMatrix,N}) map!(f::Tf, C::AbstractSparseMatrixCSC, A::SparseOrStructuredMatrix, Bs::Vararg{SparseOrStructuredMatrix,N}) where {Tf,N} = (_checksameshape(C, A, Bs...); _noshapecheck_map!(f, C, _sparsifystructured(A), map(_sparsifystructured, Bs)...)) + +# (13) extrema methods optimized for sparse vectors/matrices. +function _extrema_itr(f, A::SparseVecOrMat) + M = length(A) + iszero(M) && throw(ArgumentError("Sparse array must have at least one element.")) + N = nnz(A) + iszero(N) && return f(zero(eltype(A))), f(zero(eltype(A))) + vmin, vmax = _extrema_itr(f, nonzeros(A)) + if N != M + f0 = f(zero(eltype(A))) + vmin = min(f0, vmin) + vmax = max(f0, vmax) + end + vmin, vmax +end + +function _extrema_dims(f, x::SparseVector, dims) + sz = zeros(1) + for d in dims + sz[d] = 1 + end + if sz == [1] && !iszero(length(x)) + return [_extrema_itr(f, x)] + end + invoke(_extrema_dims, Tuple{Any, AbstractArray, Any}, f, x, dims) +end + +function _extrema_dims(f, A::AbstractSparseMatrix, dims) + sz = zeros(2) + for d in dims + sz[d] = 1 + end + if sz == [1, 0] && !iszero(length(A)) + T = eltype(A) + B = Array{Tuple{T,T}}(undef, 1, size(A, 2)) + @inbounds for col_idx in 1:size(A, 2) + col = @view A[:,col_idx] + fx = (nnz(col) == size(A, 1)) ? f(A[1,col_idx]) : f(zero(T)) + B[col_idx] = (fx, fx) + for x in nonzeros(col) + fx = f(x) + if fx < B[col_idx][1] + B[col_idx] = (fx, B[col_idx][2]) + elseif fx > B[col_idx][2] + B[col_idx] = (B[col_idx][1], fx) + end + end + end + return B + end + invoke(_extrema_dims, Tuple{Any, AbstractArray, Any}, f, A, dims) +end + +_extrema_dims(f, A::SparseVector, ::Colon) = _extrema_itr(f, A) +_extrema_dims(f, A::AbstractSparseMatrix, ::Colon) = _extrema_itr(f, A) + end diff --git a/stdlib/SparseArrays/test/higherorderfns.jl b/stdlib/SparseArrays/test/higherorderfns.jl index cd77fca6951a5..69a319ccc1340 100644 --- a/stdlib/SparseArrays/test/higherorderfns.jl +++ b/stdlib/SparseArrays/test/higherorderfns.jl @@ -687,4 +687,37 @@ end @test SparseMatStyle(Val(3)) == Broadcast.DefaultArrayStyle{3}() end +@testset "extrema" begin + n = 10 + A = sprand(n, n, 0.2) + B = Array(A) + C = Array{Real}(undef, 0, 0) + x = sprand(n, 0.2) + y = Array(x) + z = Array{Real}(undef, 0) + f(x) = x^3 + @test extrema(A) == extrema(B) + @test extrema(x) == extrema(y) + @test extrema(f, A) == extrema(f, B) + @test extrema(f, x) == extrema(f, y) + @test extrema(spzeros(n, n)) == (0.0, 0.0) + @test extrema(spzeros(n)) == (0.0, 0.0) + @test_throws ArgumentError extrema(spzeros(0, 0)) + @test_throws ArgumentError extrema(spzeros(0)) + @test extrema(sparse(ones(n, n))) == (1.0, 1.0) + @test extrema(sparse(ones(n))) == (1.0, 1.0) + @test extrema(A; dims=:) == extrema(B; dims=:) + @test extrema(A; dims=1) == extrema(B; dims=1) + @test extrema(A; dims=2) == extrema(B; dims=2) + @test extrema(A; dims=(1,2)) == extrema(B; dims=(1,2)) + @test extrema(f, A; dims=1) == extrema(f, B; dims=1) + @test extrema(sparse(C); dims=1) == extrema(C; dims=1) + @test extrema(A; dims=[]) == extrema(B; dims=[]) + @test extrema(x; dims=:) == extrema(y; dims=:) + @test extrema(x; dims=1) == extrema(y; dims=1) + @test extrema(f, x; dims=1) == extrema(f, y; dims=1) + @test_throws BoundsError extrema(sparse(z); dims=1) + @test extrema(x; dims=[]) == extrema(y; dims=[]) +end + end # module