Skip to content

Commit 5bdbf6c

Browse files
author
Andy Ferris
committed
Better broadcasting/reduction operations
* More coverage of common operators like .== * Following the Julia v0.5 convention of array broadcasting for !, &, | and $ * Various @inbounds fixes (I wonder if these were causing my earlier worries abound @BoundsCheck not being eliminated in certain cases)
1 parent b9ea6ba commit 5bdbf6c

File tree

3 files changed

+73
-24
lines changed

3 files changed

+73
-24
lines changed

src/StaticArrays.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@ import Base: @pure, @propagate_inbounds, getindex, setindex!, size, similar,
66
length, convert, promote_op, map, map!, reduce, mapreduce,
77
broadcast, broadcast!, conj, transpose, ctranspose, hcat, vcat,
88
ones, zeros, eye, cross, vecdot, reshape, fill, fill!, det, inv,
9-
eig, trace, vecnorm, norm, dot, diagm
9+
eig, trace, vecnorm, norm, dot, diagm, sum, prod, count, sumabs,
10+
sumabs2, minimum, maximum, extrema, mean
1011

1112
export StaticScalar, StaticArray, StaticVector, StaticMatrix
1213
export Scalar, SArray, SVector, SMatrix

src/arraymath.jl

Lines changed: 41 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,37 +1,74 @@
1-
import Base: .+, .-, .*, ./, .\, .%, .^
1+
import Base: .+, .-, .*, ./, .//, .\, .%, .^, .<<, .>>, .==, .<, .>, .<=, .>=, !, &, |, $
22

33
# Support for elementwise ops on AbstractArray{S<:StaticArray} with Number
44
Base.promote_op{Op,A<:StaticArray,T<:Number}(op::Op, ::Type{A}, ::Type{T}) = similar_type(A, promote_op(op, eltype(A), T))
55
Base.promote_op{Op,T<:Number,A<:StaticArray}(op::Op, ::Type{T}, ::Type{A}) = similar_type(A, promote_op(op, T, eltype(A)))
66

7-
8-
# TODO any more operators?
9-
107
@inline .-(a1::StaticArray) = broadcast(-, a1)
118

129
@inline .+(a1::StaticArray, a2::StaticArray) = broadcast(+, a1, a2)
1310
@inline .-(a1::StaticArray, a2::StaticArray) = broadcast(-, a1, a2)
1411
@inline .*(a1::StaticArray, a2::StaticArray) = broadcast(*, a1, a2)
1512
@inline ./(a1::StaticArray, a2::StaticArray) = broadcast(/, a1, a2)
13+
@inline .//(a1::StaticArray, a2::StaticArray) = broadcast(//, a1, a2)
1614
@inline .\(a1::StaticArray, a2::StaticArray) = broadcast(\, a1, a2)
1715
@inline .%(a1::StaticArray, a2::StaticArray) = broadcast(%, a1, a2)
1816
@inline .^(a1::StaticArray, a2::StaticArray) = broadcast(^, a1, a2)
17+
@inline .<<(a1::StaticArray, a2::StaticArray) = broadcast(<<, a1, a2)
18+
@inline .>>(a1::StaticArray, a2::StaticArray) = broadcast(<<, a1, a2)
19+
@inline .==(a1::StaticArray, a2::StaticArray) = broadcast(==, a1, a2)
20+
@inline .<(a1::StaticArray, a2::StaticArray) = broadcast(<, a1, a2)
21+
@inline .>(a1::StaticArray, a2::StaticArray) = broadcast(>, a1, a2)
22+
@inline .<=(a1::StaticArray, a2::StaticArray) = broadcast(<=, a1, a2)
23+
@inline .>=(a1::StaticArray, a2::StaticArray) = broadcast(>=, a1, a2)
1924

