Skip to content

Commit 82096ee

Browse files
authored
add dx methods (#77)
* add `dx` methods * add tests for derivatives of `ZeroTangent` * fix typo
1 parent be3b62d commit 82096ee

File tree

2 files changed

+4
-0
lines changed

2 files changed

+4
-0
lines changed

src/interface.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,8 @@ However, users may provide additional overloads for custom representations of
5353
one dimensional Riemannian manifolds.
5454
"""
5555
dx(x::Real) = one(x)
56+
dx(::NoTangent) = NoTangent()
57+
dx(::ZeroTangent) = ZeroTangent()
5658
dx(x::Complex) = error("Tried to take the gradient of a complex-valued function.")
5759
dx(x) = error("Cotangent space not defined for `$(typeof(x))`. Try a real-valued function.")
5860

test/runtests.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,8 @@ end
184184
@test bwd(x->f_crit_edge(true, true, false, x))(1.0) == 2.0
185185
@test bwd(x->f_crit_edge(false, true, true, x))(1.0) == 12.0
186186
@test bwd(x->f_crit_edge(false, false, true, x))(1.0) == 4.0
187+
@test bwd(bwd(x->5))(1.0) == ZeroTangent()
188+
@test fwd(fwd(x->5))(1.0) == ZeroTangent()
187189

188190
# Issue #27 - Mixup in lifting of getfield
189191
let var"'" = bwd

0 commit comments

Comments
 (0)