4
4
5
5
import Base. Broadcast:
6
6
BroadcastStyle, AbstractArrayStyle, Broadcasted, DefaultArrayStyle, materialize!
7
+ import Base. Broadcast: combine_axes, instantiate, _broadcast_getindex, broadcast_shape, Style
7
8
import Base. Broadcast: _bcs1 # for SOneTo axis information
8
9
using Base. Broadcast: _bcsm
9
10
# Add a new BroadcastStyle for StaticArrays, derived from AbstractArrayStyle
@@ -19,6 +20,42 @@ BroadcastStyle(::StaticArrayStyle{M}, ::DefaultArrayStyle{N}) where {M,N} =
19
20
DefaultArrayStyle (Val (max (M, N)))
20
21
BroadcastStyle (:: StaticArrayStyle{M} , :: DefaultArrayStyle{0} ) where {M} =
21
22
StaticArrayStyle {M} ()
23
+
24
+ # combine_axes overload (for Tuple)
25
+ @inline static_combine_axes (A, B... ) = broadcast_shape (static_axes (A), static_combine_axes (B... ))
26
+ static_combine_axes (A) = static_axes (A)
27
+ static_axes (A) = axes (A)
28
+ static_axes (x:: Tuple ) = (SOneTo {length(x)} (),)
29
+ static_axes (bc:: Broadcasted{Style{Tuple}} ) = static_combine_axes (bc. args... )
30
+ Broadcast. _axes (bc:: Broadcasted{<:StaticArrayStyle} , :: Nothing ) = static_combine_axes (bc. args... )
31
+
32
+ # instantiate overload
33
+ @inline function instantiate (B:: Broadcasted{StaticArrayStyle{M}} ) where M
34
+ if B. axes isa Tuple{Vararg{SOneTo}} || B. axes isa Tuple && length (B. axes) > M
35
+ return invoke (instantiate, Tuple{Broadcasted}, B)
36
+ elseif B. axes isa Nothing
37
+ ax = static_combine_axes (B. args... )
38
+ return Broadcasted {StaticArrayStyle{M}} (B. f, B. args, ax)
39
+ else
40
+ # We need to update B.axes for `broadcast!` if it's not static and `ndims(dest) < M`.
41
+ ax = static_check_broadcast_shape (B. axes, static_combine_axes (B. args... ))
42
+ return Broadcasted {StaticArrayStyle{M}} (B. f, B. args, ax)
43
+ end
44
+ end
45
+ @inline function static_check_broadcast_shape (shp:: Tuple , Ashp:: Tuple{Vararg{SOneTo}} )
46
+ ax1 = if length (Ashp[1 ]) == 1
47
+ shp[1 ]
48
+ elseif Ashp[1 ] == shp[1 ]
49
+ Ashp[1 ]
50
+ else
51
+ throw (DimensionMismatch (" array could not be broadcast to match destination" ))
52
+ end
53
+ (ax1, static_check_broadcast_shape (Base. tail (shp), Base. tail (Ashp))... )
54
+ end
55
+ static_check_broadcast_shape (:: Tuple{} , :: Tuple{SOneTo,Vararg{SOneTo}} ) =
56
+ throw (DimensionMismatch (" cannot broadcast array to have fewer non-singleton dimensions" ))
57
+ static_check_broadcast_shape (:: Tuple{} , :: Tuple{SOneTo{1},Vararg{SOneTo{1}}} ) = ()
58
+ static_check_broadcast_shape (:: Tuple{} , :: Tuple{} ) = ()
22
59
# copy overload
23
60
@inline function Base. copy (B:: Broadcasted{StaticArrayStyle{M}} ) where M
24
61
flat = Broadcast. flatten (B); as = flat. args; f = flat. f
42
79
43
80
# Resolving priority between dynamic and static axes
44
81
_bcs1 (a:: SOneTo , b:: SOneTo ) = _bcsm (b, a) ? b : (_bcsm (a, b) ? a : throw (DimensionMismatch (" arrays could not be broadcast to a common size" )))
45
- _bcs1 (a:: SOneTo , b:: Base.OneTo ) = _bcs1 (Base. OneTo (a), b)
46
- _bcs1 (a:: Base.OneTo , b:: SOneTo ) = _bcs1 (a, Base. OneTo (b))
82
+ function _bcs1 (a:: SOneTo , b:: Base.OneTo )
83
+ length (a) == 1 && return b
84
+ if length (b) != length (a) && length (b) != 1
85
+ throw (DimensionMismatch (" arrays could not be broadcast to a common size" ))
86
+ end
87
+ return a
88
+ end
89
+ _bcs1 (a:: Base.OneTo , b:: SOneTo ) = _bcs1 (b, a)
47
90
48
91
# ##################################################
49
92
# # Internal broadcast machinery for StaticArrays ##
0 commit comments