Skip to content

Commit e6f7b9e

Browse files
committed
add train_autodiff macro
1 parent 7aa74e0 commit e6f7b9e

File tree

1 file changed

+43
-1
lines changed

1 file changed

+43
-1
lines changed

src/train/Train.jl

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ using LinearAlgebra
44
using Optimisers: Optimisers
55
using Functors: fmap
66

7-
export train!, update!, adjust!, FluxState,
7+
export train!, update!, adjust!, FluxState, @train_autodiff,
88
Descent, Adam, Momentum, Nesterov, RMSProp,
99
AdaGrad, AdaMax, AdaDelta, AMSGrad, NAdam, AdamW, RAdam, OAdam, AdaBelief #,
1010
# InvDecay, ExpDecay, WeightDecay, stop, skip, Optimiser,
@@ -108,6 +108,48 @@ include("implicit_train.jl") # Params etc, Zygote only
108108

109109
explicit_withgradient(f, args...) = Zygote.withgradient(f, args...) # can overload this to use e.g. Yota / Diffractor
110110

111+
"""
112+
@train_autodiff Diffractor
113+
@train_autodiff Yota
114+
@train_autodiff Zygote
115+
116+
This macro allows the use of `train!` with various automatic differentiation packages,
117+
instead of the default Zygote.jl. You should load the package, then call this macro.
118+
119+
Only affects "explicit-mode" versions `train!(loss, model, data, opt)` or `train!(loss, model, opt)`,
120+
since the (deprecated) "implicit-mode" `train!(loss, ps::Params, data, opt)` is Zygote-specific.
121+
122+
!!! note
123+
Experimental
124+
"""
125+
macro train_autodiff(pkg)
126+
if pkg == :Diffractor
127+
return quote
128+
Diffractor.gradient(sin, 0.0)[1] 1.0 # ensures an error if not loaded
129+
function Flux.Train.explicit_withgradient(f, args...)
130+
y, back = Diffractor.∂⃖¹(f, args...)
131+
dy1 = Flux.Zygote.sensitivity(y) # Zygote is loaded, and this gives nice errors
132+
return (; value = y, gradient = Base.tail(back(dy1)))
133+
end
134+
end |> esc
135+
elseif pkg == :Yota
136+
return quote
137+
Yota.grad(sin, 0.0) # [2][1] ≈ 1.0
138+
function Flux.Train.explicit_withgradient(f, args...)
139+
value, (_, gradient...) = Yota.grad(f, args...)
140+
return (; value, gradient)
141+
end
142+
end |> esc
143+
elseif pkg == :Zygote
144+
return quote
145+
Flux.Train.explicit_withgradient(f, args...) = Flux.Zygote.withgradient(f, args...)
146+
end |> esc
147+
else
148+
throw("@train_autodiff expects either Zygote, Yota, or Diffractor. No other arguments are understood.")
149+
end
150+
end
151+
152+
111153
### Misc. related utilities
112154

113155
"""

0 commit comments

Comments
 (0)