Skip to content

Commit f4e300c

Browse files
committed
kwargs
1 parent b7654bd commit f4e300c

File tree

1 file changed

+10
-6
lines changed

1 file changed

+10
-6
lines changed

test/train.jl

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@ using Test
66
using Random
77
using Enzyme
88

9-
function train_enzyme!(fn, model, args...)
10-
Flux.train!(fn, Duplicated(model, Enzyme.make_zero(model)), args...)
9+
function train_enzyme!(fn, model, args...; kwargs...)
10+
Flux.train!(fn, Duplicated(model, Enzyme.make_zero(model)), args...; kwargs...)
1111
end
1212

1313
for (trainfn!, name) in ((Flux.train!, "Zygote"), (train_enzyme!, "Enzyme"))
@@ -47,13 +47,17 @@ for (trainfn!, name) in ((Flux.train!, "Zygote"), (train_enzyme!, "Enzyme"))
4747
@testset "Stop on NaN" begin
4848
m1 = Dense(1 => 1)
4949
m1.weight .= 0
50-
CNT = 0
50+
CNT = Ref(0)
5151
@test_throws DomainError trainfn!(m1, tuple.(1:100), Descent(0.1)) do m, i
52-
CNT += 1
52+
CNT[] += 1
5353
(i == 51 ? NaN32 : 1f0) * sum(m([1.0]))
5454
end
55-
@test CNT == 51 # stopped early
56-
@test m1.weight[1] -5 # did not corrupt weights
55+
@test CNT[] == 51 # stopped early
56+
if name != "Enzyme"
57+
@test m1.weight[1] 0.0 # did not corrupt weights
58+
else
59+
@test m1.weight[1] -5 # did not corrupt weights
60+
end
5761
end
5862

5963
@testset "non-tuple data" begin

0 commit comments

Comments
 (0)