Skip to content

Drop combine_size #1008

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Mar 16, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/abstractarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ Base.axes(s::StaticArray) = _axes(Size(s))
end
Base.axes(rv::Adjoint{<:Any,<:StaticVector}) = (SOneTo(1), axes(rv.parent)...)
Base.axes(rv::Transpose{<:Any,<:StaticVector}) = (SOneTo(1), axes(rv.parent)...)
Base.axes(d::Diagonal{<:Any,<:StaticVector}) = (ax = axes(d.diag, 1); (ax, ax))

Base.eachindex(::IndexLinear, a::StaticArray) = SOneTo(length(a))

Expand Down
136 changes: 69 additions & 67 deletions src/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,48 +2,93 @@
## broadcast! ##
################

import Base.Broadcast:
BroadcastStyle, AbstractArrayStyle, Broadcasted, DefaultArrayStyle, materialize!
using Base.Broadcast: AbstractArrayStyle, DefaultArrayStyle, Style, Broadcasted
using Base.Broadcast: broadcast_shape, _broadcast_getindex, combine_axes
import Base.Broadcast: BroadcastStyle, materialize!, instantiate
import Base.Broadcast: _bcs1 # for SOneTo axis information
using Base.Broadcast: _bcsm
# Add a new BroadcastStyle for StaticArrays, derived from AbstractArrayStyle
# A constructor that changes the style parameter N (array dimension) is also required
struct StaticArrayStyle{N} <: AbstractArrayStyle{N} end
StaticArrayStyle{M}(::Val{N}) where {M,N} = StaticArrayStyle{N}()
BroadcastStyle(::Type{<:StaticArray{<:Tuple, <:Any, N}}) where {N} = StaticArrayStyle{N}()
BroadcastStyle(::Type{<:Transpose{<:Any, <:StaticArray{<:Tuple, <:Any, N}}}) where {N} = StaticArrayStyle{N}()
BroadcastStyle(::Type{<:Adjoint{<:Any, <:StaticArray{<:Tuple, <:Any, N}}}) where {N} = StaticArrayStyle{N}()
BroadcastStyle(::Type{<:Transpose{<:Any, <:StaticArray}}) = StaticArrayStyle{2}()
BroadcastStyle(::Type{<:Adjoint{<:Any, <:StaticArray}}) = StaticArrayStyle{2}()
BroadcastStyle(::Type{<:Diagonal{<:Any, <:StaticArray{<:Tuple, <:Any, 1}}}) = StaticArrayStyle{2}()
# Precedence rules
BroadcastStyle(::StaticArrayStyle{M}, ::DefaultArrayStyle{N}) where {M,N} =
DefaultArrayStyle(Val(max(M, N)))
BroadcastStyle(::StaticArrayStyle{M}, ::DefaultArrayStyle{0}) where {M} =
StaticArrayStyle{M}()

# combine_axes overload (for Tuple)
@inline static_combine_axes(A, B...) = broadcast_shape(static_axes(A), static_combine_axes(B...))
static_combine_axes(A) = static_axes(A)
static_axes(A) = axes(A)
static_axes(x::Tuple) = (SOneTo{length(x)}(),)
static_axes(bc::Broadcasted{Style{Tuple}}) = static_combine_axes(bc.args...)
Broadcast._axes(bc::Broadcasted{<:StaticArrayStyle}, ::Nothing) = static_combine_axes(bc.args...)

