@@ -4,7 +4,7 @@ using LinearAlgebra
4
4
using Optimisers: Optimisers
5
5
using Functors: fmap
6
6
7
- export train!, update!, adjust!, FluxState,
7
+ export train!, update!, adjust!, FluxState, @train_autodiff ,
8
8
Descent, Adam, Momentum, Nesterov, RMSProp,
9
9
AdaGrad, AdaMax, AdaDelta, AMSGrad, NAdam, AdamW, RAdam, OAdam, AdaBelief # ,
10
10
# InvDecay, ExpDecay, WeightDecay, stop, skip, Optimiser,
@@ -108,6 +108,48 @@ include("implicit_train.jl") # Params etc, Zygote only
108
108
109
109
explicit_withgradient (f, args... ) = Zygote. withgradient (f, args... ) # can overload this to use e.g. Yota / Diffractor
110
110
111
+ """
112
+ @train_autodiff Diffractor
113
+ @train_autodiff Yota
114
+ @train_autodiff Zygote
115
+
116
+ This macro allows the use of `train!` with various automatic differentiation packages,
117
+ instead of the default Zygote.jl. You should load the package, then call this macro.
118
+
119
+ Only affects "explicit-mode" versions `train!(loss, model, data, opt)` or `train!(loss, model, opt)`,
120
+ since the (deprecated) "implicit-mode" `train!(loss, ps::Params, data, opt)` is Zygote-specific.
121
+
122
+ !!! note
123
+ Experimental
124
+ """
125
+ macro train_autodiff (pkg)
126
+ if pkg == :Diffractor
127
+ return quote
128
+ Diffractor. gradient (sin, 0.0 )[1 ] ≈ 1.0 # ensures an error if not loaded
129
+ function Flux. Train. explicit_withgradient (f, args... )
130
+ y, back = Diffractor.∂⃖¹ (f, args... )
131
+ dy1 = Flux. Zygote. sensitivity (y) # Zygote is loaded, and this gives nice errors
132
+ return (; value = y, gradient = Base. tail (back (dy1)))
133
+ end
134
+ end |> esc
135
+ elseif pkg == :Yota
136
+ return quote
137
+ Yota. grad (sin, 0.0 ) # [2][1] ≈ 1.0
138
+ function Flux. Train. explicit_withgradient (f, args... )
139
+ value, (_, gradient... ) = Yota. grad (f, args... )
140
+ return (; value, gradient)
141
+ end
142
+ end |> esc
143
+ elseif pkg == :Zygote
144
+ return quote
145
+ Flux. Train. explicit_withgradient (f, args... ) = Flux. Zygote. withgradient (f, args... )
146
+ end |> esc
147
+ else
148
+ throw (" @train_autodiff expects either Zygote, Yota, or Diffractor. No other arguments are understood." )
149
+ end
150
+ end
151
+
152
+
111
153
# ## Misc. related utilities
112
154
113
155
"""
0 commit comments