Skip to content

Commit b09880b

Browse files
N5N3thchr
andauthored
Drop combine_size (#1008)
* Fix `Transpose`/`Adjoint`'s Style * Make `instantiate` generate static axes add `static_combine_axes` to handle Tuple correctly * Drop `combine_size` 1. `axes` should be static. 2. The implementation for broadcast(1) are unified) 3. simplify `broadcast_index` * Add test * add `return` Co-Authored-By: Thomas Christensen <tchr@mit.edu> * import clean * Remove broken branch. Co-authored-by: Thomas Christensen <tchr@mit.edu>
1 parent e35134c commit b09880b

File tree

4 files changed

+95
-68
lines changed

4 files changed

+95
-68
lines changed

src/abstractarray.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ Base.axes(s::StaticArray) = _axes(Size(s))
1414
end
1515
Base.axes(rv::Adjoint{<:Any,<:StaticVector}) = (SOneTo(1), axes(rv.parent)...)
1616
Base.axes(rv::Transpose{<:Any,<:StaticVector}) = (SOneTo(1), axes(rv.parent)...)
17+
Base.axes(d::Diagonal{<:Any,<:StaticVector}) = (ax = axes(d.diag, 1); (ax, ax))
1718

1819
Base.eachindex(::IndexLinear, a::StaticArray) = SOneTo(length(a))
1920

src/broadcast.jl

Lines changed: 69 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -2,48 +2,93 @@
22
## broadcast! ##
33
################
44

5-
import Base.Broadcast:
6-
BroadcastStyle, AbstractArrayStyle, Broadcasted, DefaultArrayStyle, materialize!
5+
using Base.Broadcast: AbstractArrayStyle, DefaultArrayStyle, Style, Broadcasted
6+
using Base.Broadcast: broadcast_shape, _broadcast_getindex, combine_axes
7+
import Base.Broadcast: BroadcastStyle, materialize!, instantiate
78
import Base.Broadcast: _bcs1 # for SOneTo axis information
89
using Base.Broadcast: _bcsm
910
# Add a new BroadcastStyle for StaticArrays, derived from AbstractArrayStyle
1011
# A constructor that changes the style parameter N (array dimension) is also required
1112
struct StaticArrayStyle{N} <: AbstractArrayStyle{N} end
1213
StaticArrayStyle{M}(::Val{N}) where {M,N} = StaticArrayStyle{N}()
1314
BroadcastStyle(::Type{<:StaticArray{<:Tuple, <:Any, N}}) where {N} = StaticArrayStyle{N}()
14-
BroadcastStyle(::Type{<:Transpose{<:Any, <:StaticArray{<:Tuple, <:Any, N}}}) where {N} = StaticArrayStyle{N}()
15-
BroadcastStyle(::Type{<:Adjoint{<:Any, <:StaticArray{<:Tuple, <:Any, N}}}) where {N} = StaticArrayStyle{N}()
15+
BroadcastStyle(::Type{<:Transpose{<:Any, <:StaticArray}}) = StaticArrayStyle{2}()
16+
BroadcastStyle(::Type{<:Adjoint{<:Any, <:StaticArray}}) = StaticArrayStyle{2}()
1617
BroadcastStyle(::Type{<:Diagonal{<:Any, <:StaticArray{<:Tuple, <:Any, 1}}}) = StaticArrayStyle{2}()
1718
# Precedence rules
1819
BroadcastStyle(::StaticArrayStyle{M}, ::DefaultArrayStyle{N}) where {M,N} =
1920
DefaultArrayStyle(Val(max(M, N)))
2021
BroadcastStyle(::StaticArrayStyle{M}, ::DefaultArrayStyle{0}) where {M} =
2122
StaticArrayStyle{M}()
23+
24+
# combine_axes overload (for Tuple)
25+
@inline static_combine_axes(A, B...) = broadcast_shape(static_axes(A), static_combine_axes(B...))
26+
static_combine_axes(A) = static_axes(A)
27+
static_axes(A) = axes(A)
28+
static_axes(x::Tuple) = (SOneTo{length(x)}(),)
29+
static_axes(bc::Broadcasted{Style{Tuple}}) = static_combine_axes(bc.args...)
30+
Broadcast._axes(bc::Broadcasted{<:StaticArrayStyle}, ::Nothing) = static_combine_axes(bc.args...)
31+
32+
# instantiate overload
33+
@inline function instantiate(B::Broadcasted{StaticArrayStyle{M}}) where M
34+
if B.axes isa Tuple{Vararg{SOneTo}} || B.axes isa Tuple && length(B.axes) > M
35+
return invoke(instantiate, Tuple{Broadcasted}, B)
36+
elseif B.axes isa Nothing
37+
ax = static_combine_axes(B.args...)
38+
return Broadcasted{StaticArrayStyle{M}}(B.f, B.args, ax)
39+
else
40+
# We need to update B.axes for `broadcast!` if it's not static and `ndims(dest) < M`.
41+
ax = static_check_broadcast_shape(B.axes, static_combine_axes(B.args...))
42+
return Broadcasted{StaticArrayStyle{M}}(B.f, B.args, ax)
43+
end
44+
end
45+
@inline function static_check_broadcast_shape(shp::Tuple, Ashp::Tuple{Vararg{SOneTo}})
46+
ax1 = if length(Ashp[1]) == 1
47+
shp[1]
48+
elseif Ashp[1] == shp[1]
49+
Ashp[1]
50+
else
51+
throw(DimensionMismatch("array could not be broadcast to match destination"))
52+
end
53+
return (ax1, static_check_broadcast_shape(Base.tail(shp), Base.tail(Ashp))...)
54+
end
55+
static_check_broadcast_shape(::Tuple{}, ::Tuple{SOneTo,Vararg{SOneTo}}) =
56+
throw(DimensionMismatch("cannot broadcast array to have fewer non-singleton dimensions"))
57+
static_check_broadcast_shape(::Tuple{}, ::Tuple{SOneTo{1},Vararg{SOneTo{1}}}) = ()
58+
static_check_broadcast_shape(::Tuple{}, ::Tuple{}) = ()
2259
# copy overload
2360
@inline function Base.copy(B::Broadcasted{StaticArrayStyle{M}}) where M
2461
flat = Broadcast.flatten(B); as = flat.args; f = flat.f
2562
argsizes = broadcast_sizes(as...)
26-
destsize = combine_sizes(argsizes)
27-
_broadcast(f, destsize, argsizes, as...)
63+
ax = axes(B)
64+
ax isa Tuple{Vararg{SOneTo}} || error("Dimension is not static. Please file a bug.")
65+
return _broadcast(f, Size(map(length, ax)), argsizes, as...)
2866
end
2967
# copyto! overloads
3068
@inline Base.copyto!(dest, B::Broadcasted{<:StaticArrayStyle}) = _copyto!(dest, B)
3169
@inline Base.copyto!(dest::AbstractArray, B::Broadcasted{<:StaticArrayStyle}) = _copyto!(dest, B)
3270
@inline function _copyto!(dest, B::Broadcasted{StaticArrayStyle{M}}) where M
3371
flat = Broadcast.flatten(B); as = flat.args; f = flat.f
3472
argsizes = broadcast_sizes(as...)
35-
destsize = combine_sizes((Size(dest), argsizes...))
36-
if Length(destsize) === Length{Dynamic()}()
37-
# destination dimension cannot be determined statically; fall back to generic broadcast!
38-
return copyto!(dest, convert(Broadcasted{DefaultArrayStyle{M}}, B))
73+
ax = axes(B)
74+
if ax isa Tuple{Vararg{SOneTo}}
75+
@boundscheck axes(dest) == ax || Broadcast.throwdm(axes(dest), ax)
76+
return _broadcast!(f, Size(map(length, ax)), dest, argsizes, as...)
3977
end
40-
_broadcast!(f, destsize, dest, argsizes, as...)
78+
# destination dimension cannot be determined statically; fall back to generic broadcast!
79+
return copyto!(dest, convert(Broadcasted{DefaultArrayStyle{M}}, B))
4180
end
4281

4382
# Resolving priority between dynamic and static axes
4483
_bcs1(a::SOneTo, b::SOneTo) = _bcsm(b, a) ? b : (_bcsm(a, b) ? a : throw(DimensionMismatch("arrays could not be broadcast to a common size")))
45-
_bcs1(a::SOneTo, b::Base.OneTo) = _bcs1(Base.OneTo(a), b)
46-
_bcs1(a::Base.OneTo, b::SOneTo) = _bcs1(a, Base.OneTo(b))
84+
function _bcs1(a::SOneTo, b::Base.OneTo)
85+
length(a) == 1 && return b
86+
if length(b) != length(a) && length(b) != 1
87+
throw(DimensionMismatch("arrays could not be broadcast to a common size"))
88+
end
89+
return a
90+
end
91+
_bcs1(a::Base.OneTo, b::SOneTo) = _bcs1(b, a)
4792

4893
###################################################
4994
## Internal broadcast machinery for StaticArrays ##
@@ -56,45 +101,13 @@ _bcs1(a::Base.OneTo, b::SOneTo) = _bcs1(a, Base.OneTo(b))
56101
@inline broadcast_size(a::AbstractArray) = Size(a)
57102
@inline broadcast_size(a::Tuple) = Size(length(a))
58103

59-
function broadcasted_index(oldsize, newindex)
60-
index = ones(Int, length(oldsize))
61-
for i = 1:length(oldsize)
62-
if oldsize[i] != 1
63-
index[i] = newindex[i]
64-
end
65-
end
66-
return LinearIndices(oldsize)[index...]
67-
end
68-
69-
# similar to Base.Broadcast.combine_indices:
70-
@generated function combine_sizes(s::Tuple{Vararg{Size}})
71-
sizes = [sz.parameters[1] for sz s.parameters]
72-
ndims = 0
73-
for i = 1:length(sizes)
74-
ndims = max(ndims, length(sizes[i]))
75-
end
76-
newsize = StaticDimension[Dynamic() for _ = 1 : ndims]
77-
for i = 1:length(sizes)
78-
s = sizes[i]
79-
for j = 1:length(s)
80-
if s[j] isa Dynamic
81-
continue
82-
elseif newsize[j] isa Dynamic || newsize[j] == 1
83-
newsize[j] = s[j]
84-
elseif newsize[j] s[j] && s[j] 1
85-
throw(DimensionMismatch("Tried to broadcast on inputs sized $sizes"))
86-
end
87-
end
88-
end
89-
quote
90-
@_inline_meta
91-
Size($(tuple(newsize...)))
92-
end
104+
broadcast_getindex(::Tuple{}, i::Int, I::CartesianIndex) = return :(_broadcast_getindex(a[$i], $I))
105+
function broadcast_getindex(oldsize::Tuple, i::Int, newindex::CartesianIndex)
106+
li = LinearIndices(oldsize)
107+
ind = _broadcast_getindex(li, newindex)
108+
return :(a[$i][$ind])
93109
end
94110

95-
scalar_getindex(x) = x
96-
scalar_getindex(x::Ref) = x[]
97-
98111
isstatic(::StaticArray) = true
99112
isstatic(::Transpose{<:Any, <:StaticArray}) = true
100113
isstatic(::Adjoint{<:Any, <:StaticArray}) = true
@@ -118,13 +131,11 @@ end
118131

119132
@generated function __broadcast(f, ::Size{newsize}, s::Tuple{Vararg{Size}}, a...) where newsize
120133
sizes = [sz.parameters[1] for sz s.parameters]
134+
121135
indices = CartesianIndices(newsize)
122136
exprs = similar(indices, Expr)
123137
for (j, current_ind) enumerate(indices)
124-
exprs_vals = [
125-
(!(a[i] <: AbstractArray || a[i] <: Tuple) ? :(scalar_getindex(a[$i])) : :(a[$i][$(broadcasted_index(sizes[i], current_ind))]))
126-
for i = 1:length(sizes)
127-
]
138+
exprs_vals = (broadcast_getindex(sz, i, current_ind) for (i, sz) in enumerate(sizes))
128139
exprs[j] = :(f($(exprs_vals...)))
129140
end
130141

@@ -138,27 +149,18 @@ end
138149
## Internal broadcast! machinery for StaticArrays ##
139150
####################################################
140151

141-
@generated function _broadcast!(f, ::Size{newsize}, dest::AbstractArray, s::Tuple{Vararg{Size}}, as...) where {newsize}
142-
sizes = [sz.parameters[1] for sz s.parameters]
143-
sizes = tuple(sizes...)
144-
145-
# TODO: this could also be done outside the generated function:
146-
sizematch(Size{newsize}(), Size(dest)) ||
147-
throw(DimensionMismatch("Tried to broadcast to destination sized $newsize from inputs sized $sizes"))
152+
@generated function _broadcast!(f, ::Size{newsize}, dest::AbstractArray, s::Tuple{Vararg{Size}}, a...) where {newsize}
153+
sizes = [sz.parameters[1] for sz in s.parameters]
148154

149155
indices = CartesianIndices(newsize)
150156
exprs = similar(indices, Expr)
151157
for (j, current_ind) enumerate(indices)
152-
exprs_vals = [
153-
(!(as[i] <: AbstractArray || as[i] <: Tuple) ? :(as[$i][]) : :(as[$i][$(broadcasted_index(sizes[i], current_ind))]))
154-
for i = 1:length(sizes)
155-
]
158+
exprs_vals = (broadcast_getindex(sz, i, current_ind) for (i, sz) in enumerate(sizes))
156159
exprs[j] = :(dest[$j] = f($(exprs_vals...)))
157160
end
158161

159162
return quote
160-
@_propagate_inbounds_meta
161-
@boundscheck sizematch($(Size{newsize}()), dest) || throw(DimensionMismatch("array could not be broadcast to match destination"))
163+
@_inline_meta
162164
@inbounds $(Expr(:block, exprs...))
163165
return dest
164166
end

src/precompile.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,10 @@ function _precompile_()
2424
# Some expensive generators
2525
@assert precompile(Tuple{typeof(which(__broadcast,(Any,Size,Tuple{Vararg{Size}},Vararg{Any},)).generator.gen),Any,Any,Any,Any,Any,Any})
2626
@assert precompile(Tuple{typeof(which(_zeros,(Size,Type{<:StaticArray},)).generator.gen),Any,Any,Any,Type,Any})
27-
@assert precompile(Tuple{typeof(which(combine_sizes,(Tuple{Vararg{Size}},)).generator.gen),Any,Any})
2827
@assert precompile(Tuple{typeof(which(_mapfoldl,(Any,Any,Colon,Any,Size,Vararg{StaticArray},)).generator.gen),Any,Any,Any,Any,Any,Any,Any,Any})
28+
29+
# broadcast_getindex
30+
for m = 0:5, n = m:5
31+
@assert precompile(Tuple{typeof(broadcast_getindex),NTuple{m,Int},Int,CartesianIndex{n}})
32+
end
2933
end

test/broadcast.jl

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -315,3 +315,23 @@ end
315315
end
316316
end
317317
end
318+
319+
@testset "instantiate with axes updated" begin
320+
f(a; ax = nothing) = Broadcast.Broadcasted{StaticArrays.StaticArrayStyle{ndims(a)}}(+,(a,),ax)
321+
a = @SArray zeros(2,2,2)
322+
ax = Base.OneTo(2), Base.OneTo(2), Base.OneTo(2)
323+
@test @inferred(Broadcast.instantiate(f(a; ax))).axes isa NTuple{3,SOneTo}
324+
ax = (ax..., Base.OneTo(2))
325+
@test @inferred(Broadcast.instantiate(f(a; ax))).axes isa NTuple{4,Base.OneTo}
326+
ax = setindex(ax, Base.OneTo(1), 4)
327+
@test @inferred(Broadcast.instantiate(f(a; ax))).axes isa NTuple{4,Base.OneTo}
328+
a = @SArray zeros(2,1,2)
329+
ax = Base.OneTo(2), Base.OneTo(2), Base.OneTo(2)
330+
@test @inferred(Broadcast.instantiate(f(a; ax))).axes isa Tuple{SOneTo,Base.OneTo,SOneTo}
331+
@test_throws DimensionMismatch Broadcast.instantiate(f(a; ax = ax[1:2]))
332+
333+
a = @SArray zeros(2,2,1)
334+
ax = Base.OneTo(2), Base.OneTo(2), Base.OneTo(2)
335+
@test @inferred(Broadcast.instantiate(f(a; ax))).axes isa Tuple{SOneTo,SOneTo,Base.OneTo}
336+
@test @inferred(Broadcast.instantiate(f(a; ax = ax[1:2]))).axes isa NTuple{2,SOneTo}
337+
end

0 commit comments

Comments
 (0)