diff --git a/src/abstractarray.jl b/src/abstractarray.jl index 07ab65c1..8e27032f 100644 --- a/src/abstractarray.jl +++ b/src/abstractarray.jl @@ -14,6 +14,7 @@ Base.axes(s::StaticArray) = _axes(Size(s)) end Base.axes(rv::Adjoint{<:Any,<:StaticVector}) = (SOneTo(1), axes(rv.parent)...) Base.axes(rv::Transpose{<:Any,<:StaticVector}) = (SOneTo(1), axes(rv.parent)...) +Base.axes(d::Diagonal{<:Any,<:StaticVector}) = (ax = axes(d.diag, 1); (ax, ax)) Base.eachindex(::IndexLinear, a::StaticArray) = SOneTo(length(a)) diff --git a/src/broadcast.jl b/src/broadcast.jl index b5a69b63..9f0cdc1b 100644 --- a/src/broadcast.jl +++ b/src/broadcast.jl @@ -2,8 +2,9 @@ ## broadcast! ## ################ -import Base.Broadcast: -BroadcastStyle, AbstractArrayStyle, Broadcasted, DefaultArrayStyle, materialize! +using Base.Broadcast: AbstractArrayStyle, DefaultArrayStyle, Style, Broadcasted +using Base.Broadcast: broadcast_shape, _broadcast_getindex, combine_axes +import Base.Broadcast: BroadcastStyle, materialize!, instantiate import Base.Broadcast: _bcs1 # for SOneTo axis information using Base.Broadcast: _bcsm # Add a new BroadcastStyle for StaticArrays, derived from AbstractArrayStyle @@ -11,20 +12,57 @@ using Base.Broadcast: _bcsm struct StaticArrayStyle{N} <: AbstractArrayStyle{N} end StaticArrayStyle{M}(::Val{N}) where {M,N} = StaticArrayStyle{N}() BroadcastStyle(::Type{<:StaticArray{<:Tuple, <:Any, N}}) where {N} = StaticArrayStyle{N}() -BroadcastStyle(::Type{<:Transpose{<:Any, <:StaticArray{<:Tuple, <:Any, N}}}) where {N} = StaticArrayStyle{N}() -BroadcastStyle(::Type{<:Adjoint{<:Any, <:StaticArray{<:Tuple, <:Any, N}}}) where {N} = StaticArrayStyle{N}() +BroadcastStyle(::Type{<:Transpose{<:Any, <:StaticArray}}) = StaticArrayStyle{2}() +BroadcastStyle(::Type{<:Adjoint{<:Any, <:StaticArray}}) = StaticArrayStyle{2}() BroadcastStyle(::Type{<:Diagonal{<:Any, <:StaticArray{<:Tuple, <:Any, 1}}}) = StaticArrayStyle{2}() # Precedence rules BroadcastStyle(::StaticArrayStyle{M}, ::DefaultArrayStyle{N}) where {M,N} = DefaultArrayStyle(Val(max(M, N))) BroadcastStyle(::StaticArrayStyle{M}, ::DefaultArrayStyle{0}) where {M} = StaticArrayStyle{M}() + +# combine_axes overload (for Tuple) +@inline static_combine_axes(A, B...) = broadcast_shape(static_axes(A), static_combine_axes(B...)) +static_combine_axes(A) = static_axes(A) +static_axes(A) = axes(A) +static_axes(x::Tuple) = (SOneTo{length(x)}(),) +static_axes(bc::Broadcasted{Style{Tuple}}) = static_combine_axes(bc.args...) +Broadcast._axes(bc::Broadcasted{<:StaticArrayStyle}, ::Nothing) = static_combine_axes(bc.args...) + +# instantiate overload +@inline function instantiate(B::Broadcasted{StaticArrayStyle{M}}) where M + if B.axes isa Tuple{Vararg{SOneTo}} || B.axes isa Tuple && length(B.axes) > M + return invoke(instantiate, Tuple{Broadcasted}, B) + elseif B.axes isa Nothing + ax = static_combine_axes(B.args...) + return Broadcasted{StaticArrayStyle{M}}(B.f, B.args, ax) + else + # We need to update B.axes for `broadcast!` if it's not static and `ndims(dest) < M`. + ax = static_check_broadcast_shape(B.axes, static_combine_axes(B.args...)) + return Broadcasted{StaticArrayStyle{M}}(B.f, B.args, ax) + end +end +@inline function static_check_broadcast_shape(shp::Tuple, Ashp::Tuple{Vararg{SOneTo}}) + ax1 = if length(Ashp[1]) == 1 + shp[1] + elseif Ashp[1] == shp[1] + Ashp[1] + else + throw(DimensionMismatch("array could not be broadcast to match destination")) + end + return (ax1, static_check_broadcast_shape(Base.tail(shp), Base.tail(Ashp))...) +end +static_check_broadcast_shape(::Tuple{}, ::Tuple{SOneTo,Vararg{SOneTo}}) = + throw(DimensionMismatch("cannot broadcast array to have fewer non-singleton dimensions")) +static_check_broadcast_shape(::Tuple{}, ::Tuple{SOneTo{1},Vararg{SOneTo{1}}}) = () +static_check_broadcast_shape(::Tuple{}, ::Tuple{}) = () # copy overload @inline function Base.copy(B::Broadcasted{StaticArrayStyle{M}}) where M flat = Broadcast.flatten(B); as = flat.args; f = flat.f argsizes = broadcast_sizes(as...) - destsize = combine_sizes(argsizes) - _broadcast(f, destsize, argsizes, as...) + ax = axes(B) + ax isa Tuple{Vararg{SOneTo}} || error("Dimension is not static. Please file a bug.") + return _broadcast(f, Size(map(length, ax)), argsizes, as...) end # copyto! overloads @inline Base.copyto!(dest, B::Broadcasted{<:StaticArrayStyle}) = _copyto!(dest, B) @@ -32,18 +70,25 @@ end @inline function _copyto!(dest, B::Broadcasted{StaticArrayStyle{M}}) where M flat = Broadcast.flatten(B); as = flat.args; f = flat.f argsizes = broadcast_sizes(as...) - destsize = combine_sizes((Size(dest), argsizes...)) - if Length(destsize) === Length{Dynamic()}() - # destination dimension cannot be determined statically; fall back to generic broadcast! - return copyto!(dest, convert(Broadcasted{DefaultArrayStyle{M}}, B)) + ax = axes(B) + if ax isa Tuple{Vararg{SOneTo}} + @boundscheck axes(dest) == ax || Broadcast.throwdm(axes(dest), ax) + return _broadcast!(f, Size(map(length, ax)), dest, argsizes, as...) end - _broadcast!(f, destsize, dest, argsizes, as...) + # destination dimension cannot be determined statically; fall back to generic broadcast! + return copyto!(dest, convert(Broadcasted{DefaultArrayStyle{M}}, B)) end # Resolving priority between dynamic and static axes _bcs1(a::SOneTo, b::SOneTo) = _bcsm(b, a) ? b : (_bcsm(a, b) ? a : throw(DimensionMismatch("arrays could not be broadcast to a common size"))) -_bcs1(a::SOneTo, b::Base.OneTo) = _bcs1(Base.OneTo(a), b) -_bcs1(a::Base.OneTo, b::SOneTo) = _bcs1(a, Base.OneTo(b)) +function _bcs1(a::SOneTo, b::Base.OneTo) + length(a) == 1 && return b + if length(b) != length(a) && length(b) != 1 + throw(DimensionMismatch("arrays could not be broadcast to a common size")) + end + return a +end +_bcs1(a::Base.OneTo, b::SOneTo) = _bcs1(b, a) ################################################### ## Internal broadcast machinery for StaticArrays ## @@ -58,45 +103,13 @@ broadcast_indices(A::StaticArray) = indices(A) @inline broadcast_size(a::AbstractArray) = Size(a) @inline broadcast_size(a::Tuple) = Size(length(a)) -function broadcasted_index(oldsize, newindex) - index = ones(Int, length(oldsize)) - for i = 1:length(oldsize) - if oldsize[i] != 1 - index[i] = newindex[i] - end - end - return LinearIndices(oldsize)[index...] -end - -# similar to Base.Broadcast.combine_indices: -@generated function combine_sizes(s::Tuple{Vararg{Size}}) - sizes = [sz.parameters[1] for sz ∈ s.parameters] - ndims = 0 - for i = 1:length(sizes) - ndims = max(ndims, length(sizes[i])) - end - newsize = StaticDimension[Dynamic() for _ = 1 : ndims] - for i = 1:length(sizes) - s = sizes[i] - for j = 1:length(s) - if s[j] isa Dynamic - continue - elseif newsize[j] isa Dynamic || newsize[j] == 1 - newsize[j] = s[j] - elseif newsize[j] ≠ s[j] && s[j] ≠ 1 - throw(DimensionMismatch("Tried to broadcast on inputs sized $sizes")) - end - end - end - quote - @_inline_meta - Size($(tuple(newsize...))) - end +broadcast_getindex(::Tuple{}, i::Int, I::CartesianIndex) = return :(_broadcast_getindex(a[$i], $I)) +function broadcast_getindex(oldsize::Tuple, i::Int, newindex::CartesianIndex) + li = LinearIndices(oldsize) + ind = _broadcast_getindex(li, newindex) + return :(a[$i][$ind]) end -scalar_getindex(x) = x -scalar_getindex(x::Ref) = x[] - isstatic(::StaticArray) = true isstatic(::Transpose{<:Any, <:StaticArray}) = true isstatic(::Adjoint{<:Any, <:StaticArray}) = true @@ -120,13 +133,11 @@ end @generated function __broadcast(f, ::Size{newsize}, s::Tuple{Vararg{Size}}, a...) where newsize sizes = [sz.parameters[1] for sz ∈ s.parameters] + indices = CartesianIndices(newsize) exprs = similar(indices, Expr) for (j, current_ind) ∈ enumerate(indices) - exprs_vals = [ - (!(a[i] <: AbstractArray || a[i] <: Tuple) ? :(scalar_getindex(a[$i])) : :(a[$i][$(broadcasted_index(sizes[i], current_ind))])) - for i = 1:length(sizes) - ] + exprs_vals = (broadcast_getindex(sz, i, current_ind) for (i, sz) in enumerate(sizes)) exprs[j] = :(f($(exprs_vals...))) end @@ -140,27 +151,18 @@ end ## Internal broadcast! machinery for StaticArrays ## #################################################### -@generated function _broadcast!(f, ::Size{newsize}, dest::AbstractArray, s::Tuple{Vararg{Size}}, as...) where {newsize} - sizes = [sz.parameters[1] for sz ∈ s.parameters] - sizes = tuple(sizes...) - - # TODO: this could also be done outside the generated function: - sizematch(Size{newsize}(), Size(dest)) || - throw(DimensionMismatch("Tried to broadcast to destination sized $newsize from inputs sized $sizes")) +@generated function _broadcast!(f, ::Size{newsize}, dest::AbstractArray, s::Tuple{Vararg{Size}}, a...) where {newsize} + sizes = [sz.parameters[1] for sz in s.parameters] indices = CartesianIndices(newsize) exprs = similar(indices, Expr) for (j, current_ind) ∈ enumerate(indices) - exprs_vals = [ - (!(as[i] <: AbstractArray || as[i] <: Tuple) ? :(as[$i][]) : :(as[$i][$(broadcasted_index(sizes[i], current_ind))])) - for i = 1:length(sizes) - ] + exprs_vals = (broadcast_getindex(sz, i, current_ind) for (i, sz) in enumerate(sizes)) exprs[j] = :(dest[$j] = f($(exprs_vals...))) end return quote - @_propagate_inbounds_meta - @boundscheck sizematch($(Size{newsize}()), dest) || throw(DimensionMismatch("array could not be broadcast to match destination")) + @_inline_meta @inbounds $(Expr(:block, exprs...)) return dest end diff --git a/src/precompile.jl b/src/precompile.jl index 90e0c2af..63305326 100644 --- a/src/precompile.jl +++ b/src/precompile.jl @@ -24,6 +24,10 @@ function _precompile_() # Some expensive generators @assert precompile(Tuple{typeof(which(__broadcast,(Any,Size,Tuple{Vararg{Size}},Vararg{Any},)).generator.gen),Any,Any,Any,Any,Any,Any}) @assert precompile(Tuple{typeof(which(_zeros,(Size,Type{<:StaticArray},)).generator.gen),Any,Any,Any,Type,Any}) - @assert precompile(Tuple{typeof(which(combine_sizes,(Tuple{Vararg{Size}},)).generator.gen),Any,Any}) @assert precompile(Tuple{typeof(which(_mapfoldl,(Any,Any,Colon,Any,Size,Vararg{StaticArray},)).generator.gen),Any,Any,Any,Any,Any,Any,Any,Any}) + + # broadcast_getindex + for m = 0:5, n = m:5 + @assert precompile(Tuple{typeof(broadcast_getindex),NTuple{m,Int},Int,CartesianIndex{n}}) + end end diff --git a/test/broadcast.jl b/test/broadcast.jl index 1baa4f28..937aa723 100644 --- a/test/broadcast.jl +++ b/test/broadcast.jl @@ -315,3 +315,23 @@ end end end end + +@testset "instantiate with axes updated" begin + f(a; ax = nothing) = Broadcast.Broadcasted{StaticArrays.StaticArrayStyle{ndims(a)}}(+,(a,),ax) + a = @SArray zeros(2,2,2) + ax = Base.OneTo(2), Base.OneTo(2), Base.OneTo(2) + @test @inferred(Broadcast.instantiate(f(a; ax))).axes isa NTuple{3,SOneTo} + ax = (ax..., Base.OneTo(2)) + @test @inferred(Broadcast.instantiate(f(a; ax))).axes isa NTuple{4,Base.OneTo} + ax = setindex(ax, Base.OneTo(1), 4) + @test @inferred(Broadcast.instantiate(f(a; ax))).axes isa NTuple{4,Base.OneTo} + a = @SArray zeros(2,1,2) + ax = Base.OneTo(2), Base.OneTo(2), Base.OneTo(2) + @test @inferred(Broadcast.instantiate(f(a; ax))).axes isa Tuple{SOneTo,Base.OneTo,SOneTo} + @test_throws DimensionMismatch Broadcast.instantiate(f(a; ax = ax[1:2])) + + a = @SArray zeros(2,2,1) + ax = Base.OneTo(2), Base.OneTo(2), Base.OneTo(2) + @test @inferred(Broadcast.instantiate(f(a; ax))).axes isa Tuple{SOneTo,SOneTo,Base.OneTo} + @test @inferred(Broadcast.instantiate(f(a; ax = ax[1:2]))).axes isa NTuple{2,SOneTo} +end