Skip to content

Commit f6bd884

Browse files
authored
Do gradient via mutating and unmutating cell
1 parent 70847c8 commit f6bd884

File tree

1 file changed

+11
-5
lines changed

1 file changed

+11
-5
lines changed

src/grad.jl

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,19 +8,25 @@ Approximate the gradient of `f` at `xs...` using `fdm`. Assumes that `f(xs...)`
88
"""
99
function grad end
1010

11-
function grad(fdm, f, x::AbstractArray{T}) where T <: Number
11+
function _grad(fdm, f, x::AbstractArray{T}) where T <: Number
12+
# x must be mutable, we will mutate it and then mutate it back.
1213
dx = similar(x)
13-
tmp = similar(x)
1414
for k in eachindex(x)
1515
dx[k] = fdm(zero(T)) do ϵ
16-
tmp .= x
17-
tmp[k] += ϵ
18-
return f(tmp)
16+
xk = x[k]
17+
x[k] = xk + ϵ
18+
ret = f(x)
19+
x[k] = xk # Can't do `x[k] -= ϵ` as floating-point math is not associative
20+
return ret
1921
end
2022
end
2123
return (dx, )
2224
end
2325

26+
grad(fdm, f, x::Array{<:Number}) = _grad(fdm, f, x)
27+
# Fallback for when we don't know `x` will be mutable:
28+
grad(fdm, f, x::AbstractArray{<:Number}) = _grad(fdm, f, similar(x).=x)
29+
2430
grad(fdm, f, x::Real) = (fdm(f, x), )
2531
grad(fdm, f, x::Tuple) = (grad(fdm, (xs...)->f(xs), x...), )
2632

0 commit comments

Comments
 (0)