Skip to content

Commit 4c38c8a

Browse files
authored
Stop training on Inf/NaN loss (#2070)
* stop training on Inf/NaN loss * add a test * improve test * improve test * Update train.jl * Update optimise.jl
1 parent 090f043 commit 4c38c8a

File tree

2 files changed

+20
-3
lines changed

2 files changed

+20
-3
lines changed

src/optimise/train.jl

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
using ProgressLogging: @progress, @withprogress, @logprogress
2-
import Zygote: Params, gradient
2+
import Zygote: Params, gradient, withgradient
33

44

55
"""
@@ -105,8 +105,10 @@ The optimiser should be from the `Flux.Optimise` module (see [Optimisers](@ref))
105105
Different optimisers can be combined using [`Flux.Optimise.Optimiser`](@ref Flux.Optimiser).
106106
107107
This training loop iterates through `data` once.
108+
It will stop with a `DomainError` if the loss is `NaN` or infinite.
109+
108110
You can use [`@epochs`](@ref) to do this several times, or
109-
use for instance `Iterators.repeat` to make a longer `data` iterator.
111+
use for instance `Itertools.ncycle` to make a longer `data` iterator.
110112
111113
## Callbacks
112114
@@ -126,9 +128,12 @@ function train!(loss, ps::Params, data, opt::AbstractOptimiser; cb = () -> ())
126128
n = (itrsz == Base.HasLength()) || (itrsz == Base.HasShape{1}()) ? length(data) : 0
127129
@withprogress for (i, d) in enumerate(data)
128130
try
129-
gs = gradient(ps) do
131+
l, gs = withgradient(ps) do
130132
loss(batchmemaybe(d)...)
131133
end
134+
if !isfinite(l)
135+
throw(DomainError("Loss is $l on data item $i, stopping training"))
136+
end
132137
update!(opt, ps, gs)
133138
cb()
134139
catch ex

test/optimise.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,18 @@ end
8787
Flux.train!(loss, Flux.params(r), (r,), Descent())
8888
end
8989

90+
@testset "Stop on NaN" begin
91+
m = Dense(1 => 1)
92+
m.weight .= 0
93+
CNT = 0
94+
@test_throws DomainError Flux.train!(Flux.params(m), 1:100, Descent(0.1)) do i
95+
CNT += 1
96+
(i == 51 ? NaN32 : 1f0) * sum(m([1.0]))
97+
end
98+
@test CNT == 51 # stopped early
99+
@test m.weight[1] -5 # did not corrupt weights
100+
end
101+
90102
@testset "ExpDecay" begin
91103

92104
@testset "Sanity Check" begin

0 commit comments

Comments
 (0)