@@ -8,33 +8,40 @@ Approximate the gradient of `f` at `xs...` using `fdm`. Assumes that `f(xs...)`
8
8
"""
9
9
function grad end
10
10
11
- function grad (fdm, f, x:: AbstractArray{T} ) where T <: Number
11
+ function _grad (fdm, f, x:: AbstractArray{T} ) where T <: Number
12
+ # x must be mutable, we will mutate it and then mutate it back.
12
13
dx = similar (x)
13
- tmp = similar (x)
14
14
for k in eachindex (x)
15
15
dx[k] = fdm (zero (T)) do ϵ
16
- tmp .= x
17
- tmp[k] += ϵ
18
- return f (tmp)
16
+ xk = x[k]
17
+ x[k] = xk + ϵ
18
+ ret = f (x)
19
+ x[k] = xk # Can't do `x[k] -= ϵ` as floating-point math is not associative
20
+ return ret
19
21
end
20
22
end
21
23
return (dx, )
22
24
end
23
25
26
+ grad (fdm, f, x:: Array{<:Number} ) = _grad (fdm, f, x)
27
+ # Fallback for when we don't know `x` will be mutable:
28
+ grad (fdm, f, x:: AbstractArray{<:Number} ) = _grad (fdm, f, similar (x).= x)
29
+
24
30
grad (fdm, f, x:: Real ) = (fdm (f, x), )
25
31
grad (fdm, f, x:: Tuple ) = (grad (fdm, (xs... )-> f (xs), x... ), )
26
32
27
33
function grad (fdm, f, d:: Dict{K, V} ) where {K, V}
28
- dd = Dict {K, V} ()
34
+ ∇d = Dict {K, V} ()
29
35
for (k, v) in d
36
+ dk = d[k]
30
37
function f′ (x)
31
- tmp = copy (d)
32
- tmp[k] = x
33
- return f (tmp)
38
+ d[k] = x
39
+ return f (d)
34
40
end
35
- dd[k] = grad (fdm, f′, v)[1 ]
41
+ ∇d[k] = grad (fdm, f′, v)[1 ]
42
+ d[k] = dk
36
43
end
37
- return (dd , )
44
+ return (∇d , )
38
45
end
39
46
40
47
function grad (fdm, f, x)
0 commit comments