Skip to content

Commit 8b8eebc

Browse files
committed
Switch to update! only
1 parent 6a284ba commit 8b8eebc

File tree

2 files changed

+2
-3
lines changed

2 files changed

+2
-3
lines changed

src/optimise/Optimise.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ import Zygote
66
import Zygote: Params, gradient
77
using AbstractDifferentiation
88
import Optimisers
9-
import Optimisers: update, update!
9+
import Optimisers: update!
1010
using LinearAlgebra
1111
import ArrayInterface
1212
using ProgressLogging: @progress, @withprogress, @logprogress

src/optimise/train.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@ function Optimisers.update!(opt::AbstractOptimiser, xs::Params, gs)
2424

2525
return opt, xs
2626
end
27-
Optimisers.update(opt::AbstractOptimiser, xs::Params, gs) = update!(opt, xs, gs)
2827

2928
# Callback niceties
3029
call(f, xs...) = f(xs...)
@@ -139,7 +138,7 @@ function train!(loss, ad::AD.AbstractBackend, model, data, optstate; cb = () ->
139138
try
140139
_loss = _build_loss(ad, loss, batchmemaybe(d))
141140
gs = _gradient_only(AD.gradient(ad, _loss, model))
142-
optstate, model = update(optstate, model, gs)
141+
optstate, model = update!(optstate, model, gs)
143142
cb()
144143
catch ex
145144
if ex isa StopException

0 commit comments

Comments
 (0)