Skip to content

Commit 98394e1

Browse files
committed
divide penalty by n_batches
1 parent 61c3801 commit 98394e1

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

src/core.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ function train!(loss, penalty, chain, optimiser, X, y)
3636
parameters = Flux.params(chain)
3737
gs = Flux.gradient(parameters) do
3838
yhat = chain(X[i])
39-
batch_loss = loss(yhat, y[i]) + penalty(parameters)
39+
batch_loss = loss(yhat, y[i]) + penalty(parameters)/n_batches
4040
training_loss += batch_loss
4141
return batch_loss
4242
end
@@ -96,7 +96,7 @@ function fit!(loss, penalty, chain, optimiser, epochs, verbosity, X, y)
9696

9797
parameters = Flux.params(chain)
9898
losses = (loss(chain(X[i]), y[i]) +
99-
penalty(parameters) for i in 1:n_batches)
99+
penalty(parameters)/n_batches for i in 1:n_batches)
100100
history = [mean(losses),]
101101

102102
for i in 1:epochs

0 commit comments

Comments
 (0)