# instantiate overload
@inline function instantiate(B::Broadcasted{StaticArrayStyle{M}}) where M
if B.axes isa Tuple{Vararg{SOneTo}} || B.axes isa Tuple && length(B.axes) > M
return invoke(instantiate, Tuple{Broadcasted}, B)
elseif B.axes isa Nothing
ax = static_combine_axes(B.args...)
return Broadcasted{StaticArrayStyle{M}}(B.f, B.args, ax)
else
# We need to update B.axes for `broadcast!` if it's not static and `ndims(dest) < M`.
ax = static_check_broadcast_shape(B.axes, static_combine_axes(B.args...))
return Broadcasted{StaticArrayStyle{M}}(B.f, B.args, ax)
end
end
@inline function static_check_broadcast_shape(shp::Tuple, Ashp::Tuple{Vararg{SOneTo}})
ax1 = if length(Ashp[1]) == 1
shp[1]
elseif Ashp[1] == shp[1]
Ashp[1]
else
throw(DimensionMismatch("array could not be broadcast to match destination"))
end
return (ax1, static_check_broadcast_shape(Base.tail(shp), Base.tail(Ashp))...)
end
static_check_broadcast_shape(::Tuple{}, ::Tuple{SOneTo,Vararg{SOneTo}}) =
throw(DimensionMismatch("cannot broadcast array to have fewer non-singleton dimensions"))
static_check_broadcast_shape(::Tuple{}, ::Tuple{SOneTo{1},Vararg{SOneTo{1}}}) = ()
static_check_broadcast_shape(::Tuple{}, ::Tuple{}) = ()
# copy overload
@inline function Base.copy(B::Broadcasted{StaticArrayStyle{M}}) where M
flat = Broadcast.flatten(B); as = flat.args; f = flat.f
argsizes = broadcast_sizes(as...)
destsize = combine_sizes(argsizes)
_broadcast(f, destsize, argsizes, as...)
ax = axes(B)
ax isa Tuple{Vararg{SOneTo}} || error("Dimension is not static. Please file a bug.")
return _broadcast(f, Size(map(length, ax)), argsizes, as...)
end
# copyto! overloads
@inline Base.copyto!(dest, B::Broadcasted{<:StaticArrayStyle}) = _copyto!(dest, B)
@inline Base.copyto!(dest::AbstractArray, B::Broadcasted{<:StaticArrayStyle}) = _copyto!(dest, B)
@inline function _copyto!(dest, B::Broadcasted{StaticArrayStyle{M}}) where M
flat = Broadcast.flatten(B); as = flat.args; f = flat.f
argsizes = broadcast_sizes(as...)
destsize = combine_sizes((Size(dest), argsizes...))
if Length(destsize) === Length{Dynamic()}()
# destination dimension cannot be determined statically; fall back to generic broadcast!
return copyto!(dest, convert(Broadcasted{DefaultArrayStyle{M}}, B))
ax = axes(B)
if ax isa Tuple{Vararg{SOneTo}}
@boundscheck axes(dest) == ax || Broadcast.throwdm(axes(dest), ax)
return _broadcast!(f, Size(map(length, ax)), dest, argsizes, as...)
end
_broadcast!(f, destsize, dest, argsizes, as...)
# destination dimension cannot be determined statically; fall back to generic broadcast!
return copyto!(dest, convert(Broadcasted{DefaultArrayStyle{M}}, B))
end

# Resolving priority between dynamic and static axes
_bcs1(a::SOneTo, b::SOneTo) = _bcsm(b, a) ? b : (_bcsm(a, b) ? a : throw(DimensionMismatch("arrays could not be broadcast to a common size")))
_bcs1(a::SOneTo, b::Base.OneTo) = _bcs1(Base.OneTo(a), b)
_bcs1(a::Base.OneTo, b::SOneTo) = _bcs1(a, Base.OneTo(b))
function _bcs1(a::SOneTo, b::Base.OneTo)
length(a) == 1 && return b
if length(b) != length(a) && length(b) != 1
throw(DimensionMismatch("arrays could not be broadcast to a common size"))
end
return a
end
_bcs1(a::Base.OneTo, b::SOneTo) = _bcs1(b, a)

###################################################
## Internal broadcast machinery for StaticArrays ##
Expand All @@ -58,45 +103,13 @@ broadcast_indices(A::StaticArray) = indices(A)
@inline broadcast_size(a::AbstractArray) = Size(a)
@inline broadcast_size(a::Tuple) = Size(length(a))

function broadcasted_index(oldsize, newindex)
index = ones(Int, length(oldsize))
for i = 1:length(oldsize)
if oldsize[i] != 1
index[i] = newindex[i]
end
end
return LinearIndices(oldsize)[index...]
end

# similar to Base.Broadcast.combine_indices:
@generated function combine_sizes(s::Tuple{Vararg{Size}})
sizes = [sz.parameters[1] for sz ∈ s.parameters]
ndims = 0
for i = 1:length(sizes)
ndims = max(ndims, length(sizes[i]))
end
newsize = StaticDimension[Dynamic() for _ = 1 : ndims]
for i = 1:length(sizes)
s = sizes[i]
for j = 1:length(s)
if s[j] isa Dynamic
continue
elseif newsize[j] isa Dynamic || newsize[j] == 1
newsize[j] = s[j]
elseif newsize[j] ≠ s[j] && s[j] ≠ 1
throw(DimensionMismatch("Tried to broadcast on inputs sized $sizes"))
end
end
end
quote
@_inline_meta
Size($(tuple(newsize...)))
end
broadcast_getindex(::Tuple{}, i::Int, I::CartesianIndex) = return :(_broadcast_getindex(a[$i], $I))
function broadcast_getindex(oldsize::Tuple, i::Int, newindex::CartesianIndex)
li = LinearIndices(oldsize)
ind = _broadcast_getindex(li, newindex)
return :(a[$i][$ind])
end

scalar_getindex(x) = x
scalar_getindex(x::Ref) = x[]

isstatic(::StaticArray) = true
isstatic(::Transpose{<:Any, <:StaticArray}) = true
isstatic(::Adjoint{<:Any, <:StaticArray}) = true
Expand All @@ -120,13 +133,11 @@ end

@generated function __broadcast(f, ::Size{newsize}, s::Tuple{Vararg{Size}}, a...) where newsize
sizes = [sz.parameters[1] for sz ∈ s.parameters]

