@@ -58,7 +58,7 @@ static_check_broadcast_shape(::Tuple{}, ::Tuple{SOneTo{1},Vararg{SOneTo{1}}}) =
58
58
static_check_broadcast_shape (:: Tuple{} , :: Tuple{} ) = ()
59
59
# copy overload
60
60
@inline function Base. copy (B:: Broadcasted{StaticArrayStyle{M}} ) where M
61
- flat = Broadcast . flatten (B); as = flat. args; f = flat. f
61
+ flat = broadcast_flatten (B); as = flat. args; f = flat. f
62
62
argsizes = broadcast_sizes (as... )
63
63
ax = axes (B)
64
64
ax isa Tuple{Vararg{SOneTo}} || error (" Dimension is not static. Please file a bug." )
68
68
@inline Base. copyto! (dest, B:: Broadcasted{<:StaticArrayStyle} ) = _copyto! (dest, B)
69
69
@inline Base. copyto! (dest:: AbstractArray , B:: Broadcasted{<:StaticArrayStyle} ) = _copyto! (dest, B)
70
70
@inline function _copyto! (dest, B:: Broadcasted{StaticArrayStyle{M}} ) where M
71
- flat = Broadcast . flatten (B); as = flat. args; f = flat. f
71
+ flat = broadcast_flatten (B); as = flat. args; f = flat. f
72
72
argsizes = broadcast_sizes (as... )
73
73
ax = axes (B)
74
74
if ax isa Tuple{Vararg{SOneTo}}
165
165
return dest
166
166
end
167
167
end
168
+
169
+ # Work around for https://github.com/JuliaLang/julia/issues/27988
170
+ # The following code is borrowed from https://github.com/JuliaLang/julia/pull/43322
171
+ # with some modification to make it also works on 1.6.
172
+ # TODO : make `broadcast_flatten` call `Broadcast.flatten` once julia#43322 is merged.
173
+ module StableFlatten
174
+
175
+ export broadcast_flatten
176
+
177
+ using Base: tail
178
+ using Base. Broadcast: isflat, Broadcasted
179
+
180
+ maybeconstructor (f) = f
181
+ maybeconstructor (:: Type{F} ) where {F} = (args... ; kwargs... ) -> F (args... ; kwargs... )
182
+
183
+ function broadcast_flatten (bc:: Broadcasted{Style} ) where {Style}
184
+ isflat (bc) && return bc
185
+ args = cat_nested (bc)
186
+ len = Val {length(args)} ()
187
+ makeargs = make_makeargs (bc. args, len, ntuple (_-> true , len))
188
+ f = maybeconstructor (bc. f)
189
+ @inline newf (args... ) = f (prepare_args (makeargs, args)... )
190
+ return Broadcasted {Style} (newf, args, bc. axes)
191
+ end
192
+
193
+ cat_nested (bc:: Broadcasted ) = cat_nested_args (bc. args)
194
+ cat_nested_args (:: Tuple{} ) = ()
195
+ cat_nested_args (t:: Tuple ) = (cat_nested (t[1 ])... , cat_nested_args (tail (t))... )
196
+ cat_nested (@nospecialize (a)) = (a,)
197
+
198
+ function make_makeargs (args:: Tuple , len, flags)
199
+ makeargs, r = _make_makeargs (args, len, flags)
200
+ r isa Tuple{} || error (" Internal error. Please file a bug" )
201
+ return makeargs
202
+ end
203
+
204
+ # We build `makeargs` by traversing the broadcast nodes recursively.
205
+ # note: `len` isa `Val` indicates the length of whole flattened argument list.
206
+ # `flags` is a tuple of `Bool` with the same length of the rest arguments.
207
+ @inline function _make_makeargs (args:: Tuple , len:: Val , flags:: Tuple )
208
+ head, flags′ = _make_makeargs1 (args[1 ], len, flags)
209
+ rest, flags″ = _make_makeargs (tail (args), len, flags′)
210
+ (head, rest... ), flags″
211
+ end
212
+ _make_makeargs (:: Tuple{} , :: Val , x:: Tuple ) = (), x
213
+
214
+ # For flat nodes:
215
+ # 1. we just consume one argument, and return the "pick" function
216
+ @inline function _make_makeargs1 (@nospecialize (a), :: Val{N} , flags:: Tuple ) where {N}
217
+ pickargs (:: Val{N} ) where {N} = (@nospecialize (x:: Tuple )) -> x[N]
218
+ return pickargs (Val {N-length(flags)+1} ()), tail (flags)
219
+ end
220
+
221
+ # For nested nodes, we form the `makeargs1` based on the child `makeargs` (n += length(cat_nested(bc)))
222
+ @inline function _make_makeargs1 (bc:: Broadcasted , len:: Val , flags:: Tuple )
223
+ makeargs, flags′ = _make_makeargs (bc. args, len, flags)
224
+ f = maybeconstructor (bc. f)
225
+ @inline makeargs1 (@nospecialize (args:: Tuple )) = f (prepare_args (makeargs, args)... )
226
+ makeargs1, flags′
227
+ end
228
+
229
+ prepare_args (:: Tuple{} , @nospecialize (:: Tuple )) = ()
230
+ @inline prepare_args (makeargs:: Tuple , @nospecialize (x:: Tuple )) = (makeargs[1 ](x), prepare_args (tail (makeargs), x)... )
231
+ end
232
+ using . StableFlatten
0 commit comments