@@ -49,8 +49,7 @@ function ChainRulesCore.rrule(f::Foo, x)
49
49
end
50
50
return y, Foo_pullback
51
51
end
52
- function ChainRulesCore. frule ((Δf_, Δx), f:: Foo , x)
53
- Δf = unthunk (Δf_)
52
+ function ChainRulesCore. frule ((Δf, Δx), f:: Foo , x)
54
53
return f (x), Δf. a + Δx
55
54
end
56
55
@@ -159,7 +158,7 @@ struct MySpecialConfig <: RuleConfig{Union{MySpecialTrait}} end
159
158
160
159
@testset " check not inferred in frule" begin
161
160
function ChainRulesCore. frule ((_, Δx), :: typeof (f_noninferrable_frule), x)
162
- return (x, x > 0 ? Float64 (unthunk ( Δx)) : Float32 (unthunk (Δx) ))
161
+ return (x, x > 0 ? Float64 (Δx) : Float32 (Δx ))
163
162
end
164
163
function ChainRulesCore. rrule (:: typeof (f_noninferrable_frule), x)
165
164
f_noninferrable_frule_pullback (Δy) = (NoTangent (), Δy)
@@ -235,7 +234,7 @@ struct MySpecialConfig <: RuleConfig{Union{MySpecialTrait}} end
235
234
@testset " check non-inferrable primal still passes if pullback inferrable" begin
236
235
function ChainRulesCore. frule ((_, Δx), :: typeof (f_inferrable_pullback_only), x)
237
236
T = x > 0 ? Float64 : Float32
238
- return T (x), T (unthunk (Δx) )
237
+ return T (x), T (Δx )
239
238
end
240
239
function ChainRulesCore. rrule (:: typeof (f_inferrable_pullback_only), x)
241
240
function f_inferrable_pullback_only_pullback (Δy)
@@ -468,8 +467,7 @@ struct MySpecialConfig <: RuleConfig{Union{MySpecialTrait}} end
468
467
return s
469
468
end
470
469
471
- function ChainRulesCore. frule ((_, Δiter_), :: typeof (iterfun), iter)
472
- Δiter = unthunk (Δiter_)
470
+ function ChainRulesCore. frule ((_, Δiter), :: typeof (iterfun), iter)
473
471
iter_Δiter = zip (iter, Δiter)
474
472
state = iterate (iter_Δiter)
475
473
state === nothing && error ()
0 commit comments