Skip to content

Commit 7a78efd

Browse files
tkfc42f
andauthored
Implement accumulate and friends (#702)
* Implement accumulate and friends * Run tests for accumulate * Skip inference tests in Julia 1.1 * Update src/mapreduce.jl Co-Authored-By: Chris Foster <chris42f@gmail.com> * Rename: _maybeval -> _maybe_val * Explain how `_map` is used from `_accumulate` * Revert: (push(ys, y), y) This reverts commit 4ca0144. * Comment on why we use `vcat` * Use inference to determine element types * Use reduce_empty in cumsum/cumprod for Array-compatibility * Use reduce_first instead of reduce_empty Co-authored-by: Chris Foster <chris42f@gmail.com>
1 parent e48d2f0 commit 7a78efd

File tree

3 files changed

+117
-0
lines changed

3 files changed

+117
-0
lines changed

src/mapreduce.jl

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -285,3 +285,53 @@ end
285285
@inbounds return similar_type(a, T, Size($Snew))(tuple($(exprs...)))
286286
end
287287
end
288+
289+
struct _InitialValue end
290+
291+
_maybe_val(dims::Integer) = Val(Int(dims))
292+
_maybe_val(dims) = dims
293+
_valof(::Val{D}) where D = D
294+
295+
@inline Base.accumulate(op::F, a::StaticVector; dims = :, init = _InitialValue()) where {F} =
296+
_accumulate(op, a, _maybe_val(dims), init)
297+
298+
@inline Base.accumulate(op::F, a::StaticArray; dims, init = _InitialValue()) where {F} =
299+
_accumulate(op, a, _maybe_val(dims), init)
300+
301+
@inline function _accumulate(op::F, a::StaticArray, dims::Union{Val,Colon}, init) where {F}
302+
# Adjoin the initial value to `op`:
303+
rf(x, y) = x isa _InitialValue ? Base.reduce_first(op, y) : op(x, y)
304+
305+
if isempty(a)
306+
T = return_type(rf, Tuple{typeof(init), eltype(a)})
307+
return similar_type(a, T)()
308+
end
309+
310+
# StaticArrays' `reduce` is `foldl`:
311+
results = _reduce(
312+
a,
313+
dims,
314+
(init = (similar_type(a, Union{}, Size(0))(), init),),
315+
) do (ys, acc), x
316+
y = rf(acc, x)
317+
# Not using `push(ys, y)` here since we need to widen element type as
318+
# we iterate.
319+
(vcat(ys, SA[y]), y)
320+
end
321+
dims === (:) && return first(results)
322+
323+
ys = map(first, results)
324+
# Now map over all indices of `a`. Since `_map` needs at least
325+
# one `StaticArray` to be passed, we pass `a` here, even though
326+
# the values of `a` are not used.
327+
data = _map(a, CartesianIndices(a)) do _, CI
328+
D = _valof(dims)
329+
I = Tuple(CI)
330+
J = setindex(I, 1, D)
331+
ys[J...][I[D]]
332+
end
333+
return similar_type(a, eltype(data))(data)
334+
end
335+
336+
@inline Base.cumsum(a::StaticArray; kw...) = accumulate(Base.add_sum, a; kw...)
337+
@inline Base.cumprod(a::StaticArray; kw...) = accumulate(Base.mul_prod, a; kw...)

test/accumulate.jl

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
using StaticArrays, Test
2+
3+
@testset "accumulate" begin
4+
@testset "cumsum(::$label)" for (label, T) in [
5+
# label, T
6+
("SVector", SVector),
7+
("MVector", MVector),
8+
("SizedVector", SizedVector),
9+
]
10+
@testset "$label" for (label, a) in [
11+
("[1, 2, 3]", T{3}(SA[1, 2, 3])),
12+
("[]", T{0,Int}(())),
13+
]
14+
@test cumsum(a) == cumsum(collect(a))
15+
@test cumsum(a) isa similar_type(a)
16+
@inferred cumsum(a)
17+
end
18+
@test eltype(cumsum(T{0,Int8}(()))) == eltype(cumsum(Int8[]))
19+
@test eltype(cumsum(T{1,Int8}((1)))) == eltype(cumsum(Int8[1]))
20+
@test eltype(cumsum(T{2,Int8}((1, 2)))) == eltype(cumsum(Int8[1, 2]))
21+
end
22+
23+
@testset "cumsum(::$label; dims=2)" for (label, T) in [
24+
# label, T
25+
("SMatrix", SMatrix),
26+
("MMatrix", MMatrix),
27+
("SizedMatrix", SizedMatrix),
28+
]
29+
@testset "$label" for (label, a) in [
30+
("[1 2; 3 4; 5 6]", T{3,2}(SA[1 2; 3 4; 5 6])),
31+
("0 x 2 matrix", T{0,2,Float64}()),
32+
("2 x 0 matrix", T{2,0,Float64}()),
33+
]
34+
@test cumsum(a; dims = 2) == cumsum(collect(a); dims = 2)
35+
@test cumsum(a; dims = 2) isa similar_type(a)
36+
v"1.1" <= VERSION < v"1.2" && continue
37+
@inferred cumsum(a; dims = Val(2))
38+
end
39+
end
40+
41+
@testset "cumsum(a::SArray; dims=$i); ndims(a) = $d" for d in 1:4, i in 1:d
42+
shape = Tuple(1:d)
43+
a = similar_type(SArray, Int, Size(shape))(1:prod(shape))
44+
@test cumsum(a; dims = i) == cumsum(collect(a); dims = i)
45+
@test cumsum(a; dims = i) isa SArray
46+
v"1.1" <= VERSION < v"1.2" && continue
47+
@inferred cumsum(a; dims = Val(i))
48+
end
49+
50+
@testset "cumprod" begin
51+
a = SA[1, 2, 3]
52+
@test cumprod(a)::SArray == cumprod(collect(a))
53+
@inferred cumprod(a)
54+
55+
@test eltype(cumsum(SA{Int8}[])) == eltype(cumsum(Int8[]))
56+
@test eltype(cumsum(SA{Int8}[1])) == eltype(cumsum(Int8[1]))
57+
@test eltype(cumsum(SA{Int8}[1, 2])) == eltype(cumsum(Int8[1, 2]))
58+
end
59+
60+
@testset "empty vector with init" begin
61+
a = SA{Int}[]
62+
right(_, x) = x
63+
@test accumulate(right, a; init = Val(1)) === SA{Int}[]
64+
@inferred accumulate(right, a; init = Val(1))
65+
end
66+
end

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ include("abstractarray.jl")
3333
include("indexing.jl")
3434
include("initializers.jl")
3535
Random.seed!(42); include("mapreduce.jl")
36+
Random.seed!(42); include("accumulate.jl")
3637
Random.seed!(42); include("arraymath.jl")
3738
include("broadcast.jl")
3839
include("linalg.jl")

0 commit comments

Comments
 (0)