diff --git a/src/stage1/generated.jl b/src/stage1/generated.jl index 2f8f48fd..5ba3588d 100644 --- a/src/stage1/generated.jl +++ b/src/stage1/generated.jl @@ -52,6 +52,7 @@ Base.getindex(o::OpticBundle, i::Int) = i == 1 ? o.x : Base.iterate(o::OpticBundle) = (o.x, nothing) Base.iterate(o::OpticBundle, ::Nothing) = (o.clos, missing) Base.iterate(o::OpticBundle, ::Missing) = nothing +Base.length(o::OpticBundle) = 2 # Desturucture using `getfield` rather than iterate to make # inference happier @@ -227,7 +228,7 @@ function (::∂⃖{N})(f::T, args...) where {T, N} end function ChainRulesCore.rrule_via_ad(::DiffractorRuleConfig, f::T, args...) where {T} - ∂⃖{1}()(f, args...) + Tuple{Any, Any}(∂⃖{1}()(f, args...)) end @Base.pure function (::∂⃖{1})(::typeof(Core.apply_type), head, args...) diff --git a/test/runtests.jl b/test/runtests.jl index 4b1832e6..570c78a9 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -208,7 +208,9 @@ x43 = rand(10, 10) @test Diffractor.gradient(x->loss(svd(x), x[:,1], x[:,2]), x43) isa Tuple{Matrix{Float64}} # PR # 45 - Calling back into AD from ChainRules -y45, back45 = rrule_via_ad(DiffractorRuleConfig(), x -> log(exp(x)), 2) +r45 = rrule_via_ad(DiffractorRuleConfig(), x -> log(exp(x)), 2) +@test r45 isa Tuple +y45, back45 = r45 @test y45 ≈ 2.0 @test back45(1) == (ZeroTangent(), 1.0)