Skip to content

Commit 4fca843

Browse files
Fix rrule_via_ad return type (#86)
* fix rrule_via_ad return type * Update src/stage1/generated.jl Co-authored-by: Jeff Bezanson <jeff.bezanson@gmail.com> * add length method * Update src/stage1/generated.jl Co-authored-by: Jeff Bezanson <jeff.bezanson@gmail.com>
1 parent 018510d commit 4fca843

File tree

2 files changed

+5
-2
lines changed

2 files changed

+5
-2
lines changed

src/stage1/generated.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ Base.getindex(o::OpticBundle, i::Int) = i == 1 ? o.x :
5252
Base.iterate(o::OpticBundle) = (o.x, nothing)
5353
Base.iterate(o::OpticBundle, ::Nothing) = (o.clos, missing)
5454
Base.iterate(o::OpticBundle, ::Missing) = nothing
55+
Base.length(o::OpticBundle) = 2
5556

5657
# Desturucture using `getfield` rather than iterate to make
5758
# inference happier
@@ -227,7 +228,7 @@ function (::∂⃖{N})(f::T, args...) where {T, N}
227228
end
228229

229230
function ChainRulesCore.rrule_via_ad(::DiffractorRuleConfig, f::T, args...) where {T}
230-
∂⃖{1}()(f, args...)
231+
Tuple{Any, Any}(∂⃖{1}()(f, args...))
231232
end
232233

233234
@Base.pure function (::∂⃖{1})(::typeof(Core.apply_type), head, args...)

test/runtests.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,9 @@ x43 = rand(10, 10)
208208
@test Diffractor.gradient(x->loss(svd(x), x[:,1], x[:,2]), x43) isa Tuple{Matrix{Float64}}
209209

210210
# PR # 45 - Calling back into AD from ChainRules
211-
y45, back45 = rrule_via_ad(DiffractorRuleConfig(), x -> log(exp(x)), 2)
211+
r45 = rrule_via_ad(DiffractorRuleConfig(), x -> log(exp(x)), 2)
212+
@test r45 isa Tuple
213+
y45, back45 = r45
212214
@test y45 2.0
213215
@test back45(1) == (ZeroTangent(), 1.0)
214216

0 commit comments

Comments
 (0)