Skip to content

Commit 276eead

Browse files
committed
add RuleConfig
1 parent 0a98045 commit 276eead

File tree

3 files changed

+9
-2
lines changed

3 files changed

+9
-2
lines changed

src/runtime.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
using ChainRulesCore
2+
struct DiffractorRuleConfig <: RuleConfig{Union{HasReverseMode,HasForwardsMode}} end
23

34
@Base.constprop :aggressive accum(a, b) = a + b
45
@Base.constprop :aggressive accum(a::Tuple, b::Tuple) = map(accum, a, b)

src/stage1/forward.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ first_partial(x::CompositeBundle) = map(first_partial, getfield(x, :tup))
1313

1414
# TODO: Which version do we want in ChainRules?
1515
function my_frule(args::ATB{1}...)
16-
frule(map(first_partial, args), map(primal, args)...)
16+
frule(DiffractorRuleConfig(), map(first_partial, args), map(primal, args)...)
1717
end
1818

1919
# Fast path for some hot cases
@@ -118,6 +118,8 @@ function (::∂☆internal{1})(args::AbstractTangentBundle{1}...)
118118
end
119119
end
120120

121+
ChainRulesCore.frule_via_ad(::DiffractorRuleConfig, args...) = ∂☆internal{1}()(args...)
122+
121123
function (::∂☆shuffle{N})(args::AbstractTangentBundle{N}...) where {N}
122124
∂☆p = ∂☆{minus1(N)}()
123125
∂☆p(ZeroBundle{minus1(N)}(my_frule), map(shuffle_down, args)...)

src/stage1/generated.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,7 @@ function (::∂⃖{N})(f::T, args...) where {T, N}
210210
if N == 1
211211
# Base case (inlined to avoid ambiguities with manually specified
212212
# higher order rules)
213-
z = rrule(f, args...)
213+
z = rrule(DiffractorRuleConfig(), f, args...)
214214
if z === nothing
215215
return ∂⃖recurse{1}()(f, args...)
216216
end
@@ -226,6 +226,10 @@ function (::∂⃖{N})(f::T, args...) where {T, N}
226226
end
227227
end
228228

229+
function ChainRulesCore.rrule_via_ad(::DiffractorRuleConfig, f::T, args...) where {T}
230+
∂⃖{1}()(f, args...)
231+
end
232+
229233
@Base.pure function (::∂⃖{1})(::typeof(Core.apply_type), head, args...)
230234
return rrule(Core.apply_type, head, args...)
231235
end

0 commit comments

Comments
 (0)