Skip to content

Commit fc3bf40

Browse files
committed
closer?
1 parent c8cc053 commit fc3bf40

File tree

2 files changed

+9
-4
lines changed

2 files changed

+9
-4
lines changed

src/stage1/forward.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,12 @@ function (::∂☆internal{1})(args::AbstractTangentBundle{1}...)
118118
end
119119
end
120120

121-
ChainRulesCore.frule_via_ad(::DiffractorRuleConfig, args...) = ∂☆internal{1}()(args...)
121+
function ChainRulesCore.frule_via_ad(::DiffractorRuleConfig, partials, args...)
122+
tangents = map(partials, args) do p, a
123+
TangentBundle{1}(a, (p,))
124+
end
125+
∂☆internal{1}()(tangents...)
126+
end
122127

123128
function (::∂☆shuffle{N})(args::AbstractTangentBundle{N}...) where {N}
124129
∂☆p = ∂☆{minus1(N)}()

test/runtests.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -206,8 +206,8 @@ y45, back45 = rrule_via_ad(DiffractorRuleConfig(), x -> log(exp(x)), 2.0)
206206
@test y45 2.0
207207
@test back45(1) == (ZeroTangent(), 1.0)
208208

209-
# z45, delta45 = frule_via_ad(DiffractorRuleConfig(), (0,1), x -> log(exp(x)), 2.0)
210-
# @test z45 ≈ 2.0
211-
# @test delta45 ≈ 1.0
209+
z45, delta45 = frule_via_ad(DiffractorRuleConfig(), (0,1), x -> log(exp(x)), 2.0)
210+
@test z45 2.0
211+
@test delta45 1.0
212212

213213
include("pinn.jl")

0 commit comments

Comments
 (0)