Skip to content

Commit f669465

Browse files
TheBBc42f
authored andcommitted
Simplify broadcast code generation (#643)
This simplifies code by using CartesianIndices instead of generating indices manually. Fixes #642
1 parent c607c0d commit f669465

File tree

2 files changed

+18
-56
lines changed

2 files changed

+18
-56
lines changed

src/broadcast.jl

Lines changed: 16 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -105,30 +105,15 @@ scalar_getindex(x::Tuple{<: Any}) = x[1]
105105
end
106106
end
107107

108-
exprs = Array{Expr}(undef, newsize)
109-
more = prod(newsize) > 0
110-
current_ind = ones(Int, length(newsize))
111108
sizes = [sz.parameters[1] for sz s.parameters]
112-
113-
while more
114-
exprs_vals = [(!(a[i] <: AbstractArray) ? :(scalar_getindex(a[$i])) : :(a[$i][$(broadcasted_index(sizes[i], current_ind))])) for i = 1:length(sizes)]
115-
exprs[current_ind...] = :(f($(exprs_vals...)))
116-
117-
# increment current_ind (maybe use CartesianIndices?)
118-
current_ind[1] += 1
119-
for i 1:length(newsize)
120-
if current_ind[i] > newsize[i]
121-
if i == length(newsize)
122-
more = false
123-
break
124-
else
125-
current_ind[i] = 1
126-
current_ind[i+1] += 1
127-
end
128-
else
129-
break
130-
end
131-
end
109+
indices = CartesianIndices(newsize)
110+
exprs = similar(indices, Expr)
111+
for (j, current_ind) enumerate(indices)
112+
exprs_vals = [
113+
(!(a[i] <: AbstractArray) ? :(scalar_getindex(a[$i])) : :(a[$i][$(broadcasted_index(sizes[i], current_ind))]))
114+
for i = 1:length(sizes)
115+
]
116+
exprs[j] = :(f($(exprs_vals...)))
132117
end
133118

134119
return quote
@@ -150,39 +135,14 @@ end
150135
sizematch(Size{newsize}(), Size(dest)) ||
151136
throw(DimensionMismatch("Tried to broadcast to destination sized $newsize from inputs sized $sizes"))
152137

153-
ndims = 0
154-
for i = 1:length(sizes)
155-
ndims = max(ndims, length(sizes[i]))
156-
end
157-
158-
exprs = Array{Expr}(undef, newsize)
159-
j = 1
160-
more = prod(newsize) > 0
161-
current_ind = ones(Int, max(length(newsize), length.(sizes)...))
162-
while more
163-
exprs_vals = [(!(as[i] <: AbstractArray) ? :(as[$i][]) : :(as[$i][$(broadcasted_index(sizes[i], current_ind))])) for i = 1:length(sizes)]
164-
exprs[current_ind...] = :(dest[$j] = f($(exprs_vals...)))
165-
166-
# increment current_ind (maybe use CartesianIndices?)
167-
if length(current_ind) >= 1
168-
current_ind[1] += 1
169-
for i 1:length(newsize)
170-
if current_ind[i] > newsize[i]
171-
if i == length(newsize)
172-
more = false
173-
break
174-
else
175-
current_ind[i] = 1
176-
current_ind[i+1] += 1
177-
end
178-
else
179-
break
180-
end
181-
end
182-
else
183-
break
184-
end
185-
j += 1
138+
indices = CartesianIndices(newsize)
139+
exprs = similar(indices, Expr)
140+
for (j, current_ind) enumerate(indices)
141+
exprs_vals = [
142+
(!(as[i] <: AbstractArray) ? :(as[$i][]) : :(as[$i][$(broadcasted_index(sizes[i], current_ind))]))
143+
for i = 1:length(sizes)
144+
]
145+
exprs[j] = :(dest[$j] = f($(exprs_vals...)))
186146
end
187147

188148
return quote

test/broadcast.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@ Broadcast.broadcastable(x::ScalarTest) = Ref(x)
99
@test x == @inferred(x .+ ScalarTest())
1010
@test x .+ 1 == @inferred(x .+ Ref(1))
1111
end
12+
13+
@test Scalar(3) == @inferred(Scalar(1) .+ 2)
1214
end
1315

1416
@testset "Broadcast sizes" begin

0 commit comments

Comments
 (0)