2025
@inline .+(a1::StaticArray, a2::Number) = broadcast(+, a1, a2)
2126
@inline .-(a1::StaticArray, a2::Number) = broadcast(-, a1, a2)
2227
@inline .*(a1::StaticArray, a2::Number) = broadcast(*, a1, a2)
2328
@inline ./(a1::StaticArray, a2::Number) = broadcast(/, a1, a2)
29+
@inline .//(a1::StaticArray, a2::Number) = broadcast(//, a1, a2)
2430
@inline .\(a1::StaticArray, a2::Number) = broadcast(\, a1, a2)
2531
@inline .%(a1::StaticArray, a2::Number) = broadcast(%, a1, a2)
2632
@inline .^(a1::StaticArray, a2::Number) = broadcast(^, a1, a2)
33+
@inline .<<(a1::StaticArray, a2::Number) = broadcast(<<, a1, a2)
34+
@inline .>>(a1::StaticArray, a2::Number) = broadcast(<<, a1, a2)
35+
@inline .==(a1::StaticArray, a2) = broadcast(==, a1, a2)
36+
@inline .<(a1::StaticArray, a2) = broadcast(<, a1, a2)
37+
@inline .>(a1::StaticArray, a2) = broadcast(>, a1, a2)
38+
@inline .<=(a1::StaticArray, a2) = broadcast(<=, a1, a2)
39+
@inline .>=(a1::StaticArray, a2) = broadcast(>=, a1, a2)
2740

2841
@inline .+(a1::Number, a2::StaticArray) = broadcast(+, a1, a2)
2942
@inline .-(a1::Number, a2::StaticArray) = broadcast(-, a1, a2)
3043
@inline .*(a1::Number, a2::StaticArray) = broadcast(*, a1, a2)
3144
@inline ./(a1::Number, a2::StaticArray) = broadcast(/, a1, a2)
45+
@inline .//(a1::Number, a2::StaticArray) = broadcast(//, a1, a2)
3246
@inline .\(a1::Number, a2::StaticArray) = broadcast(\, a1, a2)
3347
@inline .%(a1::Number, a2::StaticArray) = broadcast(%, a1, a2)
3448
@inline .^(a1::Number, a2::StaticArray) = broadcast(^, a1, a2)
49+
@inline .<<(a1::Number, a2::StaticArray) = broadcast(<<, a1, a2)
50+
@inline .>>(a1::Number, a2::StaticArray) = broadcast(<<, a1, a2)
51+
@inline .==(a1, a2::StaticArray) = broadcast(==, a1, a2)
52+
@inline .<(a1, a2::StaticArray) = broadcast(<, a1, a2)
53+
@inline .>(a1, a2::StaticArray) = broadcast(>, a1, a2)
54+
@inline .<=(a1, a2::StaticArray) = broadcast(<=, a1, a2)
55+
@inline .>=(a1, a2::StaticArray) = broadcast(>=, a1, a2)
56+
57+
# The remaining auto-vectorized boolean operators
58+
@inline !(a::StaticArray{Bool}) = broadcast(!, a)
59+
60+
@inline (&){T1,T2}(a1::StaticArray{T1}, a2::StaticArray{T2}) = broadcast(&, a1, a2)
61+
@inline (|){T1,T2}(a1::StaticArray{T1}, a2::StaticArray{T2}) = broadcast(|, a1, a2)
62+
@inline ($){T1,T2}(a1::StaticArray{T1}, a2::StaticArray{T2}) = broadcast($, a1, a2)
63+
64+
@inline (&){T}(a1::StaticArray{T}, a2::Number) = broadcast(&, a1, a2)
65+
@inline (|){T}(a1::StaticArray{T}, a2::Number) = broadcast(|, a1, a2)
66+
@inline ($){T}(a1::StaticArray{T}, a2::Number) = broadcast($, a1, a2)
67+
68+
@inline (&){T}(a1::Number, a2::StaticArray{T}) = broadcast(&, a1, a2)
69+
@inline (|){T}(a1::Number, a2::StaticArray{T}) = broadcast(|, a1, a2)
70+
@inline ($){T}(a1::Number, a2::StaticArray{T}) = broadcast($, a1, a2)
71+
3572

3673
@generated function Base.zeros{SA <: StaticArray}(::Union{SA,Type{SA}})
3774
s = size(SA)

src/mapreduce.jl

