Skip to content

Commit 722627e

Browse files
committed
Drop combine_size
1. `axes` should be static. 2. The implementation for broadcast(1) are unified) 3. simplify `broadcast_index`
1 parent f04ca91 commit 722627e

File tree

3 files changed

+28
-62
lines changed

3 files changed

+28
-62
lines changed

src/abstractarray.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ Base.axes(s::StaticArray) = _axes(Size(s))
1414
end
1515
Base.axes(rv::Adjoint{<:Any,<:StaticVector}) = (SOneTo(1), axes(rv.parent)...)
1616
Base.axes(rv::Transpose{<:Any,<:StaticVector}) = (SOneTo(1), axes(rv.parent)...)
17+
Base.axes(d::Diagonal{<:Any,<:StaticVector}) = (ax = axes(d.diag, 1); (ax, ax))
1718

1819
Base.eachindex(::IndexLinear, a::StaticArray) = SOneTo(length(a))
1920

src/broadcast.jl

Lines changed: 22 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -60,21 +60,25 @@ static_check_broadcast_shape(::Tuple{}, ::Tuple{}) = ()
6060
@inline function Base.copy(B::Broadcasted{StaticArrayStyle{M}}) where M
6161
flat = Broadcast.flatten(B); as = flat.args; f = flat.f
6262
argsizes = broadcast_sizes(as...)
63-
destsize = combine_sizes(argsizes)
64-
_broadcast(f, destsize, argsizes, as...)
63+
ax = axes(B)
64+
if ax isa Tuple{Vararg{SOneTo}}
65+
return _broadcast(f, Size(map(length, ax)), argsizes, as...)
66+
end
67+
return copy(convert(Broadcasted{DefaultArrayStyle{M}}, B))
6568
end
6669
# copyto! overloads
6770
@inline Base.copyto!(dest, B::Broadcasted{<:StaticArrayStyle}) = _copyto!(dest, B)
6871
@inline Base.copyto!(dest::AbstractArray, B::Broadcasted{<:StaticArrayStyle}) = _copyto!(dest, B)
6972
@inline function _copyto!(dest, B::Broadcasted{StaticArrayStyle{M}}) where M
7073
flat = Broadcast.flatten(B); as = flat.args; f = flat.f
7174
argsizes = broadcast_sizes(as...)
72-
destsize = combine_sizes((Size(dest), argsizes...))
73-
if Length(destsize) === Length{Dynamic()}()
74-
# destination dimension cannot be determined statically; fall back to generic broadcast!
75-
return copyto!(dest, convert(Broadcasted{DefaultArrayStyle{M}}, B))
75+
ax = axes(B)
76+
if ax isa Tuple{Vararg{SOneTo}}
77+
@boundscheck axes(dest) == ax || Broadcast.throwdm(axes(dest), ax)
78+
return _broadcast!(f, Size(map(length, ax)), dest, argsizes, as...)
7679
end
77-
_broadcast!(f, destsize, dest, argsizes, as...)
80+
# destination dimension cannot be determined statically; fall back to generic broadcast!
81+
return copyto!(dest, convert(Broadcasted{DefaultArrayStyle{M}}, B))
7882
end
7983

8084
# Resolving priority between dynamic and static axes
@@ -101,45 +105,13 @@ broadcast_indices(A::StaticArray) = indices(A)
101105
@inline broadcast_size(a::AbstractArray) = Size(a)
102106
@inline broadcast_size(a::Tuple) = Size(length(a))
103107

104-
function broadcasted_index(oldsize, newindex)
105-
index = ones(Int, length(oldsize))
106-
for i = 1:length(oldsize)
107-
if oldsize[i] != 1
108-
index[i] = newindex[i]
109-
end
110-
end
111-
return LinearIndices(oldsize)[index...]
112-
end
113-
114-
# similar to Base.Broadcast.combine_indices:
115-
@generated function combine_sizes(s::Tuple{Vararg{Size}})
116-
sizes = [sz.parameters[1] for sz s.parameters]
117-
ndims = 0
118-
for i = 1:length(sizes)
119-
ndims = max(ndims, length(sizes[i]))
120-
end
121-
newsize = StaticDimension[Dynamic() for _ = 1 : ndims]
122-
for i = 1:length(sizes)
123-
s = sizes[i]
124-
for j = 1:length(s)
125-
if s[j] isa Dynamic
126-
continue
127-
elseif newsize[j] isa Dynamic || newsize[j] == 1
128-
newsize[j] = s[j]
129-
elseif newsize[j] s[j] && s[j] 1
130-
throw(DimensionMismatch("Tried to broadcast on inputs sized $sizes"))
131-
end
132-
end
133-
end
134-
quote
135-
@_inline_meta
136-
Size($(tuple(newsize...)))
137-
end
108+
broadcast_getindex(::Tuple{}, i::Int, I::CartesianIndex) = return :(_broadcast_getindex(a[$i], $I))
109+
function broadcast_getindex(oldsize::Tuple, i::Int, newindex::CartesianIndex)
110+
li = LinearIndices(oldsize)
111+
ind = _broadcast_getindex(li, newindex)
112+
return :(a[$i][$ind])
138113
end
139114

