Skip to content

Commit c8cc053

Browse files
committed
add tests, frule is broken
1 parent 276eead commit c8cc053

File tree

1 file changed

+11
-2
lines changed

1 file changed

+11
-2
lines changed

test/runtests.jl

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
using Diffractor
2-
using Diffractor: var"'", ∂⃖
2+
using Diffractor: var"'", ∂⃖, DiffractorRuleConfig
33
using ChainRules
44
using ChainRulesCore
5-
using ChainRules: ZeroTangent, NoTangent
5+
using ChainRules: ZeroTangent, NoTangent, frule_via_ad, rrule_via_ad
66
using Symbolics
77
using LinearAlgebra
88

@@ -201,4 +201,13 @@ loss(res, z, w) = sum(res.U * Diagonal(res.S) * res.V) + sum(res.S .* w)
201201
x = rand(10, 10)
202202
@test Diffractor.gradient(x->loss(svd(x), x[:,1], x[:,2]), x) isa Tuple{Matrix{Float64}}
203203

204+
# PR # 45 - Calling back into AD from ChainRules
205+
y45, back45 = rrule_via_ad(DiffractorRuleConfig(), x -> log(exp(x)), 2.0)
206+
@test y45 2.0
207+
@test back45(1) == (ZeroTangent(), 1.0)
208+
209+
# z45, delta45 = frule_via_ad(DiffractorRuleConfig(), (0,1), x -> log(exp(x)), 2.0)
210+
# @test z45 ≈ 2.0
211+
# @test delta45 ≈ 1.0
212+
204213
include("pinn.jl")

0 commit comments

Comments
 (0)