@@ -164,64 +164,67 @@ end
164
164
# Work around for https://github.com/JuliaLang/julia/issues/27988
165
165
# The following code is borrowed from https://github.com/JuliaLang/julia/pull/43322
166
166
# with some modification to make it also works on 1.6.
167
- # TODO : make `broadcast_flatten` call `Broadcast.flatten` once julia#43322 is merged.
168
167
module StableFlatten
169
168
170
169
export broadcast_flatten
171
170
172
- using Base: tail
173
- using Base. Broadcast: isflat, Broadcasted
174
-
175
- maybeconstructor (f) = f
176
- maybeconstructor (:: Type{F} ) where {F} = (args... ; kwargs... ) -> F (args... ; kwargs... )
171
+ if VERSION >= v " 1.11.0-DEV.103"
172
+ const broadcast_flatten = Broadcast. flatten
173
+ else
174
+ using Base: tail
175
+ using Base. Broadcast: isflat, Broadcasted
176
+
177
+ maybeconstructor (f) = f
178
+ maybeconstructor (:: Type{F} ) where {F} = (args... ; kwargs... ) -> F (args... ; kwargs... )
179
+
180
+ function broadcast_flatten (bc:: Broadcasted{Style} ) where {Style}
181
+ isflat (bc) && return bc
182
+ args = cat_nested (bc)
183
+ len = Val {length(args)} ()
184
+ makeargs = make_makeargs (bc. args, len, ntuple (_-> true , len))
185
+ f = maybeconstructor (bc. f)
186
+ @inline newf (args... ) = f (prepare_args (makeargs, args)... )
187
+ return Broadcasted {Style} (newf, args, bc. axes)
188
+ end
177
189
178
- function broadcast_flatten (bc:: Broadcasted{Style} ) where {Style}
179
- isflat (bc) && return bc
180
- args = cat_nested (bc)
181
- len = Val {length(args)} ()
182
- makeargs = make_makeargs (bc. args, len, ntuple (_-> true , len))
183
- f = maybeconstructor (bc. f)
184
- @inline newf (args... ) = f (prepare_args (makeargs, args)... )
185
- return Broadcasted {Style} (newf, args, bc. axes)
186
- end
190
+ cat_nested (bc:: Broadcasted ) = cat_nested_args (bc. args)
191
+ cat_nested_args (:: Tuple{} ) = ()
192
+ cat_nested_args (t:: Tuple ) = (cat_nested (t[1 ])... , cat_nested_args (tail (t))... )
193
+ cat_nested (@nospecialize (a)) = (a,)
187
194
188
- cat_nested (bc:: Broadcasted ) = cat_nested_args (bc. args)
189
- cat_nested_args (:: Tuple{} ) = ()
190
- cat_nested_args (t:: Tuple ) = (cat_nested (t[1 ])... , cat_nested_args (tail (t))... )
191
- cat_nested (@nospecialize (a)) = (a,)
195
+ function make_makeargs (args:: Tuple , len, flags)
196
+ makeargs, r = _make_makeargs (args, len, flags)
197
+ r isa Tuple{} || error (" Internal error. Please file a bug" )
198
+ return makeargs
199
+ end
192
200
193
- function make_makeargs (args:: Tuple , len, flags)
194
- makeargs, r = _make_makeargs (args, len, flags)
195
- r isa Tuple{} || error (" Internal error. Please file a bug" )
196
- return makeargs
197
- end
201
+ # We build `makeargs` by traversing the broadcast nodes recursively.
202
+ # note: `len` isa `Val` indicates the length of whole flattened argument list.
203
+ # `flags` is a tuple of `Bool` with the same length of the rest arguments.
204
+ @inline function _make_makeargs (args:: Tuple , len:: Val , flags:: Tuple )
205
+ head, flags′ = _make_makeargs1 (args[1 ], len, flags)
206
+ rest, flags″ = _make_makeargs (tail (args), len, flags′)
207
+ (head, rest... ), flags″
208
+ end
209
+ _make_makeargs (:: Tuple{} , :: Val , x:: Tuple ) = (), x
198
210
199
- # We build `makeargs` by traversing the broadcast nodes recursively.
200
- # note: `len` isa `Val` indicates the length of whole flattened argument list.
201
- # `flags` is a tuple of `Bool` with the same length of the rest arguments.
202
- @inline function _make_makeargs (args:: Tuple , len:: Val , flags:: Tuple )
203
- head, flags′ = _make_makeargs1 (args[1 ], len, flags)
204
- rest, flags″ = _make_makeargs (tail (args), len, flags′)
205
- (head, rest... ), flags″
206
- end
207
- _make_makeargs (:: Tuple{} , :: Val , x:: Tuple ) = (), x
211
+ # For flat nodes:
212
+ # 1. we just consume one argument, and return the "pick" function
213
+ @inline function _make_makeargs1 (@nospecialize (a), :: Val{N} , flags:: Tuple ) where {N}
214
+ pickargs (:: Val{N} ) where {N} = (@nospecialize (x:: Tuple )) -> x[N]
215
+ return pickargs (Val {N-length(flags)+1} ()), tail (flags)
216
+ end
208
217
209
- # For flat nodes:
210
- # 1. we just consume one argument, and return the "pick" function
211
- @inline function _make_makeargs1 (@nospecialize (a), :: Val{N} , flags:: Tuple ) where {N}
212
- pickargs (:: Val{N} ) where {N} = (@nospecialize (x:: Tuple )) -> x[N]
213
- return pickargs (Val {N-length(flags)+1} ()), tail (flags)
214
- end
218
+ # For nested nodes, we form the `makeargs1` based on the child `makeargs` (n += length(cat_nested(bc)))
219
+ @inline function _make_makeargs1 (bc:: Broadcasted , len:: Val , flags:: Tuple )
220
+ makeargs, flags′ = _make_makeargs (bc. args, len, flags)
221
+ f = maybeconstructor (bc. f)
222
+ @inline makeargs1 (@nospecialize (args:: Tuple )) = f (prepare_args (makeargs, args)... )
223
+ makeargs1, flags′
224
+ end
215
225
216
- # For nested nodes, we form the `makeargs1` based on the child `makeargs` (n += length(cat_nested(bc)))
217
- @inline function _make_makeargs1 (bc:: Broadcasted , len:: Val , flags:: Tuple )
218
- makeargs, flags′ = _make_makeargs (bc. args, len, flags)
219
- f = maybeconstructor (bc. f)
220
- @inline makeargs1 (@nospecialize (args:: Tuple )) = f (prepare_args (makeargs, args)... )
221
- makeargs1, flags′
226
+ prepare_args (:: Tuple{} , @nospecialize (:: Tuple )) = ()
227
+ @inline prepare_args (makeargs:: Tuple , @nospecialize (x:: Tuple )) = (makeargs[1 ](x), prepare_args (tail (makeargs), x)... )
222
228
end
223
-
224
- prepare_args (:: Tuple{} , @nospecialize (:: Tuple )) = ()
225
- @inline prepare_args (makeargs:: Tuple , @nospecialize (x:: Tuple )) = (makeargs[1 ](x), prepare_args (tail (makeargs), x)... )
226
229
end
227
230
using . StableFlatten
0 commit comments