Skip to content

Commit f16cb00

Browse files
committed
fixup
1 parent fc3bf40 commit f16cb00

File tree

2 files changed

+7
-8
lines changed

2 files changed

+7
-8
lines changed

src/stage1/forward.jl

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -119,10 +119,9 @@ function (::∂☆internal{1})(args::AbstractTangentBundle{1}...)
119119
end
120120

121121
function ChainRulesCore.frule_via_ad(::DiffractorRuleConfig, partials, args...)
122-
tangents = map(partials, args) do p, a
123-
TangentBundle{1}(a, (p,))
124-
end
125-
∂☆internal{1}()(tangents...)
122+
bundles = map((p,a) -> TangentBundle{1}(a, (p,)), partials, args)
123+
result = ∂☆internal{1}()(bundles...)
124+
primal(result), first_partial(result)
126125
end
127126

128127
function (::∂☆shuffle{N})(args::AbstractTangentBundle{N}...) where {N}

test/runtests.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -198,15 +198,15 @@ end
198198

199199
# PR #43
200200
loss(res, z, w) = sum(res.U * Diagonal(res.S) * res.V) + sum(res.S .* w)
201-
x = rand(10, 10)
202-
@test Diffractor.gradient(x->loss(svd(x), x[:,1], x[:,2]), x) isa Tuple{Matrix{Float64}}
201+
x43 = rand(10, 10)
202+
@test Diffractor.gradient(x->loss(svd(x), x[:,1], x[:,2]), x43) isa Tuple{Matrix{Float64}}
203203

204204
# PR # 45 - Calling back into AD from ChainRules
205-
y45, back45 = rrule_via_ad(DiffractorRuleConfig(), x -> log(exp(x)), 2.0)
205+
y45, back45 = rrule_via_ad(DiffractorRuleConfig(), x -> log(exp(x)), 2)
206206
@test y45 2.0
207207
@test back45(1) == (ZeroTangent(), 1.0)
208208

209-
z45, delta45 = frule_via_ad(DiffractorRuleConfig(), (0,1), x -> log(exp(x)), 2.0)
209+
z45, delta45 = frule_via_ad(DiffractorRuleConfig(), (0,1), x -> log(exp(x)), 2)
210210
@test z45 2.0
211211
@test delta45 1.0
212212

0 commit comments

Comments
 (0)