Skip to content

Commit dc16d92

Browse files
authored
Fix accumulate bug (#2005)
1 parent 0971891 commit dc16d92

File tree

2 files changed

+4
-1
lines changed

2 files changed

+4
-1
lines changed

src/accumulate.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,9 @@ function scan!(f::Function, output::AnyCuArray{T}, input::AnyCuArray;
177177

178178
# get the total of each thread block (except the first) of the partial scans
179179
aggregates = fill(neutral, Base.setindex(size(input), blocks_dim, dims))
180-
copyto!(aggregates, selectdim(output, dims, partial:partial:length(Rdim)))
180+
partials = selectdim(output, dims, partial:partial:length(Rdim))
181+
indices = CartesianIndices(partials)
182+
copyto!(aggregates, indices, partials, indices)
181183

182184
# scan these totals to get totals for the entire partial scan
183185
accumulate!(f, aggregates, aggregates; dims=dims)

test/base/array.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -318,6 +318,7 @@ end
318318
@testset "accumulate" begin
319319
for n in (0, 1, 2, 3, 10, 10_000, 16384, 16384+1) # small, large, odd & even, pow2 and not
320320
@test testf(x->accumulate(+, x), rand(n))
321+
@test testf(x->accumulate(+, x), rand(n,2))
321322
@test testf((x,y)->accumulate(+, x; init=y), rand(n), rand())
322323
end
323324

0 commit comments

Comments
 (0)