Skip to content

Commit 7826d5c

Browse files
committed
Make instantiate generate static axes
add `static_combine_axes` to handle Tuple correctly
1 parent ca83953 commit 7826d5c

File tree

1 file changed

+45
-2
lines changed

1 file changed

+45
-2
lines changed

src/broadcast.jl

Lines changed: 45 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import Base.Broadcast:
66
BroadcastStyle, AbstractArrayStyle, Broadcasted, DefaultArrayStyle, materialize!
7+
import Base.Broadcast: combine_axes, instantiate, _broadcast_getindex, broadcast_shape, Style
78
import Base.Broadcast: _bcs1 # for SOneTo axis information
89
using Base.Broadcast: _bcsm
910
# Add a new BroadcastStyle for StaticArrays, derived from AbstractArrayStyle
@@ -19,6 +20,42 @@ BroadcastStyle(::StaticArrayStyle{M}, ::DefaultArrayStyle{N}) where {M,N} =
1920
DefaultArrayStyle(Val(max(M, N)))
2021
BroadcastStyle(::StaticArrayStyle{M}, ::DefaultArrayStyle{0}) where {M} =
2122
StaticArrayStyle{M}()
23+
24+
# combine_axes overload (for Tuple)
25+
@inline static_combine_axes(A, B...) = broadcast_shape(static_axes(A), static_combine_axes(B...))
26+
static_combine_axes(A) = static_axes(A)
27+
static_axes(A) = axes(A)
28+
static_axes(x::Tuple) = (SOneTo{length(x)}(),)
29+
static_axes(bc::Broadcasted{Style{Tuple}}) = static_combine_axes(bc.args...)
30+
Broadcast._axes(bc::Broadcasted{<:StaticArrayStyle}, ::Nothing) = static_combine_axes(bc.args...)
31+
32+
# instantiate overload
33+
@inline function instantiate(B::Broadcasted{StaticArrayStyle{M}}) where M
34+
if B.axes isa Tuple{Vararg{SOneTo}} || B.axes isa Tuple && length(B.axes) > M
35+
return invoke(instantiate, Tuple{Broadcasted}, B)
36+
elseif B.axes isa Nothing
37+
ax = static_combine_axes(B.args...)
38+
return Broadcasted{StaticArrayStyle{M}}(B.f, B.args, ax)
39+
else
40+
# We need to update B.axes for `broadcast!` if it's not static and `ndims(dest) < M`.
41+
ax = static_check_broadcast_shape(B.axes, static_combine_axes(B.args...))
42+
return Broadcasted{StaticArrayStyle{M}}(B.f, B.args, ax)
43+
end
44+
end
45+
@inline function static_check_broadcast_shape(shp::Tuple, Ashp::Tuple{Vararg{SOneTo}})
46+
ax1 = if length(Ashp[1]) == 1
47+
shp[1]
48+
elseif Ashp[1] == shp[1]
49+
Ashp[1]
50+
else
51+
throw(DimensionMismatch("array could not be broadcast to match destination"))
52+
end
53+
(ax1, static_check_broadcast_shape(Base.tail(shp), Base.tail(Ashp))...)
54+
end
55+
static_check_broadcast_shape(::Tuple{}, ::Tuple{SOneTo,Vararg{SOneTo}}) =
56+
throw(DimensionMismatch("cannot broadcast array to have fewer non-singleton dimensions"))
57+
static_check_broadcast_shape(::Tuple{}, ::Tuple{SOneTo{1},Vararg{SOneTo{1}}}) = ()
58+
static_check_broadcast_shape(::Tuple{}, ::Tuple{}) = ()
2259
# copy overload
2360
@inline function Base.copy(B::Broadcasted{StaticArrayStyle{M}}) where M
2461
flat = Broadcast.flatten(B); as = flat.args; f = flat.f
@@ -42,8 +79,14 @@ end
4279

4380
# Resolving priority between dynamic and static axes
4481
_bcs1(a::SOneTo, b::SOneTo) = _bcsm(b, a) ? b : (_bcsm(a, b) ? a : throw(DimensionMismatch("arrays could not be broadcast to a common size")))
45-
_bcs1(a::SOneTo, b::Base.OneTo) = _bcs1(Base.OneTo(a), b)
46-
_bcs1(a::Base.OneTo, b::SOneTo) = _bcs1(a, Base.OneTo(b))
82+
function _bcs1(a::SOneTo, b::Base.OneTo)
83+
length(a) == 1 && return b
84+
if length(b) != length(a) && length(b) != 1
85+
throw(DimensionMismatch("arrays could not be broadcast to a common size"))
86+
end
87+
return a
88+
end
89+
_bcs1(a::Base.OneTo, b::SOneTo) = _bcs1(b, a)
4790

4891
###################################################
4992
## Internal broadcast machinery for StaticArrays ##

0 commit comments

Comments
 (0)