Skip to content

Commit fc70f36

Browse files
committed
renable configured frules
1 parent 03e160f commit fc70f36

File tree

2 files changed

+19
-1
lines changed

2 files changed

+19
-1
lines changed

src/stage1/forward.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ function shuffle_base(r)
106106
end
107107

108108
function (::∂☆internal{1})(args::AbstractTangentBundle{1}...)
109-
r = frule(#=DiffractorRuleConfig(),=# map(first_partial, args), map(primal, args)...)
109+
r = frule(DiffractorRuleConfig(), map(first_partial, args), map(primal, args)...)
110110
if r === nothing
111111
return ∂☆recurse{1}()(args...)
112112
else

test/forward.jl

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,4 +184,22 @@ end
184184
)
185185
end
186186

187+
188+
@testset "configured frule" begin
189+
my_func(x) = sin(x)
190+
frule_hits = 0
191+
function ChainRulesCore.frule(config::RuleConfig{>:HasForwardsMode}, (_, dx), ::typeof(my_func), x)
192+
res=my_func(x)
193+
_, der_fwd = ChainRulesCore.frule_via_ad(config, (ChainRulesCore.NoTangent(), dx), sin, x)
194+
frule_hits +=1
195+
return res, der_fwd
196+
end
197+
198+
let var"'" = Diffractor.PrimeDerivativeFwd
199+
@assert frule_hits == 0
200+
@test my_func'(1.0) == cos(1.0)
201+
@test frule_hits == 1
202+
end
203+
end
204+
187205
end # module

0 commit comments

Comments
 (0)