@@ -33,12 +33,12 @@ Base.iterate(f::Foo, state) = iterate(f.a, state)
33
33
function ChainRulesCore. rrule (:: Type{Foo} , a)
34
34
foo = Foo (a)
35
35
function Foo_pullback (Δfoo)
36
- return NoTangent (), Δfoo. a
36
+ return NoTangent (), unthunk ( Δfoo) . a
37
37
end
38
38
return foo, Foo_pullback
39
39
end
40
40
function ChainRulesCore. frule ((_, Δa), :: Type{Foo} , a)
41
- return Foo (a), Foo (Δa )
41
+ return Foo (a), Foo (unthunk (Δa) )
42
42
end
43
43
44
44
# functor
@@ -49,7 +49,8 @@ function ChainRulesCore.rrule(f::Foo, x)
49
49
end
50
50
return y, Foo_pullback
51
51
end
52
- function ChainRulesCore. frule ((Δf, Δx), f:: Foo , x)
52
+ function ChainRulesCore. frule ((Δf_, Δx), f:: Foo , x)
53
+ Δf = unthunk (Δf_)
53
54
return f (x), Δf. a + Δx
54
55
end
55
56
@@ -158,7 +159,7 @@ struct MySpecialConfig <: RuleConfig{Union{MySpecialTrait}} end
158
159
159
160
@testset " check not inferred in frule" begin
160
161
function ChainRulesCore. frule ((_, Δx), :: typeof (f_noninferrable_frule), x)
161
- return (x, x > 0 ? Float64 (Δx) : Float32 (Δx ))
162
+ return (x, x > 0 ? Float64 (unthunk ( Δx)) : Float32 (unthunk (Δx) ))
162
163
end
163
164
function ChainRulesCore. rrule (:: typeof (f_noninferrable_frule), x)
164
165
f_noninferrable_frule_pullback (Δy) = (NoTangent (), Δy)
@@ -205,7 +206,7 @@ struct MySpecialConfig <: RuleConfig{Union{MySpecialTrait}} end
205
206
@testset " check not inferred in pullback" begin
206
207
function ChainRulesCore. rrule (:: typeof (f_noninferrable_pullback), x)
207
208
function f_noninferrable_pullback_pullback (Δy)
208
- return (NoTangent (), x > 0 ? Float64 (Δy) : Float32 (Δy ))
209
+ return (NoTangent (), ( x > 0 ? Float64 : Float32)( unthunk (Δy) ))
209
210
end
210
211
return x, f_noninferrable_pullback_pullback
211
212
end
@@ -219,7 +220,7 @@ struct MySpecialConfig <: RuleConfig{Union{MySpecialTrait}} end
219
220
@testset " check not inferred in thunk" begin
220
221
function ChainRulesCore. rrule (:: typeof (f_noninferrable_thunk), x, y)
221
222
function f_noninferrable_thunk_pullback (Δz)
222
- ∂x = @thunk (x > 0 ? Float64 (Δz) : Float32 (Δz ))
223
+ ∂x = @thunk (x > 0 ? Float64 (unthunk ( Δz)) : Float32 (unthunk (Δz) ))
223
224
return (NoTangent (), ∂x, Δz)
224
225
end
225
226
return x + y, f_noninferrable_thunk_pullback
@@ -233,10 +234,13 @@ struct MySpecialConfig <: RuleConfig{Union{MySpecialTrait}} end
233
234
234
235
@testset " check non-inferrable primal still passes if pullback inferrable" begin
235
236
function ChainRulesCore. frule ((_, Δx), :: typeof (f_inferrable_pullback_only), x)
236
- return (x > 0 ? Float64 (x) : Float32 (x), x > 0 ? Float64 (Δx) : Float32 (Δx))
237
+ T = x > 0 ? Float64 : Float32
238
+ return T (x), T (unthunk (Δx))
237
239
end
238
240
function ChainRulesCore. rrule (:: typeof (f_inferrable_pullback_only), x)
239
- f_inferrable_pullback_only_pullback (Δy) = (NoTangent (), oftype (x, Δy))
241
+ function f_inferrable_pullback_only_pullback (Δy)
242
+ return NoTangent (), oftype (x, unthunk (Δy))
243
+ end
240
244
return x > 0 ? Float64 (x) : Float32 (x), f_inferrable_pullback_only_pullback
241
245
end
242
246
test_frule (f_inferrable_pullback_only, 2.0 ; check_inferred= true )
@@ -441,7 +445,9 @@ struct MySpecialConfig <: RuleConfig{Union{MySpecialTrait}} end
441
445
ẋ = [4.0 , 5.0 , 6.0 ]
442
446
xcopy, ẋcopy = copy (x), copy (ẋ)
443
447
y = [1 , 2 ]
444
- test_frule (finplace!, x ⊢ ẋ; fkwargs= (y= y,))
448
+ # Don't test tangent transforms, we do not support thunks for mutating frules
449
+ # TODO : Should we disable testing thunks for frules in general
450
+ test_frule (finplace!, x ⊢ ẋ; fkwargs= (y= y,), tangent_transforms= [])
445
451
@test x == xcopy
446
452
@test ẋ == ẋcopy
447
453
@test y == [1 , 2 ]
@@ -462,7 +468,8 @@ struct MySpecialConfig <: RuleConfig{Union{MySpecialTrait}} end
462
468
return s
463
469
end
464
470
465
- function ChainRulesCore. frule ((_, Δiter), :: typeof (iterfun), iter)
471
+ function ChainRulesCore. frule ((_, Δiter_), :: typeof (iterfun), iter)
472
+ Δiter = unthunk (Δiter_)
466
473
iter_Δiter = zip (iter, Δiter)
467
474
state = iterate (iter_Δiter)
468
475
state === nothing && error ()
0 commit comments