Skip to content

Commit ff8ed9b

Browse files
authored
Use Broadcast.flatten on master (#1186)
As julia#43322 has been merged.
1 parent 48dec2c commit ff8ed9b

File tree

1 file changed

+51
-48
lines changed

1 file changed

+51
-48
lines changed

src/broadcast.jl

Lines changed: 51 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -164,64 +164,67 @@ end
164164
# Work around for https://github.com/JuliaLang/julia/issues/27988
165165
# The following code is borrowed from https://github.com/JuliaLang/julia/pull/43322
166166
# with some modification to make it also works on 1.6.
167-
# TODO: make `broadcast_flatten` call `Broadcast.flatten` once julia#43322 is merged.
168167
module StableFlatten
169168

170169
export broadcast_flatten
171170

172-
using Base: tail
173-
using Base.Broadcast: isflat, Broadcasted
174-
175-
maybeconstructor(f) = f
176-
maybeconstructor(::Type{F}) where {F} = (args...; kwargs...) -> F(args...; kwargs...)
171+
if VERSION >= v"1.11.0-DEV.103"
172+
const broadcast_flatten = Broadcast.flatten
173+
else
174+
using Base: tail
175+
using Base.Broadcast: isflat, Broadcasted
176+
177+
maybeconstructor(f) = f
178+
maybeconstructor(::Type{F}) where {F} = (args...; kwargs...) -> F(args...; kwargs...)
179+
180+
function broadcast_flatten(bc::Broadcasted{Style}) where {Style}
181+
isflat(bc) && return bc
182+
args = cat_nested(bc)
183+
len = Val{length(args)}()
184+
makeargs = make_makeargs(bc.args, len, ntuple(_->true, len))
185+
f = maybeconstructor(bc.f)
186+
@inline newf(args...) = f(prepare_args(makeargs, args)...)
187+
return Broadcasted{Style}(newf, args, bc.axes)
188+
end
177189

178-
function broadcast_flatten(bc::Broadcasted{Style}) where {Style}
179-
isflat(bc) && return bc
180-
args = cat_nested(bc)
181-
len = Val{length(args)}()
182-
makeargs = make_makeargs(bc.args, len, ntuple(_->true, len))
183-
f = maybeconstructor(bc.f)
184-
@inline newf(args...) = f(prepare_args(makeargs, args)...)
185-
return Broadcasted{Style}(newf, args, bc.axes)
186-
end
190+
cat_nested(bc::Broadcasted) = cat_nested_args(bc.args)
191+
cat_nested_args(::Tuple{}) = ()
192+
cat_nested_args(t::Tuple) = (cat_nested(t[1])..., cat_nested_args(tail(t))...)
193+
cat_nested(@nospecialize(a)) = (a,)
187194

188-
cat_nested(bc::Broadcasted) = cat_nested_args(bc.args)
189-
cat_nested_args(::Tuple{}) = ()
190-
cat_nested_args(t::Tuple) = (cat_nested(t[1])..., cat_nested_args(tail(t))...)
191-
cat_nested(@nospecialize(a)) = (a,)
195+
function make_makeargs(args::Tuple, len, flags)
196+
makeargs, r = _make_makeargs(args, len, flags)
197+
r isa Tuple{} || error("Internal error. Please file a bug")
198+
return makeargs
199+
end
192200

193-
function make_makeargs(args::Tuple, len, flags)
194-
makeargs, r = _make_makeargs(args, len, flags)
195-
r isa Tuple{} || error("Internal error. Please file a bug")
196-
return makeargs
197-
end
201+
# We build `makeargs` by traversing the broadcast nodes recursively.
202+
# note: `len` isa `Val` indicates the length of whole flattened argument list.
203+
# `flags` is a tuple of `Bool` with the same length of the rest arguments.
204+
@inline function _make_makeargs(args::Tuple, len::Val, flags::Tuple)
205+
head, flags′ = _make_makeargs1(args[1], len, flags)
206+
rest, flags″ = _make_makeargs(tail(args), len, flags′)
207+
(head, rest...), flags″
208+
end
209+
_make_makeargs(::Tuple{}, ::Val, x::Tuple) = (), x
198210

199-
# We build `makeargs` by traversing the broadcast nodes recursively.
200-
# note: `len` isa `Val` indicates the length of whole flattened argument list.
201-
# `flags` is a tuple of `Bool` with the same length of the rest arguments.
202-
@inline function _make_makeargs(args::Tuple, len::Val, flags::Tuple)
203-
head, flags′ = _make_makeargs1(args[1], len, flags)
204-
rest, flags″ = _make_makeargs(tail(args), len, flags′)
205-
(head, rest...), flags″
206-
end
207-
_make_makeargs(::Tuple{}, ::Val, x::Tuple) = (), x
211+
# For flat nodes:
212+
# 1. we just consume one argument, and return the "pick" function
213+
@inline function _make_makeargs1(@nospecialize(a), ::Val{N}, flags::Tuple) where {N}
214+
pickargs(::Val{N}) where {N} = (@nospecialize(x::Tuple)) -> x[N]
215+
return pickargs(Val{N-length(flags)+1}()), tail(flags)
216+
end
208217

209-
# For flat nodes:
210-
# 1. we just consume one argument, and return the "pick" function
211-
@inline function _make_makeargs1(@nospecialize(a), ::Val{N}, flags::Tuple) where {N}
212-
pickargs(::Val{N}) where {N} = (@nospecialize(x::Tuple)) -> x[N]
213-
return pickargs(Val{N-length(flags)+1}()), tail(flags)
214-
end
218+
# For nested nodes, we form the `makeargs1` based on the child `makeargs` (n += length(cat_nested(bc)))
219+
@inline function _make_makeargs1(bc::Broadcasted, len::Val, flags::Tuple)
220+
makeargs, flags′ = _make_makeargs(bc.args, len, flags)
221+
f = maybeconstructor(bc.f)
222+
@inline makeargs1(@nospecialize(args::Tuple)) = f(prepare_args(makeargs, args)...)
223+
makeargs1, flags′
224+
end
215225

216-
# For nested nodes, we form the `makeargs1` based on the child `makeargs` (n += length(cat_nested(bc)))
217-
@inline function _make_makeargs1(bc::Broadcasted, len::Val, flags::Tuple)
218-
makeargs, flags′ = _make_makeargs(bc.args, len, flags)
219-
f = maybeconstructor(bc.f)
220-
@inline makeargs1(@nospecialize(args::Tuple)) = f(prepare_args(makeargs, args)...)
221-
makeargs1, flags′
226+
prepare_args(::Tuple{}, @nospecialize(::Tuple)) = ()
227+
@inline prepare_args(makeargs::Tuple, @nospecialize(x::Tuple)) = (makeargs[1](x), prepare_args(tail(makeargs), x)...)
222228
end
223-
224-
prepare_args(::Tuple{}, @nospecialize(::Tuple)) = ()
225-
@inline prepare_args(makeargs::Tuple, @nospecialize(x::Tuple)) = (makeargs[1](x), prepare_args(tail(makeargs), x)...)
226229
end
227230
using .StableFlatten

0 commit comments

Comments
 (0)