Skip to content

Commit a6208a7

Browse files
committed
Drop combine_size
1. `axes` should be static. 2. The implementation for broadcast(1) are unified) 3. simplify `broadcast_index`
1 parent 7826d5c commit a6208a7

File tree

3 files changed

+28
-62
lines changed

3 files changed

+28
-62
lines changed

src/abstractarray.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ Base.axes(s::StaticArray) = _axes(Size(s))
1414
end
1515
Base.axes(rv::Adjoint{<:Any,<:StaticVector}) = (SOneTo(1), axes(rv.parent)...)
1616
Base.axes(rv::Transpose{<:Any,<:StaticVector}) = (SOneTo(1), axes(rv.parent)...)
17+
Base.axes(d::Diagonal{<:Any,<:StaticVector}) = (ax = axes(d.diag, 1); (ax, ax))
1718

1819
Base.eachindex(::IndexLinear, a::StaticArray) = SOneTo(length(a))
1920

src/broadcast.jl

Lines changed: 22 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -60,21 +60,25 @@ static_check_broadcast_shape(::Tuple{}, ::Tuple{}) = ()
6060
@inline function Base.copy(B::Broadcasted{StaticArrayStyle{M}}) where M
6161
flat = Broadcast.flatten(B); as = flat.args; f = flat.f
6262
argsizes = broadcast_sizes(as...)
63-
destsize = combine_sizes(argsizes)
64-
_broadcast(f, destsize, argsizes, as...)
63+
ax = axes(B)
64+
if ax isa Tuple{Vararg{SOneTo}}
65+
return _broadcast(f, Size(map(length, ax)), argsizes, as...)
66+
end
67+
return copy(convert(Broadcasted{DefaultArrayStyle{M}}, B))
6568
end
6669
# copyto! overloads
6770
@inline Base.copyto!(dest, B::Broadcasted{<:StaticArrayStyle}) = _copyto!(dest, B)
6871
@inline Base.copyto!(dest::AbstractArray, B::Broadcasted{<:StaticArrayStyle}) = _copyto!(dest, B)
6972
@inline function _copyto!(dest, B::Broadcasted{StaticArrayStyle{M}}) where M
7073
flat = Broadcast.flatten(B); as = flat.args; f = flat.f
7174
argsizes = broadcast_sizes(as...)
72-
destsize = combine_sizes((Size(dest), argsizes...))
73-
if Length(destsize) === Length{Dynamic()}()
74-
# destination dimension cannot be determined statically; fall back to generic broadcast!
75-
return copyto!(dest, convert(Broadcasted{DefaultArrayStyle{M}}, B))
75+
ax = axes(B)
76+
if ax isa Tuple{Vararg{SOneTo}}
77+
@boundscheck axes(dest) == ax || Broadcast.throwdm(axes(dest), ax)
78+
return _broadcast!(f, Size(map(length, ax)), dest, argsizes, as...)
7679
end
77-
_broadcast!(f, destsize, dest, argsizes, as...)
80+
# destination dimension cannot be determined statically; fall back to generic broadcast!
81+
return copyto!(dest, convert(Broadcasted{DefaultArrayStyle{M}}, B))
7882
end
7983

8084
# Resolving priority between dynamic and static axes
@@ -99,45 +103,13 @@ _bcs1(a::Base.OneTo, b::SOneTo) = _bcs1(b, a)
99103
@inline broadcast_size(a::AbstractArray) = Size(a)
100104
@inline broadcast_size(a::Tuple) = Size(length(a))
101105

102-
function broadcasted_index(oldsize, newindex)
103-
index = ones(Int, length(oldsize))
104-
for i = 1:length(oldsize)
105-
if oldsize[i] != 1
106-
index[i] = newindex[i]
107-
end
108-
end
109-
return LinearIndices(oldsize)[index...]
110-
end
111-
112-
# similar to Base.Broadcast.combine_indices:
113-
@generated function combine_sizes(s::Tuple{Vararg{Size}})
114-
sizes = [sz.parameters[1] for sz s.parameters]
115-
ndims = 0
116-
for i = 1:length(sizes)
117-
ndims = max(ndims, length(sizes[i]))
118-
end
119-
newsize = StaticDimension[Dynamic() for _ = 1 : ndims]
120-
for i = 1:length(sizes)
121-
s = sizes[i]
122-
for j = 1:length(s)
123-
if s[j] isa Dynamic
124-
continue
125-
elseif newsize[j] isa Dynamic || newsize[j] == 1
126-
newsize[j] = s[j]
127-
elseif newsize[j] s[j] && s[j] 1
128-
throw(DimensionMismatch("Tried to broadcast on inputs sized $sizes"))
129-
end
130-
end
131-
end
132-
quote
133-
@_inline_meta
134-
Size($(tuple(newsize...)))
135-
end
106+
broadcast_getindex(::Tuple{}, i::Int, I::CartesianIndex) = return :(_broadcast_getindex(a[$i], $I))
107+
function broadcast_getindex(oldsize::Tuple, i::Int, newindex::CartesianIndex)
108+
li = LinearIndices(oldsize)
109+
ind = _broadcast_getindex(li, newindex)
110+
return :(a[$i][$ind])
136111
end
137112

