From ced5c52cfaf74f8effb2b85fbaa61d623cf4e9a4 Mon Sep 17 00:00:00 2001 From: N5N3 <2642243996@qq.com> Date: Wed, 2 Mar 2022 21:16:09 +0800 Subject: [PATCH 1/7] Fix `Transpose`/`Adjoint`'s Style --- src/broadcast.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/broadcast.jl b/src/broadcast.jl index b5a69b63..ff926de9 100644 --- a/src/broadcast.jl +++ b/src/broadcast.jl @@ -11,8 +11,8 @@ using Base.Broadcast: _bcsm struct StaticArrayStyle{N} <: AbstractArrayStyle{N} end StaticArrayStyle{M}(::Val{N}) where {M,N} = StaticArrayStyle{N}() BroadcastStyle(::Type{<:StaticArray{<:Tuple, <:Any, N}}) where {N} = StaticArrayStyle{N}() -BroadcastStyle(::Type{<:Transpose{<:Any, <:StaticArray{<:Tuple, <:Any, N}}}) where {N} = StaticArrayStyle{N}() -BroadcastStyle(::Type{<:Adjoint{<:Any, <:StaticArray{<:Tuple, <:Any, N}}}) where {N} = StaticArrayStyle{N}() +BroadcastStyle(::Type{<:Transpose{<:Any, <:StaticArray}}) = StaticArrayStyle{2}() +BroadcastStyle(::Type{<:Adjoint{<:Any, <:StaticArray}}) = StaticArrayStyle{2}() BroadcastStyle(::Type{<:Diagonal{<:Any, <:StaticArray{<:Tuple, <:Any, 1}}}) = StaticArrayStyle{2}() # Precedence rules BroadcastStyle(::StaticArrayStyle{M}, ::DefaultArrayStyle{N}) where {M,N} = From f04ca916a16ef801425cc7f00e399461b8e18270 Mon Sep 17 00:00:00 2001 From: N5N3 <2642243996@qq.com> Date: Wed, 2 Mar 2022 21:22:47 +0800 Subject: [PATCH 2/7] Make `instantiate` generate static axes add `static_combine_axes` to handle Tuple correctly --- src/broadcast.jl | 47 +++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 45 insertions(+), 2 deletions(-) diff --git a/src/broadcast.jl b/src/broadcast.jl index ff926de9..91846342 100644 --- a/src/broadcast.jl +++ b/src/broadcast.jl @@ -4,6 +4,7 @@ import Base.Broadcast: BroadcastStyle, AbstractArrayStyle, Broadcasted, DefaultArrayStyle, materialize! +import Base.Broadcast: combine_axes, instantiate, _broadcast_getindex, broadcast_shape, Style import Base.Broadcast: _bcs1 # for SOneTo axis information using Base.Broadcast: _bcsm # Add a new BroadcastStyle for StaticArrays, derived from AbstractArrayStyle @@ -19,6 +20,42 @@ BroadcastStyle(::StaticArrayStyle{M}, ::DefaultArrayStyle{N}) where {M,N} = DefaultArrayStyle(Val(max(M, N))) BroadcastStyle(::StaticArrayStyle{M}, ::DefaultArrayStyle{0}) where {M} = StaticArrayStyle{M}() + +# combine_axes overload (for Tuple) +@inline static_combine_axes(A, B...) = broadcast_shape(static_axes(A), static_combine_axes(B...)) +static_combine_axes(A) = static_axes(A) +static_axes(A) = axes(A) +static_axes(x::Tuple) = (SOneTo{length(x)}(),) +static_axes(bc::Broadcasted{Style{Tuple}}) = static_combine_axes(bc.args...) +Broadcast._axes(bc::Broadcasted{<:StaticArrayStyle}, ::Nothing) = static_combine_axes(bc.args...) + +# instantiate overload +@inline function instantiate(B::Broadcasted{StaticArrayStyle{M}}) where M + if B.axes isa Tuple{Vararg{SOneTo}} || B.axes isa Tuple && length(B.axes) > M + return invoke(instantiate, Tuple{Broadcasted}, B) + elseif B.axes isa Nothing + ax = static_combine_axes(B.args...) + return Broadcasted{StaticArrayStyle{M}}(B.f, B.args, ax) + else + # We need to update B.axes for `broadcast!` if it's not static and `ndims(dest) < M`. + ax = static_check_broadcast_shape(B.axes, static_combine_axes(B.args...)) + return Broadcasted{StaticArrayStyle{M}}(B.f, B.args, ax) + end +end +@inline function static_check_broadcast_shape(shp::Tuple, Ashp::Tuple{Vararg{SOneTo}}) + ax1 = if length(Ashp[1]) == 1 + shp[1] + elseif Ashp[1] == shp[1] + Ashp[1] + else + throw(DimensionMismatch("array could not be broadcast to match destination")) + end + (ax1, static_check_broadcast_shape(Base.tail(shp), Base.tail(Ashp))...) +end +static_check_broadcast_shape(::Tuple{}, ::Tuple{SOneTo,Vararg{SOneTo}}) = + throw(DimensionMismatch("cannot broadcast array to have fewer non-singleton dimensions")) +static_check_broadcast_shape(::Tuple{}, ::Tuple{SOneTo{1},Vararg{SOneTo{1}}}) = () +static_check_broadcast_shape(::Tuple{}, ::Tuple{}) = () # copy overload @inline function Base.copy(B::Broadcasted{StaticArrayStyle{M}}) where M flat = Broadcast.flatten(B); as = flat.args; f = flat.f @@ -42,8 +79,14 @@ end # Resolving priority between dynamic and static axes _bcs1(a::SOneTo, b::SOneTo) = _bcsm(b, a) ? b : (_bcsm(a, b) ? a : throw(DimensionMismatch("arrays could not be broadcast to a common size"))) -_bcs1(a::SOneTo, b::Base.OneTo) = _bcs1(Base.OneTo(a), b) -_bcs1(a::Base.OneTo, b::SOneTo) = _bcs1(a, Base.OneTo(b)) +function _bcs1(a::SOneTo, b::Base.OneTo) + length(a) == 1 && return b + if length(b) != length(a) && length(b) != 1 + throw(DimensionMismatch("arrays could not be broadcast to a common size")) + end + return a +end +_bcs1(a::Base.OneTo, b::SOneTo) = _bcs1(b, a) ################################################### ## Internal broadcast machinery for StaticArrays ## From 722627e94daca9d587c14d748d919635b4735aa4 Mon Sep 17 00:00:00 2001 From: N5N3 <2642243996@qq.com> Date: Wed, 2 Mar 2022 21:26:13 +0800 Subject: [PATCH 3/7] Drop `combine_size` 1. `axes` should be static. 2. The implementation for broadcast(1) are unified) 3. simplify `broadcast_index` --- src/abstractarray.jl | 1 + src/broadcast.jl | 83 ++++++++++++-------------------------------- src/precompile.jl | 6 +++- 3 files changed, 28 insertions(+), 62 deletions(-) diff --git a/src/abstractarray.jl b/src/abstractarray.jl index 07ab65c1..8e27032f 100644 --- a/src/abstractarray.jl +++ b/src/abstractarray.jl @@ -14,6 +14,7 @@ Base.axes(s::StaticArray) = _axes(Size(s)) end Base.axes(rv::Adjoint{<:Any,<:StaticVector}) = (SOneTo(1), axes(rv.parent)...) Base.axes(rv::Transpose{<:Any,<:StaticVector}) = (SOneTo(1), axes(rv.parent)...) +Base.axes(d::Diagonal{<:Any,<:StaticVector}) = (ax = axes(d.diag, 1); (ax, ax)) Base.eachindex(::IndexLinear, a::StaticArray) = SOneTo(length(a)) diff --git a/src/broadcast.jl b/src/broadcast.jl index 91846342..0066276a 100644 --- a/src/broadcast.jl +++ b/src/broadcast.jl @@ -60,8 +60,11 @@ static_check_broadcast_shape(::Tuple{}, ::Tuple{}) = () @inline function Base.copy(B::Broadcasted{StaticArrayStyle{M}}) where M flat = Broadcast.flatten(B); as = flat.args; f = flat.f argsizes = broadcast_sizes(as...) - destsize = combine_sizes(argsizes) - _broadcast(f, destsize, argsizes, as...) + ax = axes(B) + if ax isa Tuple{Vararg{SOneTo}} + return _broadcast(f, Size(map(length, ax)), argsizes, as...) + end + return copy(convert(Broadcasted{DefaultArrayStyle{M}}, B)) end # copyto! overloads @inline Base.copyto!(dest, B::Broadcasted{<:StaticArrayStyle}) = _copyto!(dest, B) @@ -69,12 +72,13 @@ end @inline function _copyto!(dest, B::Broadcasted{StaticArrayStyle{M}}) where M flat = Broadcast.flatten(B); as = flat.args; f = flat.f argsizes = broadcast_sizes(as...) - destsize = combine_sizes((Size(dest), argsizes...)) - if Length(destsize) === Length{Dynamic()}() - # destination dimension cannot be determined statically; fall back to generic broadcast! - return copyto!(dest, convert(Broadcasted{DefaultArrayStyle{M}}, B)) + ax = axes(B) + if ax isa Tuple{Vararg{SOneTo}} + @boundscheck axes(dest) == ax || Broadcast.throwdm(axes(dest), ax) + return _broadcast!(f, Size(map(length, ax)), dest, argsizes, as...) end - _broadcast!(f, destsize, dest, argsizes, as...) + # destination dimension cannot be determined statically; fall back to generic broadcast! + return copyto!(dest, convert(Broadcasted{DefaultArrayStyle{M}}, B)) end # Resolving priority between dynamic and static axes @@ -101,45 +105,13 @@ broadcast_indices(A::StaticArray) = indices(A) @inline broadcast_size(a::AbstractArray) = Size(a) @inline broadcast_size(a::Tuple) = Size(length(a)) -function broadcasted_index(oldsize, newindex) - index = ones(Int, length(oldsize)) - for i = 1:length(oldsize) - if oldsize[i] != 1 - index[i] = newindex[i] - end - end - return LinearIndices(oldsize)[index...] -end - -# similar to Base.Broadcast.combine_indices: -@generated function combine_sizes(s::Tuple{Vararg{Size}}) - sizes = [sz.parameters[1] for sz ∈ s.parameters] - ndims = 0 - for i = 1:length(sizes) - ndims = max(ndims, length(sizes[i])) - end - newsize = StaticDimension[Dynamic() for _ = 1 : ndims] - for i = 1:length(sizes) - s = sizes[i] - for j = 1:length(s) - if s[j] isa Dynamic - continue - elseif newsize[j] isa Dynamic || newsize[j] == 1 - newsize[j] = s[j] - elseif newsize[j] ≠ s[j] && s[j] ≠ 1 - throw(DimensionMismatch("Tried to broadcast on inputs sized $sizes")) - end - end - end - quote - @_inline_meta - Size($(tuple(newsize...))) - end +broadcast_getindex(::Tuple{}, i::Int, I::CartesianIndex) = return :(_broadcast_getindex(a[$i], $I)) +function broadcast_getindex(oldsize::Tuple, i::Int, newindex::CartesianIndex) + li = LinearIndices(oldsize) + ind = _broadcast_getindex(li, newindex) + return :(a[$i][$ind]) end -scalar_getindex(x) = x -scalar_getindex(x::Ref) = x[] - isstatic(::StaticArray) = true isstatic(::Transpose{<:Any, <:StaticArray}) = true isstatic(::Adjoint{<:Any, <:StaticArray}) = true @@ -163,13 +135,11 @@ end @generated function __broadcast(f, ::Size{newsize}, s::Tuple{Vararg{Size}}, a...) where newsize sizes = [sz.parameters[1] for sz ∈ s.parameters] + indices = CartesianIndices(newsize) exprs = similar(indices, Expr) for (j, current_ind) ∈ enumerate(indices) - exprs_vals = [ - (!(a[i] <: AbstractArray || a[i] <: Tuple) ? :(scalar_getindex(a[$i])) : :(a[$i][$(broadcasted_index(sizes[i], current_ind))])) - for i = 1:length(sizes) - ] + exprs_vals = (broadcast_getindex(sz, i, current_ind) for (i, sz) in enumerate(sizes)) exprs[j] = :(f($(exprs_vals...))) end @@ -183,27 +153,18 @@ end ## Internal broadcast! machinery for StaticArrays ## #################################################### -@generated function _broadcast!(f, ::Size{newsize}, dest::AbstractArray, s::Tuple{Vararg{Size}}, as...) where {newsize} - sizes = [sz.parameters[1] for sz ∈ s.parameters] - sizes = tuple(sizes...) - - # TODO: this could also be done outside the generated function: - sizematch(Size{newsize}(), Size(dest)) || - throw(DimensionMismatch("Tried to broadcast to destination sized $newsize from inputs sized $sizes")) +@generated function _broadcast!(f, ::Size{newsize}, dest::AbstractArray, s::Tuple{Vararg{Size}}, a...) where {newsize} + sizes = [sz.parameters[1] for sz in s.parameters] indices = CartesianIndices(newsize) exprs = similar(indices, Expr) for (j, current_ind) ∈ enumerate(indices) - exprs_vals = [ - (!(as[i] <: AbstractArray || as[i] <: Tuple) ? :(as[$i][]) : :(as[$i][$(broadcasted_index(sizes[i], current_ind))])) - for i = 1:length(sizes) - ] + exprs_vals = (broadcast_getindex(sz, i, current_ind) for (i, sz) in enumerate(sizes)) exprs[j] = :(dest[$j] = f($(exprs_vals...))) end return quote - @_propagate_inbounds_meta - @boundscheck sizematch($(Size{newsize}()), dest) || throw(DimensionMismatch("array could not be broadcast to match destination")) + @_inline_meta @inbounds $(Expr(:block, exprs...)) return dest end diff --git a/src/precompile.jl b/src/precompile.jl index 90e0c2af..63305326 100644 --- a/src/precompile.jl +++ b/src/precompile.jl @@ -24,6 +24,10 @@ function _precompile_() # Some expensive generators @assert precompile(Tuple{typeof(which(__broadcast,(Any,Size,Tuple{Vararg{Size}},Vararg{Any},)).generator.gen),Any,Any,Any,Any,Any,Any}) @assert precompile(Tuple{typeof(which(_zeros,(Size,Type{<:StaticArray},)).generator.gen),Any,Any,Any,Type,Any}) - @assert precompile(Tuple{typeof(which(combine_sizes,(Tuple{Vararg{Size}},)).generator.gen),Any,Any}) @assert precompile(Tuple{typeof(which(_mapfoldl,(Any,Any,Colon,Any,Size,Vararg{StaticArray},)).generator.gen),Any,Any,Any,Any,Any,Any,Any,Any}) + + # broadcast_getindex + for m = 0:5, n = m:5 + @assert precompile(Tuple{typeof(broadcast_getindex),NTuple{m,Int},Int,CartesianIndex{n}}) + end end From b72a89c17ab01f5306e8a153509d7fbaed84992c Mon Sep 17 00:00:00 2001 From: N5N3 <2642243996@qq.com> Date: Tue, 1 Mar 2022 19:50:00 +0800 Subject: [PATCH 4/7] Add test --- test/broadcast.jl | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/test/broadcast.jl b/test/broadcast.jl index 1baa4f28..937aa723 100644 --- a/test/broadcast.jl +++ b/test/broadcast.jl @@ -315,3 +315,23 @@ end end end end + +@testset "instantiate with axes updated" begin + f(a; ax = nothing) = Broadcast.Broadcasted{StaticArrays.StaticArrayStyle{ndims(a)}}(+,(a,),ax) + a = @SArray zeros(2,2,2) + ax = Base.OneTo(2), Base.OneTo(2), Base.OneTo(2) + @test @inferred(Broadcast.instantiate(f(a; ax))).axes isa NTuple{3,SOneTo} + ax = (ax..., Base.OneTo(2)) + @test @inferred(Broadcast.instantiate(f(a; ax))).axes isa NTuple{4,Base.OneTo} + ax = setindex(ax, Base.OneTo(1), 4) + @test @inferred(Broadcast.instantiate(f(a; ax))).axes isa NTuple{4,Base.OneTo} + a = @SArray zeros(2,1,2) + ax = Base.OneTo(2), Base.OneTo(2), Base.OneTo(2) + @test @inferred(Broadcast.instantiate(f(a; ax))).axes isa Tuple{SOneTo,Base.OneTo,SOneTo} + @test_throws DimensionMismatch Broadcast.instantiate(f(a; ax = ax[1:2])) + + a = @SArray zeros(2,2,1) + ax = Base.OneTo(2), Base.OneTo(2), Base.OneTo(2) + @test @inferred(Broadcast.instantiate(f(a; ax))).axes isa Tuple{SOneTo,SOneTo,Base.OneTo} + @test @inferred(Broadcast.instantiate(f(a; ax = ax[1:2]))).axes isa NTuple{2,SOneTo} +end From 8af70587a36d64955456bbf68e6bf379070e58e1 Mon Sep 17 00:00:00 2001 From: N5N3 <2642243996@qq.com> Date: Mon, 14 Mar 2022 22:39:17 +0800 Subject: [PATCH 5/7] add `return` Co-Authored-By: Thomas Christensen --- src/broadcast.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/broadcast.jl b/src/broadcast.jl index 0066276a..7565ed9a 100644 --- a/src/broadcast.jl +++ b/src/broadcast.jl @@ -50,7 +50,7 @@ end else throw(DimensionMismatch("array could not be broadcast to match destination")) end - (ax1, static_check_broadcast_shape(Base.tail(shp), Base.tail(Ashp))...) + return (ax1, static_check_broadcast_shape(Base.tail(shp), Base.tail(Ashp))...) end static_check_broadcast_shape(::Tuple{}, ::Tuple{SOneTo,Vararg{SOneTo}}) = throw(DimensionMismatch("cannot broadcast array to have fewer non-singleton dimensions")) From 4126cd138f7fd9518f94639f9ccc0d5b6a93382f Mon Sep 17 00:00:00 2001 From: N5N3 <2642243996@qq.com> Date: Tue, 15 Mar 2022 17:17:48 +0800 Subject: [PATCH 6/7] import clean --- src/broadcast.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/broadcast.jl b/src/broadcast.jl index 7565ed9a..62778f63 100644 --- a/src/broadcast.jl +++ b/src/broadcast.jl @@ -2,9 +2,9 @@ ## broadcast! ## ################ -import Base.Broadcast: -BroadcastStyle, AbstractArrayStyle, Broadcasted, DefaultArrayStyle, materialize! -import Base.Broadcast: combine_axes, instantiate, _broadcast_getindex, broadcast_shape, Style +using Base.Broadcast: AbstractArrayStyle, DefaultArrayStyle, Style, Broadcasted +using Base.Broadcast: broadcast_shape, _broadcast_getindex, combine_axes +import Base.Broadcast: BroadcastStyle, materialize!, instantiate import Base.Broadcast: _bcs1 # for SOneTo axis information using Base.Broadcast: _bcsm # Add a new BroadcastStyle for StaticArrays, derived from AbstractArrayStyle From 7fedea6f2ea569aae541ed267c38e49378484766 Mon Sep 17 00:00:00 2001 From: N5N3 <2642243996@qq.com> Date: Tue, 15 Mar 2022 18:22:25 +0800 Subject: [PATCH 7/7] Remove broken branch. --- src/broadcast.jl | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/broadcast.jl b/src/broadcast.jl index 62778f63..9f0cdc1b 100644 --- a/src/broadcast.jl +++ b/src/broadcast.jl @@ -61,10 +61,8 @@ static_check_broadcast_shape(::Tuple{}, ::Tuple{}) = () flat = Broadcast.flatten(B); as = flat.args; f = flat.f argsizes = broadcast_sizes(as...) ax = axes(B) - if ax isa Tuple{Vararg{SOneTo}} - return _broadcast(f, Size(map(length, ax)), argsizes, as...) - end - return copy(convert(Broadcasted{DefaultArrayStyle{M}}, B)) + ax isa Tuple{Vararg{SOneTo}} || error("Dimension is not static. Please file a bug.") + return _broadcast(f, Size(map(length, ax)), argsizes, as...) end # copyto! overloads @inline Base.copyto!(dest, B::Broadcasted{<:StaticArrayStyle}) = _copyto!(dest, B)