Skip to content

Commit a66f9a5

Browse files
authored
ensure that update without a gradient is an error (#50)
1 parent 07c16ee commit a66f9a5

File tree

2 files changed

+13
-3
lines changed

2 files changed

+13
-3
lines changed

src/interface.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,16 +23,16 @@ end
2323

2424
subtract!(x, x̄) = iswriteable(x) ? (x .= x .- x̄) : (x .- x̄)
2525

26-
update!(::Nothing, x, ::Zero...) = nothing, x
26+
update!(::Nothing, x, ::Zero, ::Zero...) = nothing, x
2727
update!(::Nothing, x, x̄s...) = nothing, x
2828

29-
update!(ℓ::Leaf, x, ::Zero...) = ℓ, x
29+
update!(ℓ::Leaf, x, ::Zero, ::Zero...) = ℓ, x
3030
function update!(ℓ::Leaf, x, x̄s...)
3131
s′, x̄′ = apply!(ℓ.rule, ℓ.state, x, base.(x̄s)...)
3232
Leaf(ℓ.rule, s′), subtract!(x, x̄′)
3333
end
3434

35-
update!(tree, x, ::Zero...) = tree, x
35+
update!(tree, x, ::Zero, ::Zero...) = tree, x
3636
function update!(tree, x, x̄s...)
3737
x̄s′ = map(x̄ -> functor(typeof(x), base(x̄))[1], x̄s)
3838
x′, re = functor(typeof(x), x)

test/runtests.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,16 @@ Optimisers.trainable(x::TwoThirds) = (a = x.a,)
116116
@test Optimisers.update!(s, m, g...)[2] isa Foo
117117
end
118118

119+
@testset "forgotten gradient" begin
120+
x = [1.0, 2.0]
121+
sx = Optimisers.setup(Descent(), x)
122+
@test_throws MethodError Optimisers.update(sx, x)
123+
124+
m = (x = x, y = sin)
125+
sm = Optimisers.setup(Descent(), m)
126+
@test_throws MethodError Optimisers.update(sm, m)
127+
end
128+
119129
@testset "broadcasting macros" begin
120130
x = [1.0, 2.0]; y = [3,4]; z = [5,6]
121131
@test (@lazy x + y * z) isa Broadcast.Broadcasted

0 commit comments

Comments
 (0)