Skip to content

Commit 4fa67c9

Browse files
author
Roger-luo
committed
update
1 parent b45bfcf commit 4fa67c9

File tree

2 files changed

+9
-9
lines changed

2 files changed

+9
-9
lines changed

src/grad.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -44,14 +44,14 @@ function grad(fdm, f, xs...)
4444
end
4545

4646
"""
47-
jacobian(fdm, f, xs::Union{Real, AbstractArray{<:Real}}; dim::Int=length(f(x)))
47+
jacobian(fdm, f, xs::Union{Real, AbstractArray{<:Real}}; len::Int=length(f(x)))
4848
4949
Approximate the Jacobian of `f` at `x` using `fdm`. `f(x)` must be a length `D` vector. If
5050
`D` is not provided, then `f(x)` is computed once to determine the output size.
5151
"""
52-
function jacobian(fdm, f, x::Union{T, AbstractArray{T}}; dim::Int=length(f(x))) where {T <: Real}
53-
J = Matrix{float(T)}(undef, dim, length(x))
54-
for d in 1:dim
52+
function jacobian(fdm, f, x::Union{T, AbstractArray{T}}; len::Int=length(f(x))) where {T <: Real}
53+
J = Matrix{float(T)}(undef, len, length(x))
54+
for d in 1:len
5555
gs = grad(fdm, x->f(x)[d], x)
5656
for k in 1:length(x)
5757
J[d, k] = gs[k]
@@ -60,9 +60,9 @@ function jacobian(fdm, f, x::Union{T, AbstractArray{T}}; dim::Int=length(f(x)))
6060
return J
6161
end
6262

63-
function jacobian(fdm, f, xs...; dim::Int=length(f(xs...)))
63+
function jacobian(fdm, f, xs...; len::Int=length(f(xs...)))
6464
return ntuple(length(xs)) do k
65-
jacobian(fdm, x->f(replace_arg(x, xs, k)...), xs[k]; dim=dim)
65+
jacobian(fdm, x->f(replace_arg(x, xs, k)...), xs[k]; len=len)
6666
end
6767
end
6868

@@ -78,7 +78,7 @@ _jvp(fdm, f, x::Vector{<:Number}, ẋ::AV{<:Number}) = fdm(ε -> f(x .+ ε .* x
7878
7979
Convenience function to compute `transpose(jacobian(f, x)) * ȳ`.
8080
"""
81-
_j′vp(fdm, f, ȳ::AV{<:Number}, x::Vector{<:Number}) = transpose(jacobian(fdm, f, x; dim=length(ȳ))) *
81+
_j′vp(fdm, f, ȳ::AV{<:Number}, x::Vector{<:Number}) = transpose(jacobian(fdm, f, x; len=length(ȳ))) *
8282

8383
"""
8484
jvp(fdm, f, x, ẋ)

test/grad.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,8 @@ Base.length(x::DummyType) = size(x.X, 1)
2626

2727
function check_jac_and_jvp_and_j′vp(fdm, f, ȳ, x, ẋ, J_exact)
2828
xc = copy(x)
29-
@test jacobian(fdm, f, x; dim=length(ȳ)) J_exact
30-
@test jacobian(fdm, f, x) == jacobian(fdm, f, x; dim=length(ȳ))
29+
@test jacobian(fdm, f, x; len=length(ȳ)) J_exact
30+
@test jacobian(fdm, f, x) == jacobian(fdm, f, x; len=length(ȳ))
3131
@test _jvp(fdm, f, x, ẋ) J_exact *
3232
@test _j′vp(fdm, f, ȳ, x) transpose(J_exact) *
3333
@test xc == x

0 commit comments

Comments
 (0)