Skip to content

Commit 573d645

Browse files
authored
Merge pull request #59 from JuliaDiff/ox/mut1
Do gradient via mutating and unmutating cell
2 parents 70847c8 + f677183 commit 573d645

File tree

1 file changed

+18
-11
lines changed

1 file changed

+18
-11
lines changed

src/grad.jl

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,33 +8,40 @@ Approximate the gradient of `f` at `xs...` using `fdm`. Assumes that `f(xs...)`
88
"""
99
function grad end
1010

11-
function grad(fdm, f, x::AbstractArray{T}) where T <: Number
11+
function _grad(fdm, f, x::AbstractArray{T}) where T <: Number
12+
# x must be mutable, we will mutate it and then mutate it back.
1213
dx = similar(x)
13-
tmp = similar(x)
1414
for k in eachindex(x)
1515
dx[k] = fdm(zero(T)) do ϵ
16-
tmp .= x
17-
tmp[k] += ϵ
18-
return f(tmp)
16+
xk = x[k]
17+
x[k] = xk + ϵ
18+
ret = f(x)
19+
x[k] = xk # Can't do `x[k] -= ϵ` as floating-point math is not associative
20+
return ret
1921
end
2022
end
2123
return (dx, )
2224
end
2325

26+
grad(fdm, f, x::Array{<:Number}) = _grad(fdm, f, x)
27+
# Fallback for when we don't know `x` will be mutable:
28+
grad(fdm, f, x::AbstractArray{<:Number}) = _grad(fdm, f, similar(x).=x)
29+
2430
grad(fdm, f, x::Real) = (fdm(f, x), )
2531
grad(fdm, f, x::Tuple) = (grad(fdm, (xs...)->f(xs), x...), )
2632

2733
function grad(fdm, f, d::Dict{K, V}) where {K, V}
28-
dd = Dict{K, V}()
34+
∇d = Dict{K, V}()
2935
for (k, v) in d
36+
dk = d[k]
3037
function f′(x)
31-
tmp = copy(d)
32-
tmp[k] = x
33-
return f(tmp)
38+
d[k] = x
39+
return f(d)
3440
end
35-
dd[k] = grad(fdm, f′, v)[1]
41+
∇d[k] = grad(fdm, f′, v)[1]
42+
d[k] = dk
3643
end
37-
return (dd, )
44+
return (∇d, )
3845
end
3946

4047
function grad(fdm, f, x)

0 commit comments

Comments
 (0)