@@ -60,21 +60,25 @@ static_check_broadcast_shape(::Tuple{}, ::Tuple{}) = ()
60
60
@inline function Base. copy (B:: Broadcasted{StaticArrayStyle{M}} ) where M
61
61
flat = Broadcast. flatten (B); as = flat. args; f = flat. f
62
62
argsizes = broadcast_sizes (as... )
63
- destsize = combine_sizes (argsizes)
64
- _broadcast (f, destsize, argsizes, as... )
63
+ ax = axes (B)
64
+ if ax isa Tuple{Vararg{SOneTo}}
65
+ return _broadcast (f, Size (map (length, ax)), argsizes, as... )
66
+ end
67
+ return copy (convert (Broadcasted{DefaultArrayStyle{M}}, B))
65
68
end
66
69
# copyto! overloads
67
70
@inline Base. copyto! (dest, B:: Broadcasted{<:StaticArrayStyle} ) = _copyto! (dest, B)
68
71
@inline Base. copyto! (dest:: AbstractArray , B:: Broadcasted{<:StaticArrayStyle} ) = _copyto! (dest, B)
69
72
@inline function _copyto! (dest, B:: Broadcasted{StaticArrayStyle{M}} ) where M
70
73
flat = Broadcast. flatten (B); as = flat. args; f = flat. f
71
74
argsizes = broadcast_sizes (as... )
72
- destsize = combine_sizes (( Size (dest), argsizes ... ) )
73
- if Length (destsize) === Length {Dynamic()} ()
74
- # destination dimension cannot be determined statically; fall back to generic broadcast!
75
- return copyto! (dest, convert (Broadcasted{DefaultArrayStyle{M}}, B) )
75
+ ax = axes (B )
76
+ if ax isa Tuple{Vararg{SOneTo}}
77
+ @boundscheck axes (dest) == ax || Broadcast . throwdm ( axes (dest), ax)
78
+ return _broadcast! (f, Size ( map (length, ax)), dest, argsizes, as ... )
76
79
end
77
- _broadcast! (f, destsize, dest, argsizes, as... )
80
+ # destination dimension cannot be determined statically; fall back to generic broadcast!
81
+ return copyto! (dest, convert (Broadcasted{DefaultArrayStyle{M}}, B))
78
82
end
79
83
80
84
# Resolving priority between dynamic and static axes
@@ -99,45 +103,13 @@ _bcs1(a::Base.OneTo, b::SOneTo) = _bcs1(b, a)
99
103
@inline broadcast_size (a:: AbstractArray ) = Size (a)
100
104
@inline broadcast_size (a:: Tuple ) = Size (length (a))
101
105
102
- function broadcasted_index (oldsize, newindex)
103
- index = ones (Int, length (oldsize))
104
- for i = 1 : length (oldsize)
105
- if oldsize[i] != 1
106
- index[i] = newindex[i]
107
- end
108
- end
109
- return LinearIndices (oldsize)[index... ]
110
- end
111
-
112
- # similar to Base.Broadcast.combine_indices:
113
- @generated function combine_sizes (s:: Tuple{Vararg{Size}} )
114
- sizes = [sz. parameters[1 ] for sz ∈ s. parameters]
115
- ndims = 0
116
- for i = 1 : length (sizes)
117
- ndims = max (ndims, length (sizes[i]))
118
- end
119
- newsize = StaticDimension[Dynamic () for _ = 1 : ndims]
120
- for i = 1 : length (sizes)
121
- s = sizes[i]
122
- for j = 1 : length (s)
123
- if s[j] isa Dynamic
124
- continue
125
- elseif newsize[j] isa Dynamic || newsize[j] == 1
126
- newsize[j] = s[j]
127
- elseif newsize[j] ≠ s[j] && s[j] ≠ 1
128
- throw (DimensionMismatch (" Tried to broadcast on inputs sized $sizes " ))
129
- end
130
- end
131
- end
132
- quote
133
- @_inline_meta
134
- Size ($ (tuple (newsize... )))
135
- end
106
+ broadcast_getindex (:: Tuple{} , i:: Int , I:: CartesianIndex ) = return :(_broadcast_getindex (a[$ i], $ I))
107
+ function broadcast_getindex (oldsize:: Tuple , i:: Int , newindex:: CartesianIndex )
108
+ li = LinearIndices (oldsize)
109
+ ind = _broadcast_getindex (li, newindex)
110
+ return :(a[$ i][$ ind])
136
111
end
137
112
138
- scalar_getindex (x) = x
139
- scalar_getindex (x:: Ref ) = x[]
140
-
141
113
isstatic (:: StaticArray ) = true
142
114
isstatic (:: Transpose{<:Any, <:StaticArray} ) = true
143
115
isstatic (:: Adjoint{<:Any, <:StaticArray} ) = true
@@ -161,13 +133,11 @@ end
161
133
162
134
@generated function __broadcast (f, :: Size{newsize} , s:: Tuple{Vararg{Size}} , a... ) where newsize
163
135
sizes = [sz. parameters[1 ] for sz ∈ s. parameters]
136
+
164
137
indices = CartesianIndices (newsize)
165
138
exprs = similar (indices, Expr)
166
139
for (j, current_ind) ∈ enumerate (indices)
167
- exprs_vals = [
168
- (! (a[i] <: AbstractArray || a[i] <: Tuple ) ? :(scalar_getindex (a[$ i])) : :(a[$ i][$ (broadcasted_index (sizes[i], current_ind))]))
169
- for i = 1 : length (sizes)
170
- ]
140
+ exprs_vals = (broadcast_getindex (sz, i, current_ind) for (i, sz) in enumerate (sizes))
171
141
exprs[j] = :(f ($ (exprs_vals... )))
172
142
end
173
143
@@ -181,27 +151,18 @@ end
181
151
# # Internal broadcast! machinery for StaticArrays ##
182
152
# ###################################################
183
153
184
- @generated function _broadcast! (f, :: Size{newsize} , dest:: AbstractArray , s:: Tuple{Vararg{Size}} , as... ) where {newsize}
185
- sizes = [sz. parameters[1 ] for sz ∈ s. parameters]
186
- sizes = tuple (sizes... )
187
-
188
- # TODO : this could also be done outside the generated function:
189
- sizematch (Size {newsize} (), Size (dest)) ||
190
- throw (DimensionMismatch (" Tried to broadcast to destination sized $newsize from inputs sized $sizes " ))
154
+ @generated function _broadcast! (f, :: Size{newsize} , dest:: AbstractArray , s:: Tuple{Vararg{Size}} , a... ) where {newsize}
155
+ sizes = [sz. parameters[1 ] for sz in s. parameters]
191
156
192
157
indices = CartesianIndices (newsize)
193
158
exprs = similar (indices, Expr)
194
159
for (j, current_ind) ∈ enumerate (indices)
195
- exprs_vals = [
196
- (! (as[i] <: AbstractArray || as[i] <: Tuple ) ? :(as[$ i][]) : :(as[$ i][$ (broadcasted_index (sizes[i], current_ind))]))
197
- for i = 1 : length (sizes)
198
- ]
160
+ exprs_vals = (broadcast_getindex (sz, i, current_ind) for (i, sz) in enumerate (sizes))
199
161
exprs[j] = :(dest[$ j] = f ($ (exprs_vals... )))
200
162
end
201
163
202
164
return quote
203
- @_propagate_inbounds_meta
204
- @boundscheck sizematch ($ (Size {newsize} ()), dest) || throw (DimensionMismatch (" array could not be broadcast to match destination" ))
165
+ @_inline_meta
205
166
@inbounds $ (Expr (:block , exprs... ))
206
167
return dest
207
168
end
0 commit comments