indices = CartesianIndices(newsize)
exprs = similar(indices, Expr)
for (j, current_ind) ∈ enumerate(indices)
exprs_vals = [
(!(a[i] <: AbstractArray || a[i] <: Tuple) ? :(scalar_getindex(a[$i])) : :(a[$i][$(broadcasted_index(sizes[i], current_ind))]))
for i = 1:length(sizes)
]
exprs_vals = (broadcast_getindex(sz, i, current_ind) for (i, sz) in enumerate(sizes))
exprs[j] = :(f($(exprs_vals...)))
end

Expand All @@ -140,27 +151,18 @@ end
## Internal broadcast! machinery for StaticArrays ##
####################################################

@generated function _broadcast!(f, ::Size{newsize}, dest::AbstractArray, s::Tuple{Vararg{Size}}, as...) where {newsize}
sizes = [sz.parameters[1] for sz ∈ s.parameters]
sizes = tuple(sizes...)

# TODO: this could also be done outside the generated function:
sizematch(Size{newsize}(), Size(dest)) ||
throw(DimensionMismatch("Tried to broadcast to destination sized $newsize from inputs sized $sizes"))
@generated function _broadcast!(f, ::Size{newsize}, dest::AbstractArray, s::Tuple{Vararg{Size}}, a...) where {newsize}
sizes = [sz.parameters[1] for sz in s.parameters]

indices = CartesianIndices(newsize)
exprs = similar(indices, Expr)
for (j, current_ind) ∈ enumerate(indices)
exprs_vals = [
(!(as[i] <: AbstractArray || as[i] <: Tuple) ? :(as[$i][]) : :(as[$i][$(broadcasted_index(sizes[i], current_ind))]))
for i = 1:length(sizes)
]
exprs_vals = (broadcast_getindex(sz, i, current_ind) for (i, sz) in enumerate(sizes))
exprs[j] = :(dest[$j] = f($(exprs_vals...)))
end

return quote
@_propagate_inbounds_meta
@boundscheck sizematch($(Size{newsize}()), dest) || throw(DimensionMismatch("array could not be broadcast to match destination"))
@_inline_meta
@inbounds $(Expr(:block, exprs...))
return dest
end
Expand Down
6 changes: 5 additions & 1 deletion src/precompile.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@ function _precompile_()
# Some expensive generators
@assert precompile(Tuple{typeof(which(__broadcast,(Any,Size,Tuple{Vararg{Size}},Vararg{Any},)).generator.gen),Any,Any,Any,Any,Any,Any})
@assert precompile(Tuple{typeof(which(_zeros,(Size,Type{<:StaticArray},)).generator.gen),Any,Any,Any,Type,Any})
@assert precompile(Tuple{typeof(which(combine_sizes,(Tuple{Vararg{Size}},)).generator.gen),Any,Any})
@assert precompile(Tuple{typeof(which(_mapfoldl,(Any,Any,Colon,Any,Size,Vararg{StaticArray},)).generator.gen),Any,Any,Any,Any,Any,Any,Any,Any})

# broadcast_getindex
for m = 0:5, n = m:5
@assert precompile(Tuple{typeof(broadcast_getindex),NTuple{m,Int},Int,CartesianIndex{n}})
end
end
20 changes: 20 additions & 0 deletions test/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -315,3 +315,23 @@ end
end
end
end

@testset "instantiate with axes updated" begin
f(a; ax = nothing) = Broadcast.Broadcasted{StaticArrays.StaticArrayStyle{ndims(a)}}(+,(a,),ax)
a = @SArray zeros(2,2,2)
ax = Base.OneTo(2), Base.OneTo(2), Base.OneTo(2)
@test @inferred(Broadcast.instantiate(f(a; ax))).axes isa NTuple{3,SOneTo}
ax = (ax..., Base.OneTo(2))
@test @inferred(Broadcast.instantiate(f(a; ax))).axes isa NTuple{4,Base.OneTo}
ax = setindex(ax, Base.OneTo(1), 4)
@test @inferred(Broadcast.instantiate(f(a; ax))).axes isa NTuple{4,Base.OneTo}
a = @SArray zeros(2,1,2)
ax = Base.OneTo(2), Base.OneTo(2), Base.OneTo(2)
@test @inferred(Broadcast.instantiate(f(a; ax))).axes isa Tuple{SOneTo,Base.OneTo,SOneTo}
@test_throws DimensionMismatch Broadcast.instantiate(f(a; ax = ax[1:2]))

a = @SArray zeros(2,2,1)
ax = Base.OneTo(2), Base.OneTo(2), Base.OneTo(2)
@test @inferred(Broadcast.instantiate(f(a; ax))).axes isa Tuple{SOneTo,SOneTo,Base.OneTo}
@test @inferred(Broadcast.instantiate(f(a; ax = ax[1:2]))).axes isa NTuple{2,SOneTo}
end