138-
scalar_getindex(x) = x
139-
scalar_getindex(x::Ref) = x[]
140-
141113
isstatic(::StaticArray) = true
142114
isstatic(::Transpose{<:Any, <:StaticArray}) = true
143115
isstatic(::Adjoint{<:Any, <:StaticArray}) = true
@@ -161,13 +133,11 @@ end
161133

162134
@generated function __broadcast(f, ::Size{newsize}, s::Tuple{Vararg{Size}}, a...) where newsize
163135
sizes = [sz.parameters[1] for sz s.parameters]
136+
164137
indices = CartesianIndices(newsize)
165138
exprs = similar(indices, Expr)
166139
for (j, current_ind) enumerate(indices)
167-
exprs_vals = [
168-
(!(a[i] <: AbstractArray || a[i] <: Tuple) ? :(scalar_getindex(a[$i])) : :(a[$i][$(broadcasted_index(sizes[i], current_ind))]))
169-
for i = 1:length(sizes)
170-
]
140+
exprs_vals = (broadcast_getindex(sz, i, current_ind) for (i, sz) in enumerate(sizes))
171141
exprs[j] = :(f($(exprs_vals...)))
172142
end
173143

@@ -181,27 +151,18 @@ end
181151
## Internal broadcast! machinery for StaticArrays ##
182152
####################################################
183153

184-
@generated function _broadcast!(f, ::Size{newsize}, dest::AbstractArray, s::Tuple{Vararg{Size}}, as...) where {newsize}
185-
sizes = [sz.parameters[1] for sz s.parameters]
186-
sizes = tuple(sizes...)
187-
188-
# TODO: this could also be done outside the generated function:
189-
sizematch(Size{newsize}(), Size(dest)) ||
190-
throw(DimensionMismatch("Tried to broadcast to destination sized $newsize from inputs sized $sizes"))
154+
@generated function _broadcast!(f, ::Size{newsize}, dest::AbstractArray, s::Tuple{Vararg{Size}}, a...) where {newsize}
155+
sizes = [sz.parameters[1] for sz in s.parameters]
191156

192157
indices = CartesianIndices(newsize)
193158
exprs = similar(indices, Expr)
194159
for (j, current_ind) enumerate(indices)
195-
exprs_vals = [
196-
(!(as[i] <: AbstractArray || as[i] <: Tuple) ? :(as[$i][]) : :(as[$i][$(broadcasted_index(sizes[i], current_ind))]))
197-
for i = 1:length(sizes)
198-
]
160+
exprs_vals = (broadcast_getindex(sz, i, current_ind) for (i, sz) in enumerate(sizes))
199161
exprs[j] = :(dest[$j] = f($(exprs_vals...)))
200162
end
201163

202164
return quote
203-
@_propagate_inbounds_meta
204-
@boundscheck sizematch($(Size{newsize}()), dest) || throw(DimensionMismatch("array could not be broadcast to match destination"))
165+
@_inline_meta
205166
@inbounds $(Expr(:block, exprs...))
206167
return dest
207168
end

src/precompile.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,10 @@ function _precompile_()
2424
# Some expensive generators
2525
@assert precompile(Tuple{typeof(which(__broadcast,(Any,Size,Tuple{Vararg{Size}},Vararg{Any},)).generator.gen),Any,Any,Any,Any,Any,Any})
2626
@assert precompile(Tuple{typeof(which(_zeros,(Size,Type{<:StaticArray},)).generator.gen),Any,Any,Any,Type,Any})
27-
@assert precompile(Tuple{typeof(which(combine_sizes,(Tuple{Vararg{Size}},)).generator.gen),Any,Any})
2827
@assert precompile(Tuple{typeof(which(_mapfoldl,(Any,Any,Colon,Any,Size,Vararg{StaticArray},)).generator.gen),Any,Any,Any,Any,Any,Any,Any,Any})
28+
29+
# broadcast_getindex
30+
for m = 0:5, n = m:5
31+
@assert precompile(Tuple{typeof(broadcast_getindex),NTuple{m,Int},Int,CartesianIndex{n}})
32+
end
2933
end

0 commit comments

Comments
 (0)