Lines changed: 30 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
exprs = [:(f(a1[$j])) for j = 1:length(a1)]
99
return quote
1010
$(Expr(:meta, :inline))
11-
$(Expr(:call, newtype, Expr(:tuple, exprs...)))
11+
@inbounds return $(Expr(:call, newtype, Expr(:tuple, exprs...)))
1212
end
1313
end
1414

@@ -22,7 +22,7 @@ end
2222
exprs = [:(f(a1[$j], a2[$j])) for j = 1:length(a1)]
2323
return quote
2424
$(Expr(:meta, :inline))
25-
$(Expr(:call, newtype, Expr(:tuple, exprs...)))
25+
@inbounds return $(Expr(:call, newtype, Expr(:tuple, exprs...)))
2626
end
2727
end
2828

@@ -60,36 +60,47 @@ end
6060
############
6161
## reduce ##
6262
############
63-
@generated function reduce(op, a1::StaticArray)
64-
if length(a1) == 1
65-
return :(a1[1])
63+
@generated function reduce(op, a::StaticArray)
64+
if length(a) == 1
65+
return :(@inbounds return a[1])
6666
else
67-
expr = :(op(a1[1], a1[2]))
68-
for j = 3:length(a1)
69-
expr = :(op($expr, a1[$j]))
67+
expr = :(op(a[1], a[2]))
68+
for j = 3:length(a)
69+
expr = :(op($expr, a[$j]))
7070
end
7171
return quote
7272
$(Expr(:meta, :inline))
73-
$expr
73+
@inbounds return $expr
7474
end
7575
end
7676
end
7777

78-
@generated function reduce(op, v0, a1::StaticArray)
79-
if length(a1) == 0
78+
@generated function reduce(op, v0, a::StaticArray)
79+
if length(a) == 0
8080
return :(v0)
8181
else
82-
expr = :(op(v0, a1[1]))
83-
for j = 2:length(a1)
84-
expr = :(op($expr, a1[$j]))
82+
expr = :(op(v0, a[1]))
83+
for j = 2:length(a)
84+
expr = :(op($expr, a[$j]))
8585
end
8686
return quote
8787
$(Expr(:meta, :inline))
88-
$expr
88+
@inbounds return $expr
8989
end
9090
end
9191
end
9292

93+
# These are all similar in Base but not @inline'd
94+
@inline sum{T}(a::StaticArray{T}) = reduce(+, zero(T), a)
95+
@inline prod{T}(a::StaticArray{T}) = reduce(+, zero(T), a)
96+
@inline count(a::StaticArray{Bool}) = reduce(+, 0, a)
97+
@inline mean(a::StaticArray) = sum(a) / length(a)
98+
@inline sumabs{T}(a::StaticArray{T}) = mapreduce(abs, +, zero(T), a)
99+
@inline sumabs2{T}(a::StaticArray{T}) = mapreduce(abs2, +, zero(T), a)
100+
@inline minimum(a::StaticArray) = reduce(min, a) # base has mapreduce(idenity, scalarmin, a)
101+
@inline maximum(a::StaticArray) = reduce(max, a) # base has mapreduce(idenity, scalarmax, a)
102+
103+
93104
###############
94105
## mapreduce ##
95106
###############
@@ -105,7 +116,7 @@ end
105116
end
106117
return quote
107118
$(Expr(:meta, :inline))
108-
$expr
119+
@inbounds return $expr
109120
end
110121
end
111122
end
@@ -120,7 +131,7 @@ end
120131
end
121132
return quote
122133
$(Expr(:meta, :inline))
123-
$expr
134+
@inbounds return $expr
124135
end
125136
end
126137
end
@@ -140,7 +151,7 @@ end
140151
end
141152
return quote
142153
$(Expr(:meta, :inline))
143-
$expr
154+
@inbounds return $expr
144155
end
145156
end
146157
end
@@ -159,7 +170,7 @@ end
159170
end
160171
return quote
161172
$(Expr(:meta, :inline))
162-
$expr
173+
@inbounds return $expr
163174
end
164175
end
165176
end

0 commit comments

Comments
 (0)