Skip to content

Commit 2433157

Browse files
committed
allow more types for _broadcast + fix round
1 parent 6b6f519 commit 2433157

File tree

1 file changed

+7
-3
lines changed

1 file changed

+7
-3
lines changed

src/broadcast.jl

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,10 @@
88
_broadcast(f, broadcast_sizes(a...), a...)
99
end
1010

11+
@propagate_inbounds function broadcast{T}(f::Function, a::Type{T}, x::StaticArray)
12+
_broadcast(f, (Size(), Size(x)), T, x)
13+
end
14+
1115
@inline broadcast_sizes(a...) = _broadcast_sizes((), a...)
1216
@inline _broadcast_sizes(t::Tuple) = t
1317
@inline _broadcast_sizes(t::Tuple, a::StaticArray, as...) = _broadcast_sizes((t..., Size(a)), as...)
@@ -23,7 +27,7 @@ function broadcasted_index(oldsize, newindex)
2327
return sub2ind(oldsize, index...)
2428
end
2529

26-
@generated function _broadcast(f, s::Tuple{Vararg{Size}}, a::Union{Number, StaticArray}...)
30+
@generated function _broadcast(f, s::Tuple{Vararg{Size}}, a...)
2731
first_staticarray = 0
2832
for i = 1:length(a)
2933
if a[i] <: StaticArray
@@ -57,7 +61,7 @@ end
5761
current_ind = ones(Int, length(newsize))
5862

5963
while more
60-
exprs_vals = [(a[i] <: Number ? :(a[$i]) : :(a[$i][$(broadcasted_index(sizes[i], current_ind))])) for i = 1:length(sizes)]
64+
exprs_vals = [(!(a[i] <: AbstractArray) ? :(a[$i]) : :(a[$i][$(broadcasted_index(sizes[i], current_ind))])) for i = 1:length(sizes)]
6165
exprs[current_ind...] = :(f($(exprs_vals...)))
6266

6367
# increment current_ind (maybe use CartesianRange?)
@@ -77,7 +81,7 @@ end
7781
end
7882
end
7983

80-
eltype_exprs = [:(eltype($t)) for t a]
84+
eltype_exprs = [t <: AbstractArray ? :($(eltype(t))) : :($t) for t a]
8185
newtype_expr = :(Core.Inference.return_type(f, Tuple{$(eltype_exprs...)}))
8286

8387
return quote

0 commit comments

Comments
 (0)