Skip to content

Commit 7d711ab

Browse files
wsshinandyferris
authored andcommitted
Add more functions implemented with [map]reduce and [map]reducedim (#263)
* Use StaticArrays' reduce and mapreduce for iszero and count(f, a) * Add more functions implemented with [map]reduce and [map]reducedim * Handle possibility of eltype change in mapreducedim and diff
1 parent c5d7515 commit 7d711ab

File tree

3 files changed

+142
-60
lines changed

3 files changed

+142
-60
lines changed

src/StaticArrays.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ import Base: getindex, setindex!, size, similar, vec, show,
1010
hcat, vcat, ones, zeros, eye, one, cross, vecdot, reshape, fill,
1111
fill!, det, logdet, inv, eig, eigvals, expm, logm, sqrtm, lyap, trace, kron, diag, vecnorm, norm, dot, diagm, diag,
1212
lu, svd, svdvals, svdfact, factorize, ishermitian, issymmetric, isposdef,
13-
sum, diff, prod, count, any, all, minimum,
13+
iszero, sum, diff, prod, count, any, all, minimum,
1414
maximum, extrema, mean, copy, rand, randn, randexp, rand!, randn!,
1515
randexp!, normalize, normalize!, read, read!, write
1616

src/mapreduce.jl

Lines changed: 67 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -101,9 +101,11 @@ end
101101
_mapreducedim(f, op, Size(a), a, Val{D}, v0)
102102
end
103103

104-
@generated function _mapreducedim(f, op, ::Size{S}, a::StaticArray, ::Type{Val{D}}) where {S, D}
104+
@generated function _mapreducedim(f, op, ::Size{S}, a::StaticArray, ::Type{Val{D}}) where {S,D}
105105
N = length(S)
106106
Snew = ([n==D ? 1 : S[n] for n = 1:N]...)
107+
T0 = eltype(a)
108+
T = :((T1 = Base.promote_op(f, $T0); Base.promote_op(op, T1, T1)))
107109

108110
exprs = Array{Expr}(Snew)
109111
itr = [1:n for n Snew]
@@ -118,14 +120,13 @@ end
118120
exprs[i...] = expr
119121
end
120122

121-
# TODO element type might change
122123
return quote
123124
@_inline_meta
124-
@inbounds return similar_type(a, Size($Snew))(tuple($(exprs...)))
125+
@inbounds return similar_type(a, $T, Size($Snew))(tuple($(exprs...)))
125126
end
126127
end
127128

128-
@generated function _mapreducedim(f, op, ::Size{S}, a::StaticArray, ::Type{Val{D}}, v0) where {S, D}
129+
@generated function _mapreducedim(f, op, ::Size{S}, a::StaticArray, ::Type{Val{D}}, v0::T) where {S,D,T}
129130
N = length(S)
130131
Snew = ([n==D ? 1 : S[n] for n = 1:N]...)
131132

@@ -142,10 +143,9 @@ end
142143
exprs[i...] = expr
143144
end
144145

145-
# TODO element type might change
146146
return quote
147147
@_inline_meta
148-
@inbounds return similar_type(a, Size($Snew))(tuple($(exprs...)))
148+
@inbounds return similar_type(a, T, Size($Snew))(tuple($(exprs...)))
149149
end
150150
end
151151

@@ -160,33 +160,82 @@ end
160160
## reducedim ##
161161
###############
162162

163-
@inline reducedim(op, a::StaticArray, ::Val{D}) where {D} = mapreducedim(identity, op, a, Val{D})
164-
@inline reducedim(op, a::StaticArray, ::Val{D}, v0) where {D} = mapreducedim(identity, op, a, Val{D}, v0)
163+
@inline reducedim(op, a::StaticArray, ::Type{Val{D}}) where {D} = mapreducedim(identity, op, a, Val{D})
164+
@inline reducedim(op, a::StaticArray, ::Type{Val{D}}, v0) where {D} = mapreducedim(identity, op, a, Val{D}, v0)
165165

166166
#######################
167167
## related functions ##
168168
#######################
169169

170170
# These are all similar in Base but not @inline'd
171-
@inline sum(a::StaticArray{<:Any, T}) where {T} = reduce(+, zero(T), a)
172-
@inline sum(f::Base.Callable, a::StaticArray) = mapreduce(f, +, a)
173-
@inline prod(a::StaticArray{<:Any, T}) where {T} = reduce(*, one(T), a)
174-
@inline count(a::StaticArray{<:Any, Bool}) = reduce(+, 0, a)
175-
@inline all(a::StaticArray{<:Any, Bool}) = reduce(&, true, a) # non-branching versions
176-
@inline any(a::StaticArray{<:Any, Bool}) = reduce(|, false, a) # (benchmarking needed)
171+
#
172+
# Implementation notes:
173+
#
174+
# 1. When providing an initial value v0, note that its location is different in reduce and
175+
# reducedim: v0 comes earlier than collection in reduce, whereas it is the last argument in
176+
# reducedim. The same difference exists between mapreduce and mapreducedim.
177+
#
178+
# 2. mapreduce and mapreducedim usually do not take initial value v0, because we don't
179+
# always know the return type of an arbitrary mapping function f. (We usually want to use
180+
# some initial value such as one(T) or zero(T) as v0, where T is the return type of f, but
181+
# if users provide type-unstable f, its return type cannot be known.) Therefore, mapped
182+
# versions of the functions implemented below usually require the collection to have at
183+
# least two entries.
184+
#
185+
# 3. Exceptions are the ones that require Boolean mapping functions. For example, f in
186+
# all and any must return Bool, so we know the appropriate v0 is true and false,
187+
# respectively. Therefore, all(f, ...) and any(f, ...) are implemented by mapreduce(f, ...)
188+
# with an initial value v0 = true and false.
189+
@inline iszero(a::StaticArray{<:Any,T}) where {T} = reduce((x,y) -> x && (y==zero(T)), true, a)
190+
191+
@inline sum(a::StaticArray{<:Any,T}) where {T} = reduce(+, zero(T), a)
192+
@inline sum(f::Function, a::StaticArray) = mapreduce(f, +, a)
193+
@inline sum(a::StaticArray{<:Any,T}, ::Type{Val{D}}) where {T,D} = reducedim(+, a, Val{D}, zero(T))
194+
@inline sum(f::Function, a::StaticArray, ::Type{Val{D}}) where D = mapreducedim(f, +, a, Val{D})
195+
196+
@inline prod(a::StaticArray{<:Any,T}) where {T} = reduce(*, one(T), a)
197+
@inline prod(f::Function, a::StaticArray{<:Any,T}) where {T} = mapreduce(f, *, a)
198+
@inline prod(a::StaticArray{<:Any,T}, ::Type{Val{D}}) where {T,D} = reducedim(*, a, Val{D}, one(T))
199+
@inline prod(f::Function, a::StaticArray{<:Any,T}, ::Type{Val{D}}) where {T,D} = mapreducedim(f, *, a, Val{D})
200+
201+
@inline count(a::StaticArray{<:Any,Bool}) = reduce(+, 0, a)
202+
@inline count(f::Function, a::StaticArray) = mapreduce(x->f(x)::Bool, +, 0, a)
203+
@inline count(a::StaticArray{<:Any,Bool}, ::Type{Val{D}}) where {D} = reducedim(+, a, Val{D}, 0)
204+
@inline count(f::Function, a::StaticArray, ::Type{Val{D}}) where {D} = mapreducedim(x->f(x)::Bool, +, a, Val{D}, 0)
205+
206+
@inline all(a::StaticArray{<:Any,Bool}) = reduce(&, true, a) # non-branching versions
207+
@inline all(f::Function, a::StaticArray) = mapreduce(x->f(x)::Bool, &, true, a)
208+
@inline all(a::StaticArray{<:Any,Bool}, ::Type{Val{D}}) where {D} = reducedim(&, a, Val{D}, true)
209+
@inline all(f::Function, a::StaticArray, ::Type{Val{D}}) where {D} = mapreducedim(x->f(x)::Bool, &, a, Val{D}, true)
210+
211+
@inline any(a::StaticArray{<:Any,Bool}) = reduce(|, false, a) # (benchmarking needed)
212+
@inline any(f::Function, a::StaticArray) = mapreduce(x->f(x)::Bool, |, false, a) # (benchmarking needed)
213+
@inline any(a::StaticArray{<:Any,Bool}, ::Type{Val{D}}) where {D} = reducedim(|, a, Val{D}, false)
214+
@inline any(f::Function, a::StaticArray, ::Type{Val{D}}) where {D} = mapreducedim(x->f(x)::Bool, |, a, Val{D}, false)
215+
177216
@inline mean(a::StaticArray) = sum(a) / length(a)
217+
@inline mean(f::Function, a::StaticArray) = sum(f, a) / length(a)
218+
@inline mean(a::StaticArray, ::Type{Val{D}}) where {D} = sum(a, Val{D}) / size(a, D)
219+
@inline mean(f::Function, a::StaticArray, ::Type{Val{D}}) where {D} = sum(f, a, Val{D}) / size(a, D)
220+
178221
@inline minimum(a::StaticArray) = reduce(min, a) # base has mapreduce(idenity, scalarmin, a)
222+
@inline minimum(f::Function, a::StaticArray) = mapreduce(f, min, a)
223+
@inline minimum(a::StaticArray, ::Type{Val{D}}) where {D} = reducedim(min, a, Val{D})
224+
@inline minimum(f::Function, a::StaticArray, ::Type{Val{D}}) where {D} = mapreducedim(f, min, a, Val{D})
225+
179226
@inline maximum(a::StaticArray) = reduce(max, a) # base has mapreduce(idenity, scalarmax, a)
180-
@inline minimum(a::StaticArray, dim::Type{Val{D}}) where {D} = reducedim(min, a, dim)
181-
@inline maximum(a::StaticArray, dim::Type{Val{D}}) where {D} = reducedim(max, a, dim)
227+
@inline maximum(f::Function, a::StaticArray) = mapreduce(f, max, a)
228+
@inline maximum(a::StaticArray, ::Type{Val{D}}) where {D} = reducedim(max, a, Val{D})
229+
@inline maximum(f::Function, a::StaticArray, ::Type{Val{D}}) where {D} = mapreducedim(f, max, a, Val{D})
182230

183231
# Diff is slightly different
184232
@inline diff(a::StaticArray) = diff(a, Val{1})
185233
@inline diff(a::StaticArray, ::Type{Val{D}}) where {D} = _diff(Size(a), a, Val{D})
186234

187-
@generated function _diff(::Size{S}, a::StaticArray, ::Type{Val{D}}) where {S, D}
235+
@generated function _diff(::Size{S}, a::StaticArray, ::Type{Val{D}}) where {S,D}
188236
N = length(S)
189237
Snew = ([n==D ? S[n]-1 : S[n] for n = 1:N]...)
238+
T = Base.promote_op(-, eltype(a), eltype(a))
190239

191240
exprs = Array{Expr}(Snew)
192241
itr = [1:n for n = Snew]
@@ -197,9 +246,8 @@ end
197246
exprs[i1...] = :(a[$(i2...)] - a[$(i1...)])
198247
end
199248

200-
# TODO element type might change
201249
return quote
202250
@_inline_meta
203-
@inbounds return similar_type(a, Size($Snew))(tuple($(exprs...)))
251+
@inbounds return similar_type(a, $T, Size($Snew))(tuple($(exprs...)))
204252
end
205253
end

test/mapreduce.jl

Lines changed: 74 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -23,48 +23,82 @@
2323
@test mv3 == @MVector [7, 9, 11, 13]
2424
end
2525

26-
@testset "reduce" begin
27-
v1 = @SVector [2,4,6,8]
28-
@test reduce(+, v1) === 20
29-
@test reduce(+, 0, v1) === 20
30-
@test sum(v1) === 20
31-
@test sum(abs2, v1) === 120
32-
@test prod(v1) === 384
33-
@test mean(v1) === 5.
34-
@test maximum(v1) === 8
35-
@test minimum(v1) === 2
36-
vb = @SVector [true, false, true, false]
37-
@test count(vb) === 2
38-
@test any(vb)
39-
end
40-
41-
@testset "reduce in dim" begin
42-
a = @SArray rand(4,3,2)
43-
@test maximum(a, Val{1}) == maximum(a, 1)
44-
@test maximum(a, Val{2}) == maximum(a, 2)
45-
@test maximum(a, Val{3}) == maximum(a, 3)
46-
@test minimum(a, Val{1}) == minimum(a, 1)
47-
@test minimum(a, Val{2}) == minimum(a, 2)
48-
@test minimum(a, Val{3}) == minimum(a, 3)
49-
@test diff(a) == diff(a, Val{1}) == a[2:end,:,:] - a[1:end-1,:,:]
50-
@test diff(a, Val{2}) == a[:,2:end,:] - a[:,1:end-1,:]
51-
@test diff(a, Val{3}) == a[:,:,2:end] - a[:,:,1:end-1]
52-
53-
a = @SArray rand(4,3) # as of Julia v0.5, diff() for regular Array is defined only for vectors and matrices
54-
@test diff(a) == diff(a, Val{1}) == diff(a, 1)
55-
@test diff(a, Val{2}) == diff(a, 2)
56-
57-
@test reducedim(max, a, Val{1}, -1.) == reducedim(max, a, 1, -1.)
58-
@test reducedim(max, a, Val{2}, -1.) == reducedim(max, a, 2, -1.)
26+
@testset "[map]reduce and [map]reducedim" begin
27+
a = rand(4,3); sa = SMatrix{4,3}(a); (I,J) = size(a)
28+
v1 = [2,4,6,8]; sv1 = SVector{4}(v1)
29+
v2 = [4,3,2,1]; sv2 = SVector{4}(v2)
30+
@test reduce(+, sv1) === reduce(+, v1)
31+
@test reduce(+, 0, sv1) === reduce(+, 0, v1)
32+
@test reducedim(max, sa, Val{1}, -1.) === SMatrix{1,J}(reducedim(max, a, 1, -1.))
33+
@test reducedim(max, sa, Val{2}, -1.) === SMatrix{I,1}(reducedim(max, a, 2, -1.))
34+
@test mapreduce(-, +, sv1) === mapreduce(-, +, v1)
35+
@test mapreduce(-, +, 0, sv1) === mapreduce(-, +, 0, v1)
36+
@test mapreduce(*, +, sv1, sv2) === 40
37+
@test mapreduce(*, +, 0, sv1, sv2) === 40
38+
@test mapreducedim(x->x^2, max, sa, Val{1}, -1.) == SMatrix{1,J}(mapreducedim(x->x^2, max, a, 1, -1.))
39+
@test mapreducedim(x->x^2, max, sa, Val{2}, -1.) == SMatrix{I,1}(mapreducedim(x->x^2, max, a, 2, -1.))
5940
end
6041

61-
@testset "mapreduce" begin
62-
v1 = @SVector [2,4,6,8]
63-
v2 = @SVector [4,3,2,1]
64-
@test mapreduce(-, +, v1) === -20
65-
@test mapreduce(-, +, 0, v1) === -20
66-
@test mapreduce(*, +, v1, v2) === 40
67-
@test mapreduce(*, +, 0, v1, v2) === 40
42+
@testset "implemented by [map]reduce and [map]reducedim" begin
43+
I, J, K = 2, 2, 2
44+
OSArray = SArray{Tuple{I,J,K}} # original
45+
RSArray1 = SArray{Tuple{1,J,K}} # reduced in dimension 1
46+
RSArray2 = SArray{Tuple{I,1,K}} # reduced in dimension 2
47+
RSArray3 = SArray{Tuple{I,J,1}} # reduced in dimension 3
48+
a = randn(I,J,K); sa = OSArray(a)
49+
b = rand(Bool,I,J,K); sb = OSArray(b)
50+
z = zeros(I,J,K); sz = OSArray(z)
51+
52+
@test iszero(sz) == iszero(z)
53+
54+
@test sum(sa) === sum(a)
55+
@test sum(abs2, sa) === sum(abs2, a)
56+
@test sum(sa, Val{2}) === RSArray2(sum(a, 2))
57+
@test sum(abs2, sa, Val{2}) === RSArray2(sum(abs2, a, 2))
58+
59+
@test prod(sa) === prod(a)
60+
@test prod(abs2, sa) === prod(abs2, a)
61+
@test prod(sa, Val{2}) === RSArray2(prod(a, 2))
62+
@test prod(abs2, sa, Val{2}) === RSArray2(prod(abs2, a, 2))
63+
64+
@test count(sb) === count(b)
65+
@test count(x->x>0, sa) === count(x->x>0, a)
66+
@test count(sb, Val{2}) === RSArray2(reshape([count(b[i,:,k]) for i = 1:I, k = 1:K], (I,1,K)))
67+
@test count(x->x>0, sa, Val{2}) === RSArray2(reshape([count(x->x>0, a[i,:,k]) for i = 1:I, k = 1:K], (I,1,K)))
68+
69+
@test all(sb) === all(b)
70+
@test all(x->x>0, sa) === all(x->x>0, a)
71+
@test all(sb, Val{2}) === RSArray2(all(b, 2))
72+
@test all(x->x>0, sa, Val{2}) === RSArray2(all(x->x>0, a, 2))
73+
74+
@test any(sb) === any(b)
75+
@test any(x->x>0, sa) === any(x->x>0, a)
76+
@test any(sb, Val{2}) === RSArray2(any(b, 2))
77+
@test any(x->x>0, sa, Val{2}) === RSArray2(any(x->x>0, a, 2))
78+
79+
@test mean(sa) === mean(a)
80+
@test mean(abs2, sa) === mean(abs2, a)
81+
@test mean(sa, Val{2}) === RSArray2(mean(a, 2))
82+
@test mean(abs2, sa, Val{2}) === RSArray2(mean(abs2.(a), 2))
83+
84+
@test minimum(sa) === minimum(a)
85+
@test minimum(abs2, sa) === minimum(abs2, a)
86+
@test minimum(sa, Val{2}) === RSArray2(minimum(a, 2))
87+
@test minimum(abs2, sa, Val{2}) === RSArray2(minimum(abs2, a, 2))
88+
89+
@test maximum(sa) === maximum(a)
90+
@test maximum(abs2, sa) === maximum(abs2, a)
91+
@test maximum(sa, Val{2}) === RSArray2(maximum(a, 2))
92+
@test maximum(abs2, sa, Val{2}) === RSArray2(maximum(abs2, a, 2))
93+
94+
@test diff(sa, Val{1}) === RSArray1(a[2:end,:,:] - a[1:end-1,:,:])
95+
@test diff(sa, Val{2}) === RSArray2(a[:,2:end,:] - a[:,1:end-1,:])
96+
@test diff(sa, Val{3}) === RSArray3(a[:,:,2:end] - a[:,:,1:end-1])
97+
98+
# as of Julia v0.6, diff() for regular Array is defined only for vectors and matrices
99+
m = randn(4,3); sm = SMatrix{4,3}(m)
100+
@test diff(sm, Val{1}) == diff(m, 1) == diff(sm) == diff(m)
101+
@test diff(sm, Val{2}) == diff(m, 2)
68102
end
69103

70104
@testset "broadcast and broadcast!" begin

0 commit comments

Comments
 (0)