Skip to content

Commit e2cfd85

Browse files
authored
Merge pull request #45 from mcabbott/ruleconfig
Add `RuleConfig`
2 parents 0a98045 + 48460e7 commit e2cfd85

File tree

4 files changed

+26
-6
lines changed

4 files changed

+26
-6
lines changed

src/runtime.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
using ChainRulesCore
2+
struct DiffractorRuleConfig <: RuleConfig{Union{HasReverseMode,HasForwardsMode}} end
23

34
@Base.constprop :aggressive accum(a, b) = a + b
45
@Base.constprop :aggressive accum(a::Tuple, b::Tuple) = map(accum, a, b)

src/stage1/forward.jl

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ first_partial(x::CompositeBundle) = map(first_partial, getfield(x, :tup))
1313

1414
# TODO: Which version do we want in ChainRules?
1515
function my_frule(args::ATB{1}...)
16-
frule(map(first_partial, args), map(primal, args)...)
16+
frule(DiffractorRuleConfig(), map(first_partial, args), map(primal, args)...)
1717
end
1818

1919
# Fast path for some hot cases
@@ -118,6 +118,12 @@ function (::∂☆internal{1})(args::AbstractTangentBundle{1}...)
118118
end
119119
end
120120

121+
function ChainRulesCore.frule_via_ad(::DiffractorRuleConfig, partials, args...)
122+
bundles = map((p,a) -> TangentBundle{1}(a, (p,)), partials, args)
123+
result = ∂☆internal{1}()(bundles...)
124+
primal(result), first_partial(result)
125+
end
126+
121127
function (::∂☆shuffle{N})(args::AbstractTangentBundle{N}...) where {N}
122128
∂☆p = ∂☆{minus1(N)}()
123129
∂☆p(ZeroBundle{minus1(N)}(my_frule), map(shuffle_down, args)...)

src/stage1/generated.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,7 @@ function (::∂⃖{N})(f::T, args...) where {T, N}
210210
if N == 1
211211
# Base case (inlined to avoid ambiguities with manually specified
212212
# higher order rules)
213-
z = rrule(f, args...)
213+
z = rrule(DiffractorRuleConfig(), f, args...)
214214
if z === nothing
215215
return ∂⃖recurse{1}()(f, args...)
216216
end
@@ -226,6 +226,10 @@ function (::∂⃖{N})(f::T, args...) where {T, N}
226226
end
227227
end
228228

229+
function ChainRulesCore.rrule_via_ad(::DiffractorRuleConfig, f::T, args...) where {T}
230+
∂⃖{1}()(f, args...)
231+
end
232+
229233
@Base.pure function (::∂⃖{1})(::typeof(Core.apply_type), head, args...)
230234
return rrule(Core.apply_type, head, args...)
231235
end

test/runtests.jl

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
using Diffractor
2-
using Diffractor: var"'", ∂⃖
2+
using Diffractor: var"'", ∂⃖, DiffractorRuleConfig
33
using ChainRules
44
using ChainRulesCore
5-
using ChainRules: ZeroTangent, NoTangent
5+
using ChainRulesCore: ZeroTangent, NoTangent, frule_via_ad, rrule_via_ad
66
using Symbolics
77
using LinearAlgebra
88

@@ -198,7 +198,16 @@ end
198198

199199
# PR #43
200200
loss(res, z, w) = sum(res.U * Diagonal(res.S) * res.V) + sum(res.S .* w)
201-
x = rand(10, 10)
202-
@test Diffractor.gradient(x->loss(svd(x), x[:,1], x[:,2]), x) isa Tuple{Matrix{Float64}}
201+
x43 = rand(10, 10)
202+
@test Diffractor.gradient(x->loss(svd(x), x[:,1], x[:,2]), x43) isa Tuple{Matrix{Float64}}
203+
204+
# PR # 45 - Calling back into AD from ChainRules
205+
y45, back45 = rrule_via_ad(DiffractorRuleConfig(), x -> log(exp(x)), 2)
206+
@test y45 2.0
207+
@test back45(1) == (ZeroTangent(), 1.0)
208+
209+
z45, delta45 = frule_via_ad(DiffractorRuleConfig(), (0,1), x -> log(exp(x)), 2)
210+
@test z45 2.0
211+
@test delta45 1.0
203212

204213
include("pinn.jl")

0 commit comments

Comments
 (0)