Skip to content

Commit b2f2fb0

Browse files
thchrmateuszbaran
andauthored
Let count, sum, and prod take an init kwarg (#1281)
* let `count` take an `init` kwarg - `Base.count` allows the `init` kwarg since v1.6 - bump StaticArrays to 1.9.8 * add `init` kwarg to `sum` and `prod` as well (fix #1119) * run invalidations action on julia lts * Move `solve.jl` tests to group B to hopefully OOM issues on Julia 1.6 CI --------- Co-authored-by: Mateusz Baran <mateuszbaran89@gmail.com>
1 parent 899fb0a commit b2f2fb0

File tree

5 files changed

+17
-11
lines changed

5 files changed

+17
-11
lines changed

.github/workflows/Invalidations.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ jobs:
1818
steps:
1919
- uses: julia-actions/setup-julia@v2
2020
with:
21-
version: '1'
21+
version: 'lts'
2222
- uses: actions/checkout@v4
2323
- uses: julia-actions/julia-buildpkg@v1
2424
- uses: julia-actions/julia-invalidations@v1

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "StaticArrays"
22
uuid = "90137ffa-7385-5640-81b9-e52037218182"
3-
version = "1.9.7"
3+
version = "1.9.8"
44

55
[deps]
66
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"

src/mapreduce.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -284,16 +284,16 @@ reduce(::typeof(hcat), A::StaticArray{<:Tuple,<:StaticVecOrMatLike}) =
284284
# TODO: change to use Base.reduce_empty/Base.reduce_first
285285
@inline iszero(a::StaticArray{<:Tuple,T}) where {T} = reduce((x,y) -> x && iszero(y), a, init=true)
286286

287-
@inline sum(a::StaticArray{<:Tuple,T}; dims=:) where {T} = _reduce(+, a, dims)
288-
@inline sum(f, a::StaticArray{<:Tuple,T}; dims=:) where {T} = _mapreduce(f, +, dims, _InitialValue(), Size(a), a)
289-
@inline sum(f::Union{Function, Type}, a::StaticArray{<:Tuple,T}; dims=:) where {T} = _mapreduce(f, +, dims, _InitialValue(), Size(a), a) # avoid ambiguity
287+
@inline sum(a::StaticArray{<:Tuple,T}; dims=:, init=_InitialValue()) where {T} = _reduce(+, a, dims, init)
288+
@inline sum(f, a::StaticArray{<:Tuple,T}; dims=:, init=_InitialValue()) where {T} = _mapreduce(f, +, dims, init, Size(a), a)
289+
@inline sum(f::Union{Function, Type}, a::StaticArray{<:Tuple,T}; dims=:, init=_InitialValue()) where {T} = _mapreduce(f, +, dims, init, Size(a), a) # avoid ambiguity
290290

291-
@inline prod(a::StaticArray{<:Tuple,T}; dims=:) where {T} = _reduce(*, a, dims)
292-
@inline prod(f, a::StaticArray{<:Tuple,T}; dims=:) where {T} = _mapreduce(f, *, dims, _InitialValue(), Size(a), a)
293-
@inline prod(f::Union{Function, Type}, a::StaticArray{<:Tuple,T}; dims=:) where {T} = _mapreduce(f, *, dims, _InitialValue(), Size(a), a)
291+
@inline prod(a::StaticArray{<:Tuple,T}; dims=:, init=_InitialValue()) where {T} = _reduce(*, a, dims, init)
292+
@inline prod(f, a::StaticArray{<:Tuple,T}; dims=:, init=_InitialValue()) where {T} = _mapreduce(f, *, dims, init, Size(a), a)
293+
@inline prod(f::Union{Function, Type}, a::StaticArray{<:Tuple,T}; dims=:, init=_InitialValue()) where {T} = _mapreduce(f, *, dims, init, Size(a), a)
294294

295-
@inline count(a::StaticArray{<:Tuple,Bool}; dims=:) = _reduce(+, a, dims)
296-
@inline count(f, a::StaticArray; dims=:) = _mapreduce(x->f(x)::Bool, +, dims, _InitialValue(), Size(a), a)
295+
@inline count(a::StaticArray{<:Tuple,Bool}; dims=:, init=0) = _reduce(+, a, dims, init)
296+
@inline count(f, a::StaticArray; dims=:, init=0) = _mapreduce(x->f(x)::Bool, +, dims, init, Size(a), a)
297297

298298
@inline all(a::StaticArray{<:Tuple,Bool}; dims=:) = _reduce(&, a, dims, true) # non-branching versions
299299
@inline all(f::Function, a::StaticArray; dims=:) = _mapreduce(x->f(x)::Bool, &, dims, true, Size(a), a)

test/mapreduce.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,16 +130,22 @@ using Statistics: mean
130130
@test sum(sa, dims=Val(2)) === RSArray2(sum(a, dims=2))
131131
@test sum(abs2, sa; dims=2) === RSArray2(sum(abs2, a, dims=2))
132132
@test sum(abs2, sa; dims=Val(2)) === RSArray2(sum(abs2, a, dims=2))
133+
@test sum(sa, init=2) == sum(a, init=2) sum(sa) + 2 # Float64 is non-associative
134+
@test sum(sb, init=2) == sum(b, init=2) == sum(sb) + 2
133135

134136
@test prod(sa) === prod(a)
135137
@test prod(abs2, sa) === prod(abs2, a)
136138
@test prod(sa, dims=Val(2)) === RSArray2(prod(a, dims=2))
137139
@test prod(abs2, sa, dims=Val(2)) === RSArray2(prod(abs2, a, dims=2))
140+
@test prod(sa, init=2) == prod(a, init=2) 2*prod(sa) # Float64 is non-associative
141+
@test prod(sb, init=2) == prod(b, init=2) == 2*prod(sb)
138142

139143
@test count(sb) === count(b)
140144
@test count(x->x>0, sa) === count(x->x>0, a)
141145
@test count(sb, dims=Val(2)) === RSArray2(reshape([count(b[i,:,k]) for i = 1:I, k = 1:K], (I,1,K)))
142146
@test count(x->x>0, sa, dims=Val(2)) === RSArray2(reshape([count(x->x>0, a[i,:,k]) for i = 1:I, k = 1:K], (I,1,K)))
147+
@test count(sb, init=3) == count(b, init=3) == count(sb) + 3
148+
@test count(x->x>0, sa, init=-2) == count(x->x>0, a, init=-2) == count(x->x>0, sa) - 2
143149

144150
@test all(sb) === all(b)
145151
@test all(x->x>0, sa) === all(x->x>0, a)

test/runtests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,6 @@ if TEST_GROUP ∈ ["", "all", "group-A"]
6767
addtests("det.jl")
6868
addtests("inv.jl")
6969
addtests("pinv.jl")
70-
addtests("solve.jl")
7170

7271
# special logic required since we need to start a new
7372
# Julia process for these tests
@@ -78,6 +77,7 @@ if TEST_GROUP ∈ ["", "all", "group-A"]
7877
end
7978

8079
if TEST_GROUP ["", "all", "group-B"]
80+
addtests("solve.jl")
8181
addtests("eigen.jl")
8282
addtests("expm.jl")
8383
addtests("sqrtm.jl")

0 commit comments

Comments
 (0)