2
2
# # broadcast! ##
3
3
# ###############
4
4
5
- import Base. Broadcast:
6
- BroadcastStyle, AbstractArrayStyle, Broadcasted, DefaultArrayStyle, materialize!
5
+ using Base. Broadcast: AbstractArrayStyle, DefaultArrayStyle, Style, Broadcasted
6
+ using Base. Broadcast: broadcast_shape, _broadcast_getindex, combine_axes
7
+ import Base. Broadcast: BroadcastStyle, materialize!, instantiate
7
8
import Base. Broadcast: _bcs1 # for SOneTo axis information
8
9
using Base. Broadcast: _bcsm
9
10
# Add a new BroadcastStyle for StaticArrays, derived from AbstractArrayStyle
10
11
# A constructor that changes the style parameter N (array dimension) is also required
11
12
struct StaticArrayStyle{N} <: AbstractArrayStyle{N} end
12
13
StaticArrayStyle {M} (:: Val{N} ) where {M,N} = StaticArrayStyle {N} ()
13
14
BroadcastStyle (:: Type{<:StaticArray{<:Tuple, <:Any, N}} ) where {N} = StaticArrayStyle {N} ()
14
- BroadcastStyle (:: Type{<:Transpose{<:Any, <:StaticArray{<:Tuple, <:Any, N}}} ) where {N} = StaticArrayStyle {N } ()
15
- BroadcastStyle (:: Type{<:Adjoint{<:Any, <:StaticArray{<:Tuple, <:Any, N}}} ) where {N} = StaticArrayStyle {N } ()
15
+ BroadcastStyle (:: Type{<:Transpose{<:Any, <:StaticArray}} ) = StaticArrayStyle {2 } ()
16
+ BroadcastStyle (:: Type{<:Adjoint{<:Any, <:StaticArray}} ) = StaticArrayStyle {2 } ()
16
17
BroadcastStyle (:: Type{<:Diagonal{<:Any, <:StaticArray{<:Tuple, <:Any, 1}}} ) = StaticArrayStyle {2} ()
17
18
# Precedence rules
18
19
BroadcastStyle (:: StaticArrayStyle{M} , :: DefaultArrayStyle{N} ) where {M,N} =
19
20
DefaultArrayStyle (Val (max (M, N)))
20
21
BroadcastStyle (:: StaticArrayStyle{M} , :: DefaultArrayStyle{0} ) where {M} =
21
22
StaticArrayStyle {M} ()
23
+
24
+ # combine_axes overload (for Tuple)
25
+ @inline static_combine_axes (A, B... ) = broadcast_shape (static_axes (A), static_combine_axes (B... ))
26
+ static_combine_axes (A) = static_axes (A)
27
+ static_axes (A) = axes (A)
28
+ static_axes (x:: Tuple ) = (SOneTo {length(x)} (),)
29
+ static_axes (bc:: Broadcasted{Style{Tuple}} ) = static_combine_axes (bc. args... )
30
+ Broadcast. _axes (bc:: Broadcasted{<:StaticArrayStyle} , :: Nothing ) = static_combine_axes (bc. args... )
31
+
32
+ # instantiate overload
33
+ @inline function instantiate (B:: Broadcasted{StaticArrayStyle{M}} ) where M
34
+ if B. axes isa Tuple{Vararg{SOneTo}} || B. axes isa Tuple && length (B. axes) > M
35
+ return invoke (instantiate, Tuple{Broadcasted}, B)
36
+ elseif B. axes isa Nothing
37
+ ax = static_combine_axes (B. args... )
38
+ return Broadcasted {StaticArrayStyle{M}} (B. f, B. args, ax)
39
+ else
40
+ # We need to update B.axes for `broadcast!` if it's not static and `ndims(dest) < M`.
41
+ ax = static_check_broadcast_shape (B. axes, static_combine_axes (B. args... ))
42
+ return Broadcasted {StaticArrayStyle{M}} (B. f, B. args, ax)
43
+ end
44
+ end
45
+ @inline function static_check_broadcast_shape (shp:: Tuple , Ashp:: Tuple{Vararg{SOneTo}} )
46
+ ax1 = if length (Ashp[1 ]) == 1
47
+ shp[1 ]
48
+ elseif Ashp[1 ] == shp[1 ]
49
+ Ashp[1 ]
50
+ else
51
+ throw (DimensionMismatch (" array could not be broadcast to match destination" ))
52
+ end
53
+ return (ax1, static_check_broadcast_shape (Base. tail (shp), Base. tail (Ashp))... )
54
+ end
55
+ static_check_broadcast_shape (:: Tuple{} , :: Tuple{SOneTo,Vararg{SOneTo}} ) =
56
+ throw (DimensionMismatch (" cannot broadcast array to have fewer non-singleton dimensions" ))
57
+ static_check_broadcast_shape (:: Tuple{} , :: Tuple{SOneTo{1},Vararg{SOneTo{1}}} ) = ()
58
+ static_check_broadcast_shape (:: Tuple{} , :: Tuple{} ) = ()
22
59
# copy overload
23
60
@inline function Base. copy (B:: Broadcasted{StaticArrayStyle{M}} ) where M
24
61
flat = Broadcast. flatten (B); as = flat. args; f = flat. f
25
62
argsizes = broadcast_sizes (as... )
26
- destsize = combine_sizes (argsizes)
27
- _broadcast (f, destsize, argsizes, as... )
63
+ ax = axes (B)
64
+ ax isa Tuple{Vararg{SOneTo}} || error (" Dimension is not static. Please file a bug." )
65
+ return _broadcast (f, Size (map (length, ax)), argsizes, as... )
28
66
end
29
67
# copyto! overloads
30
68
@inline Base. copyto! (dest, B:: Broadcasted{<:StaticArrayStyle} ) = _copyto! (dest, B)
31
69
@inline Base. copyto! (dest:: AbstractArray , B:: Broadcasted{<:StaticArrayStyle} ) = _copyto! (dest, B)
32
70
@inline function _copyto! (dest, B:: Broadcasted{StaticArrayStyle{M}} ) where M
33
71
flat = Broadcast. flatten (B); as = flat. args; f = flat. f
34
72
argsizes = broadcast_sizes (as... )
35
- destsize = combine_sizes (( Size (dest), argsizes ... ) )
36
- if Length (destsize) === Length {Dynamic()} ()
37
- # destination dimension cannot be determined statically; fall back to generic broadcast!
38
- return copyto! (dest, convert (Broadcasted{DefaultArrayStyle{M}}, B) )
73
+ ax = axes (B )
74
+ if ax isa Tuple{Vararg{SOneTo}}
75
+ @boundscheck axes (dest) == ax || Broadcast . throwdm ( axes (dest), ax)
76
+ return _broadcast! (f, Size ( map (length, ax)), dest, argsizes, as ... )
39
77
end
40
- _broadcast! (f, destsize, dest, argsizes, as... )
78
+ # destination dimension cannot be determined statically; fall back to generic broadcast!
79
+ return copyto! (dest, convert (Broadcasted{DefaultArrayStyle{M}}, B))
41
80
end
42
81
43
82
# Resolving priority between dynamic and static axes
44
83
_bcs1 (a:: SOneTo , b:: SOneTo ) = _bcsm (b, a) ? b : (_bcsm (a, b) ? a : throw (DimensionMismatch (" arrays could not be broadcast to a common size" )))
45
- _bcs1 (a:: SOneTo , b:: Base.OneTo ) = _bcs1 (Base. OneTo (a), b)
46
- _bcs1 (a:: Base.OneTo , b:: SOneTo ) = _bcs1 (a, Base. OneTo (b))
84
+ function _bcs1 (a:: SOneTo , b:: Base.OneTo )
85
+ length (a) == 1 && return b
86
+ if length (b) != length (a) && length (b) != 1
87
+ throw (DimensionMismatch (" arrays could not be broadcast to a common size" ))
88
+ end
89
+ return a
90
+ end
91
+ _bcs1 (a:: Base.OneTo , b:: SOneTo ) = _bcs1 (b, a)
47
92
48
93
# ##################################################
49
94
# # Internal broadcast machinery for StaticArrays ##
@@ -56,45 +101,13 @@ _bcs1(a::Base.OneTo, b::SOneTo) = _bcs1(a, Base.OneTo(b))
56
101
@inline broadcast_size (a:: AbstractArray ) = Size (a)
57
102
@inline broadcast_size (a:: Tuple ) = Size (length (a))
58
103
59
- function broadcasted_index (oldsize, newindex)
60
- index = ones (Int, length (oldsize))
61
- for i = 1 : length (oldsize)
62
- if oldsize[i] != 1
63
- index[i] = newindex[i]
64
- end
65
- end
66
- return LinearIndices (oldsize)[index... ]
67
- end
68
-
69
- # similar to Base.Broadcast.combine_indices:
70
- @generated function combine_sizes (s:: Tuple{Vararg{Size}} )
71
- sizes = [sz. parameters[1 ] for sz ∈ s. parameters]
72
- ndims = 0
73
- for i = 1 : length (sizes)
74
- ndims = max (ndims, length (sizes[i]))
75
- end
76
- newsize = StaticDimension[Dynamic () for _ = 1 : ndims]
77
- for i = 1 : length (sizes)
78
- s = sizes[i]
79
- for j = 1 : length (s)
80
- if s[j] isa Dynamic
81
- continue
82
- elseif newsize[j] isa Dynamic || newsize[j] == 1
83
- newsize[j] = s[j]
84
- elseif newsize[j] ≠ s[j] && s[j] ≠ 1
85
- throw (DimensionMismatch (" Tried to broadcast on inputs sized $sizes " ))
86
- end
87
- end
88
- end
89
- quote
90
- @_inline_meta
91
- Size ($ (tuple (newsize... )))
92
- end
104
+ broadcast_getindex (:: Tuple{} , i:: Int , I:: CartesianIndex ) = return :(_broadcast_getindex (a[$ i], $ I))
105
+ function broadcast_getindex (oldsize:: Tuple , i:: Int , newindex:: CartesianIndex )
106
+ li = LinearIndices (oldsize)
107
+ ind = _broadcast_getindex (li, newindex)
108
+ return :(a[$ i][$ ind])
93
109
end
94
110
95
- scalar_getindex (x) = x
96
- scalar_getindex (x:: Ref ) = x[]
97
-
98
111
isstatic (:: StaticArray ) = true
99
112
isstatic (:: Transpose{<:Any, <:StaticArray} ) = true
100
113
isstatic (:: Adjoint{<:Any, <:StaticArray} ) = true
@@ -118,13 +131,11 @@ end
118
131
119
132
@generated function __broadcast (f, :: Size{newsize} , s:: Tuple{Vararg{Size}} , a... ) where newsize
120
133
sizes = [sz. parameters[1 ] for sz ∈ s. parameters]
134
+
121
135
indices = CartesianIndices (newsize)
122
136
exprs = similar (indices, Expr)
123
137
for (j, current_ind) ∈ enumerate (indices)
124
- exprs_vals = [
125
- (! (a[i] <: AbstractArray || a[i] <: Tuple ) ? :(scalar_getindex (a[$ i])) : :(a[$ i][$ (broadcasted_index (sizes[i], current_ind))]))
126
- for i = 1 : length (sizes)
127
- ]
138
+ exprs_vals = (broadcast_getindex (sz, i, current_ind) for (i, sz) in enumerate (sizes))
128
139
exprs[j] = :(f ($ (exprs_vals... )))
129
140
end
130
141
@@ -138,27 +149,18 @@ end
138
149
# # Internal broadcast! machinery for StaticArrays ##
139
150
# ###################################################
140
151
141
- @generated function _broadcast! (f, :: Size{newsize} , dest:: AbstractArray , s:: Tuple{Vararg{Size}} , as... ) where {newsize}
142
- sizes = [sz. parameters[1 ] for sz ∈ s. parameters]
143
- sizes = tuple (sizes... )
144
-
145
- # TODO : this could also be done outside the generated function:
146
- sizematch (Size {newsize} (), Size (dest)) ||
147
- throw (DimensionMismatch (" Tried to broadcast to destination sized $newsize from inputs sized $sizes " ))
152
+ @generated function _broadcast! (f, :: Size{newsize} , dest:: AbstractArray , s:: Tuple{Vararg{Size}} , a... ) where {newsize}
153
+ sizes = [sz. parameters[1 ] for sz in s. parameters]
148
154
149
155
indices = CartesianIndices (newsize)
150
156
exprs = similar (indices, Expr)
151
157
for (j, current_ind) ∈ enumerate (indices)
152
- exprs_vals = [
153
- (! (as[i] <: AbstractArray || as[i] <: Tuple ) ? :(as[$ i][]) : :(as[$ i][$ (broadcasted_index (sizes[i], current_ind))]))
154
- for i = 1 : length (sizes)
155
- ]
158
+ exprs_vals = (broadcast_getindex (sz, i, current_ind) for (i, sz) in enumerate (sizes))
156
159
exprs[j] = :(dest[$ j] = f ($ (exprs_vals... )))
157
160
end
158
161
159
162
return quote
160
- @_propagate_inbounds_meta
161
- @boundscheck sizematch ($ (Size {newsize} ()), dest) || throw (DimensionMismatch (" array could not be broadcast to match destination" ))
163
+ @_inline_meta
162
164
@inbounds $ (Expr (:block , exprs... ))
163
165
return dest
164
166
end
0 commit comments