140-
scalar_getindex(x) = x
141-
scalar_getindex(x::Ref) = x[]
142-
143115
isstatic(::StaticArray) = true
144116
isstatic(::Transpose{<:Any, <:StaticArray}) = true
145117
isstatic(::Adjoint{<:Any, <:StaticArray}) = true
@@ -163,13 +135,11 @@ end
163135

164136
@generated function __broadcast(f, ::Size{newsize}, s::Tuple{Vararg{Size}}, a...) where newsize
165137
sizes = [sz.parameters[1] for sz s.parameters]
138+
166139
indices = CartesianIndices(newsize)
167140
exprs = similar(indices, Expr)
168141
for (j, current_ind) enumerate(indices)
169-
exprs_vals = [
170-
(!(a[i] <: AbstractArray || a[i] <: Tuple) ? :(scalar_getindex(a[$i])) : :(a[$i][$(broadcasted_index(sizes[i], current_ind))]))
171-
for i = 1:length(sizes)
172-
]
142+
exprs_vals = (broadcast_getindex(sz, i, current_ind) for (i, sz) in enumerate(sizes))
173143
exprs[j] = :(f($(exprs_vals...)))
174144
end
175145

@@ -183,27 +153,18 @@ end
183153
## Internal broadcast! machinery for StaticArrays ##
184154
####################################################
185155

186-
@generated function _broadcast!(f, ::Size{newsize}, dest::AbstractArray, s::Tuple{Vararg{Size}}, as...) where {newsize}
187-
sizes = [sz.parameters[1] for sz s.parameters]
188-
sizes = tuple(sizes...)
189-
190-
# TODO: this could also be done outside the generated function:
191-
sizematch(Size{newsize}(), Size(dest)) ||
192-
throw(DimensionMismatch("Tried to broadcast to destination sized $newsize from inputs sized $sizes"))
156+
@generated function _broadcast!(f, ::Size{newsize}, dest::AbstractArray, s::Tuple{Vararg{Size}}, a...) where {newsize}
157+
sizes = [sz.parameters[1] for sz in s.parameters]
193158

194159
indices = CartesianIndices(newsize)
195160
exprs = similar(indices, Expr)
196161
for (j, current_ind) enumerate(indices)
197-
exprs_vals = [
198-
(!(as[i] <: AbstractArray || as[i] <: Tuple) ? :(as[$i][]) : :(as[$i][$(broadcasted_index(sizes[i], current_ind))]))
199-
for i = 1:length(sizes)
200-
]
162+
exprs_vals = (broadcast_getindex(sz, i, current_ind) for (i, sz) in enumerate(sizes))
201163
exprs[j] = :(dest[$j] = f($(exprs_vals...)))
202164
end
203165

204166
return quote
205-
@_propagate_inbounds_meta
206-
@boundscheck sizematch($(Size{newsize}()), dest) || throw(DimensionMismatch("array could not be broadcast to match destination"))
167+
@_inline_meta
207168
@inbounds $(Expr(:block, exprs...))
208169
return dest
209170
end

src/precompile.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,10 @@ function _precompile_()
2424
# Some expensive generators
2525
@assert precompile(Tuple{typeof(which(__broadcast,(Any,Size,Tuple{Vararg{Size}},Vararg{Any},)).generator.gen),Any,Any,Any,Any,Any,Any})
2626
@assert precompile(Tuple{typeof(which(_zeros,(Size,Type{<:StaticArray},)).generator.gen),Any,Any,Any,Type,Any})
27-
@assert precompile(Tuple{typeof(which(combine_sizes,(Tuple{Vararg{Size}},)).generator.gen),Any,Any})
2827
@assert precompile(Tuple{typeof(which(_mapfoldl,(Any,Any,Colon,Any,Size,Vararg{StaticArray},)).generator.gen),Any,Any,Any,Any,Any,Any,Any,Any})
28+
29+
# broadcast_getindex
30+
for m = 0:5, n = m:5
31+
@assert precompile(Tuple{typeof(broadcast_getindex),NTuple{m,Int},Int,CartesianIndex{n}})
32+
end
2933
end

0 commit comments

Comments
 (0)