Skip to content

Commit 9c22c11

Browse files
committed
remove train_autodiff macro
1 parent c20fb9e commit 9c22c11

File tree

2 files changed

+0
-132
lines changed

2 files changed

+0
-132
lines changed

src/train.jl

Lines changed: 0 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -90,8 +90,6 @@ The built-in loss functions accept 3 arguments, allowing for instance `train!(Fl
9090
9191
Callback functions are not supported. But see 3-argument `train!(loss, model, opt)` for an
9292
easy way to construct more complicated training loops.
93-
94-
To change the package used to calculate gradients, use [`Flux.@train_autodiff`](@ref).
9593
"""
9694
function train!(loss, model, data, opt)
9795
losses = Float32[]
@@ -144,8 +142,6 @@ for (i, d) in enumerate(data)
144142
end
145143
```
146144
147-
To change the package used to calculate gradients, use [`Flux.@train_autodiff`](@ref).
148-
149145
!!! note
150146
This method has no implicit `Params` analog in Flux ≤ 0.13.
151147
"""
@@ -178,56 +174,4 @@ end
178174

179175
explicit_withgradient(f, args...) = Zygote.withgradient(f, args...) # can overload this to use e.g. Yota / Diffractor
180176

181-
"""
182-
Flux.@train_autodiff Tracker
183-
Flux.@train_autodiff Zygote
184-
Flux.@train_autodiff Yota
185-
Flux.@train_autodiff Diffractor
186-
187-
This macro allows the use of `train!` with various automatic differentiation (AD) packages,
188-
instead of the default Zygote.jl.
189-
190-
You should load AD package, and then call this macro with the chosen name.
191-
The macro overwrites a method withing Flux, thus is a global setting, lasting until you re-start Julia.
192-
193-
Only works with [Yota.jl](https://github.com/dfdx/Yota.jl),
194-
[Tracker.jl](https://github.com/FluxML/Tracker.jl) (Flux's old AD),
195-
[Diffractor.jl](https://github.com/JuliaDiff/Diffractor.jl) (which is not yet registered),
196-
and with the default [Zygote.jl](https://github.com/FluxML/Zygote.jl).
197-
198-
!!! note
199-
This is mechanism is experimental! And there are known bugs, in particular Tracker will not automatically switch to training mode for `Dropout` etc.
200-
"""
201-
macro train_autodiff(pkg)
202-
if pkg == :Diffractor
203-
return quote
204-
Diffractor.gradient(sin, 0.0)[1] 1.0 # ensures an error if not loaded
205-
function Flux.Train.explicit_withgradient(f, args...)
206-
y, back = Diffractor.∂⃖¹(f, args...)
207-
dy1 = Flux.Zygote.sensitivity(y) # Zygote is loaded, and this gives nice errors
208-
return (; value = y, gradient = Base.tail(back(dy1)))
209-
end
210-
end |> esc
211-
elseif pkg == :Yota
212-
return quote
213-
Yota.grad(sin, 0.0) # [2][1] ≈ 1.0
214-
function Flux.Train.explicit_withgradient(f, args...)
215-
value, (_, gradient...) = Yota.grad(f, args...)
216-
return (; value, gradient)
217-
end
218-
end |> esc
219-
elseif pkg == :Tracker
220-
return quote
221-
Tracker.withgradient(sum, [1.0]).val == 1.0 # ensures an error if too-old version
222-
Flux.Train.explicit_withgradient(f, args...) = Tracker.withgradient(f, args...)
223-
end |> esc
224-
elseif pkg == :Zygote
225-
return quote
226-
Flux.Train.explicit_withgradient(f, args...) = Flux.Zygote.withgradient(f, args...)
227-
end |> esc
228-
else
229-
throw("@train_autodiff expects one of Tracker, Zygote, Yota, or Diffractor. No other arguments are understood.")
230-
end
231-
end
232-
233177
end # module

test/train.jl

Lines changed: 0 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -53,79 +53,3 @@ end
5353
# Test NaN / Inf early stop
5454
# Test that loss is returned
5555
end
56-
57-
import Tracker
58-
Flux.@train_autodiff Tracker
59-
60-
@testset "Explicit Flux.train! with Tracker" begin
61-
Random.seed!(84)
62-
w = randn(10, 10)
63-
w2 = randn(10, 10) # NB outside the inner @testset, else it will be exactly == w, as the RNG seed is reset.
64-
@testset for rule in [Descent(0.1), Adam(), AdamW()]
65-
66-
loss(m, x) = begin
67-
Flux.istraining() && error("This test is not in fact using Tracker!")
68-
Flux.Losses.mse(w*x, m.weight*x .+ m.bias)
69-
end
70-
model = (weight=copy(w2), bias=zeros(10), ignore=nothing)
71-
@test loss(model, rand(10, 10)) > 1
72-
73-
opt = Flux.setup(rule, model)
74-
Flux.train!(loss, model, ((rand(10),) for _ in 1: 10^5), opt)
75-
@test loss(model, rand(10, 10)) < 0.01
76-
end
77-
78-
# Test 3-arg `Flux.train!` method:
79-
@testset for rule in [Descent(0.1), Adam()]
80-
81-
loss(m) = let x = rand(10)
82-
Flux.istraining() && error("This test is not in fact using Tracker!")
83-
Flux.Losses.mse(w*x, m.weight*x .+ m.bias)
84-
end
85-
model = (weight=copy(w2), bias=zeros(10), ignore=nothing)
86-
@test loss(model) > 1
87-
88-
opt = Flux.setup(rule, model)
89-
for i in 1:10^5
90-
Flux.train!(loss, model, opt)
91-
end
92-
@test loss(model) < 0.01
93-
end
94-
end
95-
96-
import Yota
97-
Flux.@train_autodiff Yota
98-
99-
@testset "Explicit Flux.train! with Yota" begin
100-
Random.seed!(84)
101-
w = randn(10, 10)
102-
w2 = randn(10, 10) # NB outside the inner @testset, else it will be exactly == w, as the RNG seed is reset.
103-
@testset for rule in [Descent(0.1), Adam(), AdamW()]
104-
105-
loss(m, x) = Flux.Losses.mse(w*x, m.weight*x .+ m.bias)
106-
model = (weight=copy(w2), bias=zeros(10), ignore=nothing)
107-
@test loss(model, rand(10, 10)) > 1
108-
109-
opt = Flux.setup(rule, model)
110-
Flux.train!(loss, model, ((rand(10),) for _ in 1: 10^5), opt)
111-
@test loss(model, rand(10, 10)) < 0.01
112-
end
113-
114-
# Test 3-arg `Flux.train!` method:
115-
@testset for rule in [Descent(0.1), Adam()]
116-
117-
loss(m) = let x = rand(10)
118-
Flux.Losses.mse(w*x, m.weight*x .+ m.bias)
119-
end
120-
model = (weight=copy(w2), bias=zeros(10), ignore=nothing)
121-
@test loss(model) > 1
122-
123-
opt = Flux.setup(rule, model)
124-
for i in 1:10^5
125-
Flux.train!(loss, model, opt)
126-
end
127-
@test loss(model) < 0.01
128-
end
129-
end
130-
131-
Flux.@train_autodiff Zygote

0 commit comments

Comments
 (0)