Skip to content

Commit d5c3c96

Browse files
authored
Improve broadcast's stability (#1079)
* Improve `broadcast`'s stability * Fix coverage. * bump
1 parent 5cdc0d0 commit d5c3c96

File tree

3 files changed

+85
-3
lines changed

3 files changed

+85
-3
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "StaticArrays"
22
uuid = "90137ffa-7385-5640-81b9-e52037218182"
3-
version = "1.5.4"
3+
version = "1.5.5"
44

55
[deps]
66
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"

src/broadcast.jl

Lines changed: 67 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ static_check_broadcast_shape(::Tuple{}, ::Tuple{SOneTo{1},Vararg{SOneTo{1}}}) =
5858
static_check_broadcast_shape(::Tuple{}, ::Tuple{}) = ()
5959
# copy overload
6060
@inline function Base.copy(B::Broadcasted{StaticArrayStyle{M}}) where M
61-
flat = Broadcast.flatten(B); as = flat.args; f = flat.f
61+
flat = broadcast_flatten(B); as = flat.args; f = flat.f
6262
argsizes = broadcast_sizes(as...)
6363
ax = axes(B)
6464
ax isa Tuple{Vararg{SOneTo}} || error("Dimension is not static. Please file a bug.")
@@ -68,7 +68,7 @@ end
6868
@inline Base.copyto!(dest, B::Broadcasted{<:StaticArrayStyle}) = _copyto!(dest, B)
6969
@inline Base.copyto!(dest::AbstractArray, B::Broadcasted{<:StaticArrayStyle}) = _copyto!(dest, B)
7070
@inline function _copyto!(dest, B::Broadcasted{StaticArrayStyle{M}}) where M
71-
flat = Broadcast.flatten(B); as = flat.args; f = flat.f
71+
flat = broadcast_flatten(B); as = flat.args; f = flat.f
7272
argsizes = broadcast_sizes(as...)
7373
ax = axes(B)
7474
if ax isa Tuple{Vararg{SOneTo}}
@@ -165,3 +165,68 @@ end
165165
return dest
166166
end
167167
end
168+
169+
# Work around for https://github.com/JuliaLang/julia/issues/27988
170+
# The following code is borrowed from https://github.com/JuliaLang/julia/pull/43322
171+
# with some modification to make it also works on 1.6.
172+
# TODO: make `broadcast_flatten` call `Broadcast.flatten` once julia#43322 is merged.
173+
module StableFlatten
174+
175+
export broadcast_flatten
176+
177+
using Base: tail
178+
using Base.Broadcast: isflat, Broadcasted
179+
180+
maybeconstructor(f) = f
181+
maybeconstructor(::Type{F}) where {F} = (args...; kwargs...) -> F(args...; kwargs...)
182+
183+
function broadcast_flatten(bc::Broadcasted{Style}) where {Style}
184+
isflat(bc) && return bc
185+
args = cat_nested(bc)
186+
len = Val{length(args)}()
187+
makeargs = make_makeargs(bc.args, len, ntuple(_->true, len))
188+
f = maybeconstructor(bc.f)
189+
@inline newf(args...) = f(prepare_args(makeargs, args)...)
190+
return Broadcasted{Style}(newf, args, bc.axes)
191+
end
192+
193+
cat_nested(bc::Broadcasted) = cat_nested_args(bc.args)
194+
cat_nested_args(::Tuple{}) = ()
195+
cat_nested_args(t::Tuple) = (cat_nested(t[1])..., cat_nested_args(tail(t))...)
196+
cat_nested(@nospecialize(a)) = (a,)
197+
198+
function make_makeargs(args::Tuple, len, flags)
199+
makeargs, r = _make_makeargs(args, len, flags)
200+
r isa Tuple{} || error("Internal error. Please file a bug")
201+
return makeargs
202+
end
203+
204+
# We build `makeargs` by traversing the broadcast nodes recursively.
205+
# note: `len` isa `Val` indicates the length of whole flattened argument list.
206+
# `flags` is a tuple of `Bool` with the same length of the rest arguments.
207+
@inline function _make_makeargs(args::Tuple, len::Val, flags::Tuple)
208+
head, flags′ = _make_makeargs1(args[1], len, flags)
209+
rest, flags″ = _make_makeargs(tail(args), len, flags′)
210+
(head, rest...), flags″
211+
end
212+
_make_makeargs(::Tuple{}, ::Val, x::Tuple) = (), x
213+
214+
# For flat nodes:
215+
# 1. we just consume one argument, and return the "pick" function
216+
@inline function _make_makeargs1(@nospecialize(a), ::Val{N}, flags::Tuple) where {N}
217+
pickargs(::Val{N}) where {N} = (@nospecialize(x::Tuple)) -> x[N]
218+
return pickargs(Val{N-length(flags)+1}()), tail(flags)
219+
end
220+
221+
# For nested nodes, we form the `makeargs1` based on the child `makeargs` (n += length(cat_nested(bc)))
222+
@inline function _make_makeargs1(bc::Broadcasted, len::Val, flags::Tuple)
223+
makeargs, flags′ = _make_makeargs(bc.args, len, flags)
224+
f = maybeconstructor(bc.f)
225+
@inline makeargs1(@nospecialize(args::Tuple)) = f(prepare_args(makeargs, args)...)
226+
makeargs1, flags′
227+
end
228+
229+
prepare_args(::Tuple{}, @nospecialize(::Tuple)) = ()
230+
@inline prepare_args(makeargs::Tuple, @nospecialize(x::Tuple)) = (makeargs[1](x), prepare_args(tail(makeargs), x)...)
231+
end
232+
using .StableFlatten

test/broadcast.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -335,3 +335,20 @@ end
335335
@test @inferred(Broadcast.instantiate(f(a; ax))).axes isa Tuple{SOneTo,SOneTo,Base.OneTo}
336336
@test @inferred(Broadcast.instantiate(f(a; ax = ax[1:2]))).axes isa NTuple{2,SOneTo}
337337
end
338+
339+
@testset "`broadcast`'s stability" begin
340+
issue1078(t) = t ./ (1 .- t .^ 2)
341+
a = @SVector rand(3)
342+
@test @inferred(issue1078(a)) == issue1078(Vector(a))
343+
issue560(ũ, u₀, u₁, ρ) = ũ ./ (1e-6 .+ max.(abs.(u₀), abs.(u₁)) .* ρ)
344+
issue797(a, b, c, d) = @. a + 5 * b + 3 * c - d
345+
manual(a, b, c, d) = @. 0.1a^2 + 0.2b^3 * 0.4c^1 + 0.5d
346+
manual2(a, b, c, d) = @. Float32(a) * Float32(b) + Float32(c) * Float32(d)
347+
args = rand(3), rand(3), rand(3), rand(3)
348+
@test @inferred(issue560(map(SVector{3}, args)...)) == issue560(args...)
349+
@test @inferred(issue797(map(SVector{3}, args)...)) == issue797(args...)
350+
@test @inferred(manual(map(SVector{3}, args)...)) == manual(args...)
351+
@test @inferred(manual2(map(SVector{3}, args)...)) == manual2(args...)
352+
issue609(s, c::Integer) = (s .- s.^2) ./ c
353+
@test @inferred(issue609(SA[1.], 2)) == issue609([1.], 2)
354+
end

0 commit comments

Comments
 (0)