|
2 | 2 | ## broadcast! ##
|
3 | 3 | ################
|
4 | 4 |
|
5 |
| -# TODO: bad codegen for `broadcast(-, SVector(1,2,3))` |
6 |
| - |
7 |
| -@propagate_inbounds function broadcast(f, a::Union{Number, StaticArray}...) |
8 |
| - _broadcast(f, broadcast_sizes(a...), a...) |
9 |
| -end |
10 |
| - |
11 |
| -@propagate_inbounds function broadcast{T}(f::Function, a::Type{T}, x::StaticArray) |
12 |
| - _broadcast(f, (Size(), Size(x)), T, x) |
| 5 | +import Base.Broadcast: |
| 6 | + _containertype, promote_containertype, broadcast_indices, |
| 7 | + broadcast_c, broadcast_c! |
| 8 | + |
| 9 | +# Add StaticArray as a new output type in Base.Broadcast promotion machinery. |
| 10 | +# This isn't the precise output type, just a placeholder to return from |
| 11 | +# promote_containertype, which will control dispatch to our broadcast_c. |
| 12 | +_containertype(::Type{<:StaticArray}) = StaticArray |
| 13 | + |
| 14 | +# With the above, the default promote_containertype gives reasonable defaults: |
| 15 | +# StaticArray, StaticArray -> StaticArray |
| 16 | +# Array, StaticArray -> Array |
| 17 | +# |
| 18 | +# We could be more precise about the latter, but this isn't really possible |
| 19 | +# without using Array{N} rather than Array in Base's promote_containertype. |
| 20 | +# |
| 21 | +# Base also has broadcast with tuple + Array, but while implementing this would |
| 22 | +# be consistent with Base, it's not exactly clear it's a good idea when you can |
| 23 | +# just use an SVector instead? |
| 24 | +promote_containertype(::Type{StaticArray}, ::Type{Any}) = StaticArray |
| 25 | +promote_containertype(::Type{Any}, ::Type{StaticArray}) = StaticArray |
| 26 | + |
| 27 | +broadcast_indices(::Type{StaticArray}, A) = indices(A) |
| 28 | + |
| 29 | + |
| 30 | +# Override for when output type is deduced to be a StaticArray. |
| 31 | +@inline function broadcast_c(f, ::Type{StaticArray}, as...) |
| 32 | + _broadcast(f, broadcast_sizes(as...), as...) |
13 | 33 | end
|
14 | 34 |
|
15 | 35 | @inline broadcast_sizes(a::StaticArray, as...) = (Size(a), broadcast_sizes(as...)...)
|
16 |
| -@inline broadcast_sizes(a::Number, as...) = (Size(), broadcast_sizes(as...)...) |
| 36 | +@inline broadcast_sizes(a, as...) = (Size(), broadcast_sizes(as...)...) |
17 | 37 | @inline broadcast_sizes() = ()
|
18 | 38 |
|
19 | 39 | function broadcasted_index(oldsize, newindex)
|
|
94 | 114 | ## broadcast! ##
|
95 | 115 | ################
|
96 | 116 |
|
97 |
| -@propagate_inbounds function broadcast!(f, dest::StaticArray, a::Union{Number, StaticArray}...) |
98 |
| - _broadcast!(f, Size(dest), dest, broadcast_sizes(a...), a...) |
| 117 | +@inline function broadcast_c!(f, ::Type{StaticArray}, ::Type, dest, as...) |
| 118 | + _broadcast!(f, Size(dest), dest, broadcast_sizes(as...), as...) |
99 | 119 | end
|
100 | 120 |
|
101 | 121 |
|
102 |
| -@generated function _broadcast!(f, ::Size{newsize}, dest::StaticArray, s::Tuple{Vararg{Size}}, a::Union{Number, StaticArray}...) where {newsize} |
| 122 | +@generated function _broadcast!(f, ::Size{newsize}, dest::StaticArray, s::Tuple{Vararg{Size}}, as...) where {newsize} |
103 | 123 | sizes = [sz.parameters[1] for sz ∈ s.parameters]
|
104 | 124 | sizes = tuple(sizes...)
|
105 | 125 |
|
|
122 | 142 | more = newsize[1] != 0
|
123 | 143 | current_ind = ones(Int, max(length(newsize), length.(sizes)...))
|
124 | 144 | while more
|
125 |
| - exprs_vals = [(a[i] <: Number ? :(a[$i]) : :(a[$i][$(broadcasted_index(sizes[i], current_ind))])) for i = 1:length(sizes)] |
| 145 | + exprs_vals = [(!(as[i] <: AbstractArray) ? :(as[$i]) : :(as[$i][$(broadcasted_index(sizes[i], current_ind))])) for i = 1:length(sizes)] |
126 | 146 | exprs[current_ind...] = :(dest[$j] = f($(exprs_vals...)))
|
127 | 147 |
|
128 | 148 | # increment current_ind (maybe use CartesianRange?)
|
|
0 commit comments