Skip to content

Commit 5963b5d

Browse files
authored
Merge pull request #196 from JuliaDiff/os/fix-Cuda-gradient
fix `Cuda` gradients
2 parents c98f1ab + 6fd0ea4 commit 5963b5d

File tree

1 file changed

+3
-4
lines changed

1 file changed

+3
-4
lines changed

src/gradients.jl

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -239,9 +239,8 @@ function finite_difference_gradient!(
239239
fx, c1, c2, c3 = cache.fx, cache.c1, cache.c2, cache.c3
240240
if fdtype != Val(:complex) && ArrayInterface.fast_scalar_indexing(c2)
241241
@. c2 = compute_epsilon(fdtype, one(eltype(x)), relstep, absstep, dir)
242-
copyto!(c1, x)
243242
end
244-
copyto!(c3, x)
243+
copyto!(c1, x)
245244
if fdtype == Val(:forward)
246245
@inbounds for i eachindex(x)
247246
if ArrayInterface.fast_scalar_indexing(c2)
@@ -273,6 +272,7 @@ function finite_difference_gradient!(
273272
end
274273
end
275274
elseif fdtype == Val(:central)
275+
copyto!(c3, x)
276276
@inbounds for i eachindex(x)
277277
if ArrayInterface.fast_scalar_indexing(c2)
278278
epsilon = ArrayInterface.allowed_getindex(c2, i) * dir
@@ -296,9 +296,8 @@ function finite_difference_gradient!(
296296
ArrayInterface.allowed_setindex!(c3, x_old, i)
297297
end
298298
elseif fdtype == Val(:complex) && returntype <: Real
299-
copyto!(c1, x)
300-
epsilon_complex = eps(real(eltype(x)))
301299
# we use c1 here to avoid typing issues with x
300+
epsilon_complex = eps(real(eltype(x)))
302301
@inbounds for i eachindex(x)
303302
c1_old = ArrayInterface.allowed_getindex(c1, i)
304303
ArrayInterface.allowed_setindex!(c1, c1_old + im * epsilon_complex, i)

0 commit comments

Comments
 (0)