Skip to content

Commit 67ecefa

Browse files
authored
Merge pull request #15 from YingboMa/myb/hessian
Hessian
2 parents 3fd3205 + 67ac079 commit 67ecefa

File tree

5 files changed

+12
-8
lines changed

5 files changed

+12
-8
lines changed
File renamed without changes.

src/ForwardDiff2.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,6 @@ include("dual_context.jl")
88
include("api.jl")
99

1010
# Experimental
11-
#include("aosoa.jl")
11+
#include("experiment/aosoa.jl")
1212

1313
end # module

src/api.jl

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -40,12 +40,13 @@ function D(f)
4040
return diffres
4141
end
4242
# scalar
43-
function deriv(arg)
44-
dualrun() do
45-
dualized = map(x->Dual(x, one(x)), arg)
46-
res = f(dualized)
47-
return map(partials, res)
43+
function deriv(x)
44+
dx = one(x)
45+
res = dualrun() do
46+
dualized = Dual(x, dx)
47+
f(dualized)
4848
end
49+
return map(partials, res)
4950
end
5051
return deriv
5152
end

src/dualarray.jl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
using StaticArrays: SVector
22

33
partial_type(::Dual{T,V,P}) where {T,V,P} = P
4-
# TODO: Tagging?
5-
# TODO: Integrate better with SVector. Maybe even use SIMD.jl?
64

75
struct DualArray{T,E,M,V<:AbstractArray,D<:AbstractArray} <: AbstractArray{E,M}
86
data::V

test/api.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,15 @@ using StaticArrays
44

55
@testset begin
66
@test D(sin)(1.0) == cos(1.0)
7+
# Gradient
78
@test D(x->[x, x^2])(3) == [1, 6]
89
@test D(sum)([1,2,3]) == ones(3)'
10+
# Jacobian
911
@test D(x->@SVector([x[1]^x[2], x[3]^3, x[3]*x[2]*x[1]]))(@SVector[1,2,3.]) === @SMatrix [2.0 0 0; 0 0 27; 6 3 2]
1012
@test D(cumsum)(@SVector([1,2,3])) == @SMatrix [1 0 0; 1 1 0; 1 1 1]
1113
@test D(cumsum)([1,2,3]) == [1 0 0; 1 1 0; 1 1 1]
1214
@test D(x->@SVector([x[1], x[2]]))(@SVector([1,2,3])) === @SMatrix [1 0 0; 0 1 0]
15+
# Hessian
16+
@test D(D(x->x[1]^x[2] + x[3]^3 + x[3]*x[2]*x[1]))(@SVector[1,2,3]) === @SMatrix [2 4 2; 4 0 1; 2 1 18.]
17+
@test D(D(x->x[1]^x[2] + x[3]^3 + x[3]*x[2]*x[1]))([1,2,3]) == [2 4 2; 4 0 1; 2 1 18.]
1318
end

0 commit comments

Comments
 (0)