|
84 | 84 | y1 = m(x) # before
|
85 | 85 |
|
86 | 86 | # Implicit gradient
|
87 |
| - gold = gradient(() -> m(x), Flux.params(m)) |
| 87 | + gold = Zygote.gradient(() -> m(x), Flux.params(m)) |
88 | 88 | @test gold isa Flux.Zygote.Grads
|
89 | 89 | @test_throws ErrorException Flux.update!(Flux.Adam(), m, gold) # friendly
|
90 | 90 | Flux.update!(Flux.Adam(), Flux.params(m), gold)
|
91 | 91 | y2 = m(x)
|
92 | 92 | @test y2 < y1
|
93 | 93 |
|
94 | 94 | # Explicit gradient
|
95 |
| - gs = gradient(marg -> marg(x), m) |
| 95 | + gs = Zygote.gradient(marg -> marg(x), m) |
96 | 96 | @test gs isa Tuple
|
97 | 97 | @test_throws ErrorException Flux.update!(Flux.Adam(), Flux.params(m), gs) # friendly
|
98 | 98 | @test_throws ErrorException Flux.update!(Flux.Adam(), Flux.params(m), gs[1]) # friendly
|
@@ -133,17 +133,20 @@ for (trainfn!, name) in ((Flux.train!, "Zygote"), (train_enzyme!, "Enzyme"))
|
133 | 133 | diff1 = model.weight .- init_weight
|
134 | 134 |
|
135 | 135 | # Take 2: the same, but with Flux.params. Was broken for a bit, no tests!
|
136 |
| - model.weight .= init_weight |
137 |
| - model.bias .= 0 |
138 |
| - pen2(x::AbstractArray) = sum(abs2, x)/2 |
139 |
| - opt = Flux.setup(Adam(0.1), model) |
140 |
| - trainfn!(model, data, opt) do m, x, y |
141 |
| - err = Flux.mse(m(x), y) |
142 |
| - l2 = sum(pen2, Flux.params(m)) |
143 |
| - err + 0.33 * l2 |
| 136 | + # skipping this test for Enzyme cause implicit params is unsupported |
| 137 | + if name == "Zygote" |
| 138 | + model.weight .= init_weight |
| 139 | + model.bias .= 0 |
| 140 | + pen2(x::AbstractArray) = sum(abs2, x)/2 |
| 141 | + opt = Flux.setup(Adam(0.1), model) |
| 142 | + trainfn!(model, data, opt) do m, x, y |
| 143 | + err = Flux.mse(m(x), y) |
| 144 | + l2 = sum(pen2, Flux.params(m)) |
| 145 | + err + 0.33 * l2 |
| 146 | + end |
| 147 | + diff2 = model.weight .- init_weight |
| 148 | + @test diff1 ≈ diff2 |
144 | 149 | end
|
145 |
| - diff2 = model.weight .- init_weight |
146 |
| - @test diff1 ≈ diff2 |
147 | 150 |
|
148 | 151 | # Take 3: using WeightDecay instead. Need the /2 above, to match exactly.
|
149 | 152 | model.weight .= init_weight
|
|
0 commit comments