Skip to content

Commit e0d3cbe

Browse files
authored
revert changes to frules
1 parent e3bb529 commit e0d3cbe

File tree

1 file changed

+4
-6
lines changed

1 file changed

+4
-6
lines changed

test/testers.jl

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,7 @@ function ChainRulesCore.rrule(f::Foo, x)
4949
end
5050
return y, Foo_pullback
5151
end
52-
function ChainRulesCore.frule((Δf_, Δx), f::Foo, x)
53-
Δf = unthunk(Δf_)
52+
function ChainRulesCore.frule((Δf, Δx), f::Foo, x)
5453
return f(x), Δf.a + Δx
5554
end
5655

@@ -159,7 +158,7 @@ struct MySpecialConfig <: RuleConfig{Union{MySpecialTrait}} end
159158

160159
@testset "check not inferred in frule" begin
161160
function ChainRulesCore.frule((_, Δx), ::typeof(f_noninferrable_frule), x)
162-
return (x, x > 0 ? Float64(unthunk(Δx)) : Float32(unthunk(Δx)))
161+
return (x, x > 0 ? Float64(Δx) : Float32(Δx))
163162
end
164163
function ChainRulesCore.rrule(::typeof(f_noninferrable_frule), x)
165164
f_noninferrable_frule_pullback(Δy) = (NoTangent(), Δy)
@@ -235,7 +234,7 @@ struct MySpecialConfig <: RuleConfig{Union{MySpecialTrait}} end
235234
@testset "check non-inferrable primal still passes if pullback inferrable" begin
236235
function ChainRulesCore.frule((_, Δx), ::typeof(f_inferrable_pullback_only), x)
237236
T = x > 0 ? Float64 : Float32
238-
return T(x), T(unthunk(Δx))
237+
return T(x), T(Δx)
239238
end
240239
function ChainRulesCore.rrule(::typeof(f_inferrable_pullback_only), x)
241240
function f_inferrable_pullback_only_pullback(Δy)
@@ -468,8 +467,7 @@ struct MySpecialConfig <: RuleConfig{Union{MySpecialTrait}} end
468467
return s
469468
end
470469

471-
function ChainRulesCore.frule((_, Δiter_), ::typeof(iterfun), iter)
472-
Δiter = unthunk(Δiter_)
470+
function ChainRulesCore.frule((_, Δiter), ::typeof(iterfun), iter)
473471
iter_Δiter = zip(iter, Δiter)
474472
state = iterate(iter_Δiter)
475473
state === nothing && error()

0 commit comments

Comments
 (0)