Skip to content

Commit febdc19

Browse files
authored
Merge pull request #663 from mateuszbaran/mbaran/fix485
Broadcasting with tuples (fixes #485)
2 parents 4c8e95f + 502eee4 commit febdc19

File tree

2 files changed

+20
-3
lines changed

2 files changed

+20
-3
lines changed

src/broadcast.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ broadcast_indices(A::StaticArray) = indices(A)
5555
@inline broadcast_sizes() = ()
5656
@inline broadcast_size(a) = Size()
5757
@inline broadcast_size(a::AbstractArray) = Size(a)
58+
@inline broadcast_size(a::NTuple{N}) where N = Size(N)
5859

5960
function broadcasted_index(oldsize, newindex)
6061
index = ones(Int, length(oldsize))
@@ -94,7 +95,6 @@ end
9495

9596
scalar_getindex(x) = x
9697
scalar_getindex(x::Ref) = x[]
97-
scalar_getindex(x::Tuple{<: Any}) = x[1]
9898

9999
@generated function _broadcast(f, ::Size{newsize}, s::Tuple{Vararg{Size}}, a...) where newsize
100100
first_staticarray = 0
@@ -110,7 +110,7 @@ scalar_getindex(x::Tuple{<: Any}) = x[1]
110110
exprs = similar(indices, Expr)
111111
for (j, current_ind) enumerate(indices)
112112
exprs_vals = [
113-
(!(a[i] <: AbstractArray) ? :(scalar_getindex(a[$i])) : :(a[$i][$(broadcasted_index(sizes[i], current_ind))]))
113+
(!(a[i] <: AbstractArray || a[i] <: Tuple) ? :(scalar_getindex(a[$i])) : :(a[$i][$(broadcasted_index(sizes[i], current_ind))]))
114114
for i = 1:length(sizes)
115115
]
116116
exprs[j] = :(f($(exprs_vals...)))
@@ -139,7 +139,7 @@ end
139139
exprs = similar(indices, Expr)
140140
for (j, current_ind) enumerate(indices)
141141
exprs_vals = [
142-
(!(as[i] <: AbstractArray) ? :(as[$i][]) : :(as[$i][$(broadcasted_index(sizes[i], current_ind))]))
142+
(!(as[i] <: AbstractArray || as[i] <: Tuple) ? :(as[$i][]) : :(as[$i][$(broadcasted_index(sizes[i], current_ind))]))
143143
for i = 1:length(sizes)
144144
]
145145
exprs[j] = :(dest[$j] = f($(exprs_vals...)))

test/broadcast.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -213,4 +213,21 @@ end
213213
foo493(X) = normalize.(X)
214214
@test foo493(X) isa Core.Compiler.return_type(foo493, Tuple{typeof(X)})
215215
end
216+
217+
@testset "broadcasting with tuples" begin
218+
# issue 485
219+
@test @inferred(SA[1,2,3] .+ (1,)) === SA{Int}[2, 3, 4]
220+
@test @inferred(SA[1,2,3] .+ (10, 20, 30)) === SA{Int}[11, 22, 33]
221+
@test @inferred((1,2) .+ (SA[10 20; 30 40])) === SA{Int}[11 21; 32 42]
222+
@test @inferred((SA[10 20; 30 40]) .+ (1,2)) === SA{Int}[11 21; 32 42]
223+
224+
add_bc!(m, v) = m .+= v # Helper function; @inferred gets confused by .+= syntax
225+
@test @inferred(add_bc!(MVector((1,2,3)), (10, 20, 30))) ::MVector{3,Int} == SA[11, 22, 33]
226+
@test @inferred(add_bc!(MMatrix(SA[10 20; 30 40]), (1,2))) ::MMatrix{2,2,Int} == SA[11 21; 32 42]
227+
228+
# Tuples of SA
229+
@test SA[1,2,3] .* (SA[1,0],) === SVector{3,SVector{2,Int}}(((1,0), (2,0), (3,0)))
230+
# Unfortunately this case of nested broadcasting is not inferred
231+
@test_broken @inferred(SA[1,2,3] .* (SA[1,0],))
232+
end
216233
end

0 commit comments

Comments
 (0)