Skip to content

Commit e3bb529

Browse files
committed
add missing unthunks to tests
1 parent 2df22ea commit e3bb529

File tree

1 file changed

+17
-10
lines changed

1 file changed

+17
-10
lines changed

test/testers.jl

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -33,12 +33,12 @@ Base.iterate(f::Foo, state) = iterate(f.a, state)
3333
function ChainRulesCore.rrule(::Type{Foo}, a)
3434
foo = Foo(a)
3535
function Foo_pullback(Δfoo)
36-
return NoTangent(), Δfoo.a
36+
return NoTangent(), unthunk(Δfoo).a
3737
end
3838
return foo, Foo_pullback
3939
end
4040
function ChainRulesCore.frule((_, Δa), ::Type{Foo}, a)
41-
return Foo(a), Foo(Δa)
41+
return Foo(a), Foo(unthunk(Δa))
4242
end
4343

4444
# functor
@@ -49,7 +49,8 @@ function ChainRulesCore.rrule(f::Foo, x)
4949
end
5050
return y, Foo_pullback
5151
end
52-
function ChainRulesCore.frule((Δf, Δx), f::Foo, x)
52+
function ChainRulesCore.frule((Δf_, Δx), f::Foo, x)
53+
Δf = unthunk(Δf_)
5354
return f(x), Δf.a + Δx
5455
end
5556

@@ -158,7 +159,7 @@ struct MySpecialConfig <: RuleConfig{Union{MySpecialTrait}} end
158159

159160
@testset "check not inferred in frule" begin
160161
function ChainRulesCore.frule((_, Δx), ::typeof(f_noninferrable_frule), x)
161-
return (x, x > 0 ? Float64(Δx) : Float32(Δx))
162+
return (x, x > 0 ? Float64(unthunk(Δx)) : Float32(unthunk(Δx)))
162163
end
163164
function ChainRulesCore.rrule(::typeof(f_noninferrable_frule), x)
164165
f_noninferrable_frule_pullback(Δy) = (NoTangent(), Δy)
@@ -205,7 +206,7 @@ struct MySpecialConfig <: RuleConfig{Union{MySpecialTrait}} end
205206
@testset "check not inferred in pullback" begin
206207
function ChainRulesCore.rrule(::typeof(f_noninferrable_pullback), x)
207208
function f_noninferrable_pullback_pullback(Δy)
208-
return (NoTangent(), x > 0 ? Float64(Δy) : Float32(Δy))
209+
return (NoTangent(), (x > 0 ? Float64 : Float32)(unthunk(Δy)))
209210
end
210211
return x, f_noninferrable_pullback_pullback
211212
end
@@ -219,7 +220,7 @@ struct MySpecialConfig <: RuleConfig{Union{MySpecialTrait}} end
219220
@testset "check not inferred in thunk" begin
220221
function ChainRulesCore.rrule(::typeof(f_noninferrable_thunk), x, y)
221222
function f_noninferrable_thunk_pullback(Δz)
222-
∂x = @thunk(x > 0 ? Float64(Δz) : Float32(Δz))
223+
∂x = @thunk(x > 0 ? Float64(unthunk(Δz)) : Float32(unthunk(Δz)))
223224
return (NoTangent(), ∂x, Δz)
224225
end
225226
return x + y, f_noninferrable_thunk_pullback
@@ -233,10 +234,13 @@ struct MySpecialConfig <: RuleConfig{Union{MySpecialTrait}} end
233234

234235
@testset "check non-inferrable primal still passes if pullback inferrable" begin
235236
function ChainRulesCore.frule((_, Δx), ::typeof(f_inferrable_pullback_only), x)
236-
return (x > 0 ? Float64(x) : Float32(x), x > 0 ? Float64(Δx) : Float32(Δx))
237+
T = x > 0 ? Float64 : Float32
238+
return T(x), T(unthunk(Δx))
237239
end
238240
function ChainRulesCore.rrule(::typeof(f_inferrable_pullback_only), x)
239-
f_inferrable_pullback_only_pullback(Δy) = (NoTangent(), oftype(x, Δy))
241+
function f_inferrable_pullback_only_pullback(Δy)
242+
return NoTangent(), oftype(x, unthunk(Δy))
243+
end
240244
return x > 0 ? Float64(x) : Float32(x), f_inferrable_pullback_only_pullback
241245
end
242246
test_frule(f_inferrable_pullback_only, 2.0; check_inferred=true)
@@ -441,7 +445,9 @@ struct MySpecialConfig <: RuleConfig{Union{MySpecialTrait}} end
441445
= [4.0, 5.0, 6.0]
442446
xcopy, ẋcopy = copy(x), copy(ẋ)
443447
y = [1, 2]
444-
test_frule(finplace!, x ẋ; fkwargs=(y=y,))
448+
# Don't test tangent transforms, we do not support thunks for mutating frules
449+
# TODO: Should we disable testing thunks for frules in general
450+
test_frule(finplace!, x ẋ; fkwargs=(y=y,), tangent_transforms=[])
445451
@test x == xcopy
446452
@test== ẋcopy
447453
@test y == [1, 2]
@@ -462,7 +468,8 @@ struct MySpecialConfig <: RuleConfig{Union{MySpecialTrait}} end
462468
return s
463469
end
464470

465-
function ChainRulesCore.frule((_, Δiter), ::typeof(iterfun), iter)
471+
function ChainRulesCore.frule((_, Δiter_), ::typeof(iterfun), iter)
472+
Δiter = unthunk(Δiter_)
466473
iter_Δiter = zip(iter, Δiter)
467474
state = iterate(iter_Δiter)
468475
state === nothing && error()

0 commit comments

Comments
 (0)