Skip to content

Commit eb408b9

Browse files
authored
Merge pull request #241 from JuliaArrays/non-numeric-broadcast
Integrate StaticArrays with julia 0.6 broadcast
2 parents 3e7c71e + 3865c6f commit eb408b9

File tree

3 files changed

+48
-13
lines changed

3 files changed

+48
-13
lines changed

src/broadcast.jl

Lines changed: 33 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,18 +2,38 @@
22
## broadcast! ##
33
################
44

5-
# TODO: bad codegen for `broadcast(-, SVector(1,2,3))`
6-
7-
@propagate_inbounds function broadcast(f, a::Union{Number, StaticArray}...)
8-
_broadcast(f, broadcast_sizes(a...), a...)
9-
end
10-
11-
@propagate_inbounds function broadcast{T}(f::Function, a::Type{T}, x::StaticArray)
12-
_broadcast(f, (Size(), Size(x)), T, x)
5+
import Base.Broadcast:
6+
_containertype, promote_containertype, broadcast_indices,
7+
broadcast_c, broadcast_c!
8+
9+
# Add StaticArray as a new output type in Base.Broadcast promotion machinery.
10+
# This isn't the precise output type, just a placeholder to return from
11+
# promote_containertype, which will control dispatch to our broadcast_c.
12+
_containertype(::Type{<:StaticArray}) = StaticArray
13+
14+
# With the above, the default promote_containertype gives reasonable defaults:
15+
# StaticArray, StaticArray -> StaticArray
16+
# Array, StaticArray -> Array
17+
#
18+
# We could be more precise about the latter, but this isn't really possible
19+
# without using Array{N} rather than Array in Base's promote_containertype.
20+
#
21+
# Base also has broadcast with tuple + Array, but while implementing this would
22+
# be consistent with Base, it's not exactly clear it's a good idea when you can
23+
# just use an SVector instead?
24+
promote_containertype(::Type{StaticArray}, ::Type{Any}) = StaticArray
25+
promote_containertype(::Type{Any}, ::Type{StaticArray}) = StaticArray
26+
27+
broadcast_indices(::Type{StaticArray}, A) = indices(A)
28+
29+
30+
# Override for when output type is deduced to be a StaticArray.
31+
@inline function broadcast_c(f, ::Type{StaticArray}, as...)
32+
_broadcast(f, broadcast_sizes(as...), as...)
1333
end
1434

1535
@inline broadcast_sizes(a::StaticArray, as...) = (Size(a), broadcast_sizes(as...)...)
16-
@inline broadcast_sizes(a::Number, as...) = (Size(), broadcast_sizes(as...)...)
36+
@inline broadcast_sizes(a, as...) = (Size(), broadcast_sizes(as...)...)
1737
@inline broadcast_sizes() = ()
1838

1939
function broadcasted_index(oldsize, newindex)
@@ -94,12 +114,12 @@ end
94114
## broadcast! ##
95115
################
96116

97-
@propagate_inbounds function broadcast!(f, dest::StaticArray, a::Union{Number, StaticArray}...)
98-
_broadcast!(f, Size(dest), dest, broadcast_sizes(a...), a...)
117+
@inline function broadcast_c!(f, ::Type{StaticArray}, ::Type, dest, as...)
118+
_broadcast!(f, Size(dest), dest, broadcast_sizes(as...), as...)
99119
end
100120

101121

102-
@generated function _broadcast!(f, ::Size{newsize}, dest::StaticArray, s::Tuple{Vararg{Size}}, a::Union{Number, StaticArray}...) where {newsize}
122+
@generated function _broadcast!(f, ::Size{newsize}, dest::StaticArray, s::Tuple{Vararg{Size}}, as...) where {newsize}
103123
sizes = [sz.parameters[1] for sz s.parameters]
104124
sizes = tuple(sizes...)
105125

@@ -122,7 +142,7 @@ end
122142
more = newsize[1] != 0
123143
current_ind = ones(Int, max(length(newsize), length.(sizes)...))
124144
while more
125-
exprs_vals = [(a[i] <: Number ? :(a[$i]) : :(a[$i][$(broadcasted_index(sizes[i], current_ind))])) for i = 1:length(sizes)]
145+
exprs_vals = [(!(as[i] <: AbstractArray) ? :(as[$i]) : :(as[$i][$(broadcasted_index(sizes[i], current_ind))])) for i = 1:length(sizes)]
126146
exprs[current_ind...] = :(dest[$j] = f($(exprs_vals...)))
127147

128148
# increment current_ind (maybe use CartesianRange?)

test/Scalar.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
using StaticArrays, Base.Test
2+
13
@testset "Scalar" begin
24
@test Scalar(2) .* [1, 2, 3] == [2, 4, 6]
35
@test Scalar([1 2; 3 4]) .+ [[1 1; 1 1], [2 2; 2 2]] == [[2 3; 4 5], [3 4; 5 6]]

test/broadcast.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
using StaticArrays, Base.Test
2+
3+
include("testutil.jl")
4+
15
@testset "Broadcast sizes" begin
26
@test @inferred(StaticArrays.broadcast_sizes(1, 1, 1)) === (Size(), Size(), Size())
37
for t in (SVector{2}, MVector{2}, SMatrix{2, 2}, MMatrix{2, 2})
@@ -139,4 +143,13 @@ end
139143
@test eltype(a) == Float32
140144
end
141145
end
146+
147+
@testset "broadcast general scalars" begin
148+
# Issue #239 - broadcast with non-numeric element types
149+
@eval @enum Axis aX aY aZ
150+
@testinf (SVector(aX,aY,aZ) .== aX) == SVector(true,false,false)
151+
mv = MVector(aX,aY,aZ)
152+
@testinf broadcast!(identity, mv, aX) == MVector(aX,aX,aX)
153+
@test mv == SVector(aX,aX,aX)
154+
end
142155
end

0 commit comments

Comments
 (0)