Skip to content

Commit 20275db

Browse files
authored
Merge pull request #86 from wsshin/reducer-with-dim
Implement versions of maximum(), minimum(), diff() with Val{dim} as second argument
2 parents 06689d4 + 4f024a7 commit 20275db

File tree

3 files changed

+75
-7
lines changed

3 files changed

+75
-7
lines changed

src/StaticArrays.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,12 @@ __precompile__()
33
module StaticArrays
44

55
import Base: @pure, @propagate_inbounds, getindex, setindex!, size, similar,
6-
length, convert, promote_op, map, map!, reduce, mapreduce,
7-
broadcast, broadcast!, conj, transpose, ctranspose, hcat, vcat,
8-
ones, zeros, eye, one, cross, vecdot, reshape, fill, fill!, det,
9-
inv, eig, eigvals, trace, vecnorm, norm, dot, diagm, sum, prod,
10-
count, any, all, sumabs, sumabs2, minimum, maximum, extrema, mean,
11-
copy
6+
length, convert, promote_op, map, map!, reduce, reducedim,
7+
mapreduce, broadcast, broadcast!, conj, transpose, ctranspose,
8+
hcat, vcat, ones, zeros, eye, one, cross, vecdot, reshape, fill,
9+
fill!, det, inv, eig, eigvals, trace, vecnorm, norm, dot, diagm,
10+
sum, diff, prod, count, any, all, sumabs, sumabs2, minimum,
11+
maximum, extrema, mean, copy
1212

1313
export StaticScalar, StaticArray, StaticVector, StaticMatrix
1414
export Scalar, SArray, SVector, SMatrix

src/mapreduce.jl

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,36 @@ end
9090
end
9191
end
9292

93+
@generated function reducedim{D}(op, a::StaticArray, ::Type{Val{D}})
94+
S = size(a)
95+
if S[D] == 1
96+
return :(return a)
97+
else
98+
N = ndims(a)
99+
Snew = ([n==D ? 1 : S[n] for n = 1:N]...)
100+
newtype = similar_type(a, Snew)
101+
102+
exprs = Array{Expr}(Snew)
103+
itr = [1:n for n = Snew]
104+
for i = Base.product(itr...)
105+
ik = copy([i...])
106+
ik[D] = 2
107+
expr = :(op(a[$(i...)], a[$(ik...)]))
108+
for k = 3:S[D]
109+
ik[D] = k
110+
expr = :(op($expr, a[$(ik...)]))
111+
end
112+
113+
exprs[i...] = expr
114+
end
115+
116+
return quote
117+
$(Expr(:meta,:inline))
118+
@inbounds return $(Expr(:call, newtype, Expr(:tuple, exprs...)))
119+
end
120+
end
121+
end
122+
93123
# These are all similar in Base but not @inline'd
94124
@inline sum{T}(a::StaticArray{T}) = reduce(+, zero(T), a)
95125
@inline prod{T}(a::StaticArray{T}) = reduce(*, one(T), a)
@@ -101,8 +131,29 @@ end
101131
@inline sumabs2{T}(a::StaticArray{T}) = mapreduce(abs2, +, zero(T), a)
102132
@inline minimum(a::StaticArray) = reduce(min, a) # base has mapreduce(idenity, scalarmin, a)
103133
@inline maximum(a::StaticArray) = reduce(max, a) # base has mapreduce(idenity, scalarmax, a)
134+
@inline minimum{D}(a::StaticArray, dim::Type{Val{D}}) = reducedim(min, a, dim)
135+
@inline maximum{D}(a::StaticArray, dim::Type{Val{D}}) = reducedim(max, a, dim)
136+
137+
@generated function diff{D}(a::StaticArray, ::Type{Val{D}}=Val{1})
138+
S = size(a)
139+
N = ndims(a)
140+
Snew = ([n==D ? S[n]-1:S[n] for n = 1:N]...)
141+
newtype = similar_type(a, Snew)
142+
143+
exprs = Array{Expr}(Snew)
144+
itr = [1:n for n = Snew]
145+
146+
for i1 = Base.product(itr...)
147+
i2 = copy([i1...])
148+
i2[D] = i1[D] + 1
149+
exprs[i1...] = :(a[$(i2...)] - a[$(i1...)])
150+
end
104151

105-
152+
return quote
153+
$(Expr(:meta,:inline))
154+
@inbounds return $(Expr(:call, newtype, Expr(:tuple, exprs...)))
155+
end
156+
end
106157

107158
###############
108159
## mapreduce ##

test/mapreduce.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,23 @@
2424
@test prod(v1) === 384
2525
end
2626

27+
@testset "reduce in dim" begin
28+
a = @SArray rand(4,3,2)
29+
@test maximum(a, Val{1}) == maximum(a, 1)
30+
@test maximum(a, Val{2}) == maximum(a, 2)
31+
@test maximum(a, Val{3}) == maximum(a, 3)
32+
@test minimum(a, Val{1}) == minimum(a, 1)
33+
@test minimum(a, Val{2}) == minimum(a, 2)
34+
@test minimum(a, Val{3}) == minimum(a, 3)
35+
@test diff(a) == diff(a, Val{1}) == a[2:end,:,:] - a[1:end-1,:,:]
36+
@test diff(a, Val{2}) == a[:,2:end,:] - a[:,1:end-1,:]
37+
@test diff(a, Val{3}) == a[:,:,2:end] - a[:,:,1:end-1]
38+
39+
a = @SArray rand(4,3) # as of Julia v0.5, diff() for regular Array is defined only for vectors and matrices
40+
@test diff(a) == diff(a, Val{1}) == diff(a, 1)
41+
@test diff(a, Val{2}) == diff(a, 2)
42+
end
43+
2744
@testset "mapreduce" begin
2845
v1 = @SVector [2,4,6,8]
2946
v2 = @SVector [4,3,2,1]

0 commit comments

Comments
 (0)