From 7cd29c0835f6988521dd640188268db2da00a4ba Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Fri, 12 Aug 2022 22:18:33 -0700 Subject: [PATCH 1/4] fix rrule_via_ad return type --- src/stage1/generated.jl | 2 +- test/runtests.jl | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/src/stage1/generated.jl b/src/stage1/generated.jl index 2f8f48fd..19cfc8c8 100644 --- a/src/stage1/generated.jl +++ b/src/stage1/generated.jl @@ -227,7 +227,7 @@ function (::∂⃖{N})(f::T, args...) where {T, N} end function ChainRulesCore.rrule_via_ad(::DiffractorRuleConfig, f::T, args...) where {T} - ∂⃖{1}()(f, args...) + ∂⃖{1}()(f, args...) |> Tuple{<:Any, <:Any} end @Base.pure function (::∂⃖{1})(::typeof(Core.apply_type), head, args...) diff --git a/test/runtests.jl b/test/runtests.jl index 4b1832e6..570c78a9 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -208,7 +208,9 @@ x43 = rand(10, 10) @test Diffractor.gradient(x->loss(svd(x), x[:,1], x[:,2]), x43) isa Tuple{Matrix{Float64}} # PR # 45 - Calling back into AD from ChainRules -y45, back45 = rrule_via_ad(DiffractorRuleConfig(), x -> log(exp(x)), 2) +r45 = rrule_via_ad(DiffractorRuleConfig(), x -> log(exp(x)), 2) +@test r45 isa Tuple +y45, back45 = r45 @test y45 ≈ 2.0 @test back45(1) == (ZeroTangent(), 1.0) From c1e4006dd576213b17a9a17f6ead520e53c699b5 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Thu, 18 Aug 2022 17:10:46 -0400 Subject: [PATCH 2/4] Update src/stage1/generated.jl Co-authored-by: Jeff Bezanson --- src/stage1/generated.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/stage1/generated.jl b/src/stage1/generated.jl index 19cfc8c8..1f03d859 100644 --- a/src/stage1/generated.jl +++ b/src/stage1/generated.jl @@ -227,7 +227,7 @@ function (::∂⃖{N})(f::T, args...) where {T, N} end function ChainRulesCore.rrule_via_ad(::DiffractorRuleConfig, f::T, args...) where {T} - ∂⃖{1}()(f, args...) |> Tuple{<:Any, <:Any} + ∂⃖{1}()(f, args...) |> Tuple{Any, Any} end @Base.pure function (::∂⃖{1})(::typeof(Core.apply_type), head, args...) From af13bf89b68a4f5f6a24f908f2706d8c00a1c3d9 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Thu, 18 Aug 2022 17:12:56 -0400 Subject: [PATCH 3/4] add length method --- src/stage1/generated.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/stage1/generated.jl b/src/stage1/generated.jl index 1f03d859..9c7261cd 100644 --- a/src/stage1/generated.jl +++ b/src/stage1/generated.jl @@ -52,6 +52,7 @@ Base.getindex(o::OpticBundle, i::Int) = i == 1 ? o.x : Base.iterate(o::OpticBundle) = (o.x, nothing) Base.iterate(o::OpticBundle, ::Nothing) = (o.clos, missing) Base.iterate(o::OpticBundle, ::Missing) = nothing +Base.length(o::OpticBundle) = 2 # Desturucture using `getfield` rather than iterate to make # inference happier From e6d15eb0873212ef87a8bb68b23f1481b8480466 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Fri, 26 Aug 2022 20:55:11 -0400 Subject: [PATCH 4/4] Update src/stage1/generated.jl --- src/stage1/generated.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/stage1/generated.jl b/src/stage1/generated.jl index 9c7261cd..5ba3588d 100644 --- a/src/stage1/generated.jl +++ b/src/stage1/generated.jl @@ -228,7 +228,7 @@ function (::∂⃖{N})(f::T, args...) where {T, N} end function ChainRulesCore.rrule_via_ad(::DiffractorRuleConfig, f::T, args...) where {T} - ∂⃖{1}()(f, args...) |> Tuple{Any, Any} + Tuple{Any, Any}(∂⃖{1}()(f, args...)) end @Base.pure function (::∂⃖{1})(::typeof(Core.apply_type), head, args...)