@@ -60,21 +60,25 @@ static_check_broadcast_shape(::Tuple{}, ::Tuple{}) = ()
60
60
@inline function Base. copy (B:: Broadcasted{StaticArrayStyle{M}} ) where M
61
61
flat = Broadcast. flatten (B); as = flat. args; f = flat. f
62
62
argsizes = broadcast_sizes (as... )
63
- destsize = combine_sizes (argsizes)
64
- _broadcast (f, destsize, argsizes, as... )
63
+ ax = axes (B)
64
+ if ax isa Tuple{Vararg{SOneTo}}
65
+ return _broadcast (f, Size (map (length, ax)), argsizes, as... )
66
+ end
67
+ return copy (convert (Broadcasted{DefaultArrayStyle{M}}, B))
65
68
end
66
69
# copyto! overloads
67
70
@inline Base. copyto! (dest, B:: Broadcasted{<:StaticArrayStyle} ) = _copyto! (dest, B)
68
71
@inline Base. copyto! (dest:: AbstractArray , B:: Broadcasted{<:StaticArrayStyle} ) = _copyto! (dest, B)
69
72
@inline function _copyto! (dest, B:: Broadcasted{StaticArrayStyle{M}} ) where M
70
73
flat = Broadcast. flatten (B); as = flat. args; f = flat. f
71
74
argsizes = broadcast_sizes (as... )
72
- destsize = combine_sizes (( Size (dest), argsizes ... ) )
73
- if Length (destsize) === Length {Dynamic()} ()
74
- # destination dimension cannot be determined statically; fall back to generic broadcast!
75
- return copyto! (dest, convert (Broadcasted{DefaultArrayStyle{M}}, B) )
75
+ ax = axes (B )
76
+ if ax isa Tuple{Vararg{SOneTo}}
77
+ @boundscheck axes (dest) == ax || Broadcast . throwdm ( axes (dest), ax)
78
+ return _broadcast! (f, Size ( map (length, ax)), dest, argsizes, as ... )
76
79
end
77
- _broadcast! (f, destsize, dest, argsizes, as... )
80
+ # destination dimension cannot be determined statically; fall back to generic broadcast!
81
+ return copyto! (dest, convert (Broadcasted{DefaultArrayStyle{M}}, B))
78
82
end
79
83
80
84
# Resolving priority between dynamic and static axes
@@ -101,45 +105,13 @@ broadcast_indices(A::StaticArray) = indices(A)
101
105
@inline broadcast_size (a:: AbstractArray ) = Size (a)
102
106
@inline broadcast_size (a:: Tuple ) = Size (length (a))
103
107
104
- function broadcasted_index (oldsize, newindex)
105
- index = ones (Int, length (oldsize))
106
- for i = 1 : length (oldsize)
107
- if oldsize[i] != 1
108
- index[i] = newindex[i]
109
- end
110
- end
111
- return LinearIndices (oldsize)[index... ]
112
- end
113
-
114
- # similar to Base.Broadcast.combine_indices:
115
- @generated function combine_sizes (s:: Tuple{Vararg{Size}} )
116
- sizes = [sz. parameters[1 ] for sz ∈ s. parameters]
117
- ndims = 0
118
- for i = 1 : length (sizes)
119
- ndims = max (ndims, length (sizes[i]))
120
- end
121
- newsize = StaticDimension[Dynamic () for _ = 1 : ndims]
122
- for i = 1 : length (sizes)
123
- s = sizes[i]
124
- for j = 1 : length (s)
125
- if s[j] isa Dynamic
126
- continue
127
- elseif newsize[j] isa Dynamic || newsize[j] == 1
128
- newsize[j] = s[j]
129
- elseif newsize[j] ≠ s[j] && s[j] ≠ 1
130
- throw (DimensionMismatch (" Tried to broadcast on inputs sized $sizes " ))
131
- end
132
- end
133
- end
134
- quote
135
- @_inline_meta
136
- Size ($ (tuple (newsize... )))
137
- end
108
+ broadcast_getindex (:: Tuple{} , i:: Int , I:: CartesianIndex ) = return :(_broadcast_getindex (a[$ i], $ I))
109
+ function broadcast_getindex (oldsize:: Tuple , i:: Int , newindex:: CartesianIndex )
110
+ li = LinearIndices (oldsize)
111
+ ind = _broadcast_getindex (li, newindex)
112
+ return :(a[$ i][$ ind])
138
113
end
139
114
140
- scalar_getindex (x) = x
141
- scalar_getindex (x:: Ref ) = x[]
142
-
143
115
isstatic (:: StaticArray ) = true
144
116
isstatic (:: Transpose{<:Any, <:StaticArray} ) = true
145
117
isstatic (:: Adjoint{<:Any, <:StaticArray} ) = true
@@ -163,13 +135,11 @@ end
163
135
164
136
@generated function __broadcast (f, :: Size{newsize} , s:: Tuple{Vararg{Size}} , a... ) where newsize
165
137
sizes = [sz. parameters[1 ] for sz ∈ s. parameters]
138
+
166
139
indices = CartesianIndices (newsize)
167
140
exprs = similar (indices, Expr)
168
141
for (j, current_ind) ∈ enumerate (indices)
169
- exprs_vals = [
170
- (! (a[i] <: AbstractArray || a[i] <: Tuple ) ? :(scalar_getindex (a[$ i])) : :(a[$ i][$ (broadcasted_index (sizes[i], current_ind))]))
171
- for i = 1 : length (sizes)
172
- ]
142
+ exprs_vals = (broadcast_getindex (sz, i, current_ind) for (i, sz) in enumerate (sizes))
173
143
exprs[j] = :(f ($ (exprs_vals... )))
174
144
end
175
145
@@ -183,27 +153,18 @@ end
183
153
# # Internal broadcast! machinery for StaticArrays ##
184
154
# ###################################################
185
155
186
- @generated function _broadcast! (f, :: Size{newsize} , dest:: AbstractArray , s:: Tuple{Vararg{Size}} , as... ) where {newsize}
187
- sizes = [sz. parameters[1 ] for sz ∈ s. parameters]
188
- sizes = tuple (sizes... )
189
-
190
- # TODO : this could also be done outside the generated function:
191
- sizematch (Size {newsize} (), Size (dest)) ||
192
- throw (DimensionMismatch (" Tried to broadcast to destination sized $newsize from inputs sized $sizes " ))
156
+ @generated function _broadcast! (f, :: Size{newsize} , dest:: AbstractArray , s:: Tuple{Vararg{Size}} , a... ) where {newsize}
157
+ sizes = [sz. parameters[1 ] for sz in s. parameters]
193
158
194
159
indices = CartesianIndices (newsize)
195
160
exprs = similar (indices, Expr)
196
161
for (j, current_ind) ∈ enumerate (indices)
197
- exprs_vals = [
198
- (! (as[i] <: AbstractArray || as[i] <: Tuple ) ? :(as[$ i][]) : :(as[$ i][$ (broadcasted_index (sizes[i], current_ind))]))
199
- for i = 1 : length (sizes)
200
- ]
162
+ exprs_vals = (broadcast_getindex (sz, i, current_ind) for (i, sz) in enumerate (sizes))
201
163
exprs[j] = :(dest[$ j] = f ($ (exprs_vals... )))
202
164
end
203
165
204
166
return quote
205
- @_propagate_inbounds_meta
206
- @boundscheck sizematch ($ (Size {newsize} ()), dest) || throw (DimensionMismatch (" array could not be broadcast to match destination" ))
167
+ @_inline_meta
207
168
@inbounds $ (Expr (:block , exprs... ))
208
169
return dest
209
170
end
0 commit comments