Skip to content

Commit 4d22ce2

Browse files
author
Roger-luo
committed
handle dict & change style
1 parent a555a2c commit 4d22ce2

File tree

1 file changed

+14
-9
lines changed

1 file changed

+14
-9
lines changed

src/grad.jl

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,5 @@
11
export grad, jacobian, jvp, j′vp, to_vec
2-
function replace_arg(x, xs::Tuple, k::Int)
3-
return ntuple(length(xs)) do p
4-
if p == k
5-
x
6-
else
7-
xs[p]
8-
end
9-
end
10-
end
2+
replace_arg(x, xs::Tuple, k::Int) = ntuple(p -> p == k ? x : xs[p], length(xs))
113

124
"""
135
grad(fdm, f, xs...)
@@ -31,6 +23,19 @@ end
3123
grad(fdm, f, x::Real) = fdm(f, x)
3224
grad(fdm, f, x::Tuple) = grad(fdm, (xs...)->f(xs), x...)
3325

26+
function grad(fdm, f, d::Dict{K, V}) where {K, V}
27+
dd = Dict{K, V}()
28+
for (k, v) in d
29+
function f′(x)
30+
tmp = copy(d)
31+
tmp[k] = x
32+
return f(tmp)
33+
end
34+
dd[k] = grad(fdm, f′, v)
35+
end
36+
return dd
37+
end
38+
3439
function grad(fdm, f, xs...)
3540
return ntuple(length(xs)) do k
3641
grad(fdm, x->f(replace_arg(x, xs, k)...), xs[k])

0 commit comments

Comments
 (0)