1
1
# # Pushforward
2
2
3
+ struct EnzymeOneArgPushforwardPrep{SIG,DF,DC} <: DI.PushforwardPrep{SIG}
4
+ _sig:: Val{SIG}
5
+ df:: DF
6
+ context_shadows:: DC
7
+ end
8
+
3
9
function DI. prepare_pushforward_nokwarg (
4
10
strict:: Val ,
5
11
f:: F ,
6
12
backend:: AutoEnzyme{<:Union{ForwardMode,Nothing}} ,
7
13
x,
8
- tx:: NTuple ,
14
+ tx:: NTuple{B} ,
9
15
contexts:: Vararg{DI.Context,C} ;
10
- ) where {F,C}
16
+ ) where {F,C,B }
11
17
_sig = DI. signature (f, backend, x, tx, contexts... ; strict)
12
- return DI. NoPushforwardPrep (_sig)
18
+ df = function_shadow (f, backend, Val (B))
19
+ mode = forward_withprimal (backend)
20
+ context_shadows = make_context_shadows (backend, mode, Val (B), contexts... )
21
+ return EnzymeOneArgPushforwardPrep (_sig, df, context_shadows)
13
22
end
14
23
15
24
function DI. value_and_pushforward (
16
25
f:: F ,
17
- prep:: DI.NoPushforwardPrep ,
26
+ prep:: EnzymeOneArgPushforwardPrep ,
18
27
backend:: AutoEnzyme{<:Union{ForwardMode,Nothing}} ,
19
28
x,
20
29
tx:: NTuple{1} ,
21
30
contexts:: Vararg{DI.Context,C} ,
22
31
) where {F,C}
23
32
DI. check_prep (f, prep, backend, x, tx, contexts... )
33
+ (; df, context_shadows) = prep
24
34
mode = forward_withprimal (backend)
25
- f_and_df = get_f_and_df ( f, backend, mode )
35
+ f_and_df = get_f_and_df_prepared! (df, f, backend, Val ( 1 ) )
26
36
dx = only (tx)
27
37
x_and_dx = Duplicated (x, dx)
28
- annotated_contexts = translate (backend, mode , Val (1 ), contexts ... )
38
+ annotated_contexts = translate_prepared! (context_shadows, contexts , Val (1 ))
29
39
dy, y = autodiff (mode, f_and_df, x_and_dx, annotated_contexts... )
30
40
return y, (dy,)
31
41
end
32
42
33
43
function DI. value_and_pushforward (
34
44
f:: F ,
35
- prep:: DI.NoPushforwardPrep ,
45
+ prep:: EnzymeOneArgPushforwardPrep ,
36
46
backend:: AutoEnzyme{<:Union{ForwardMode,Nothing}} ,
37
47
x,
38
48
tx:: NTuple{B} ,
39
49
contexts:: Vararg{DI.Context,C} ,
40
50
) where {F,B,C}
41
51
DI. check_prep (f, prep, backend, x, tx, contexts... )
52
+ (; df, context_shadows) = prep
42
53
mode = forward_withprimal (backend)
43
- f_and_df = get_f_and_df (f, backend, mode , Val (B))
54
+ f_and_df = get_f_and_df_prepared! (df, f, backend , Val (B))
44
55
x_and_tx = BatchDuplicated (x, tx)
45
- annotated_contexts = translate (backend, mode , Val (B), contexts ... )
56
+ annotated_contexts = translate_prepared! (context_shadows, contexts , Val (B))
46
57
ty, y = autodiff (mode, f_and_df, x_and_tx, annotated_contexts... )
47
58
return y, values (ty)
48
59
end
49
60
50
61
function DI. pushforward (
51
62
f:: F ,
52
- prep:: DI.NoPushforwardPrep ,
63
+ prep:: EnzymeOneArgPushforwardPrep ,
53
64
backend:: AutoEnzyme{<:Union{ForwardMode,Nothing}} ,
54
65
x,
55
66
tx:: NTuple{1} ,
56
67
contexts:: Vararg{DI.Context,C} ,
57
68
) where {F,C}
58
69
DI. check_prep (f, prep, backend, x, tx, contexts... )
70
+ (; df, context_shadows) = prep
59
71
mode = forward_noprimal (backend)
60
- f_and_df = get_f_and_df ( f, backend, mode )
72
+ f_and_df = get_f_and_df_prepared! (df, f, backend, Val ( 1 ) )
61
73
dx = only (tx)
62
74
x_and_dx = Duplicated (x, dx)
63
- annotated_contexts = translate (backend, mode , Val (1 ), contexts ... )
75
+ annotated_contexts = translate_prepared! (context_shadows, contexts , Val (1 ))
64
76
dy = only (autodiff (mode, f_and_df, x_and_dx, annotated_contexts... ))
65
77
return (dy,)
66
78
end
67
79
68
80
function DI. pushforward (
69
81
f:: F ,
70
- prep:: DI.NoPushforwardPrep ,
82
+ prep:: EnzymeOneArgPushforwardPrep ,
71
83
backend:: AutoEnzyme{<:Union{ForwardMode,Nothing}} ,
72
84
x,
73
85
tx:: NTuple{B} ,
74
86
contexts:: Vararg{DI.Context,C} ,
75
87
) where {F,B,C}
76
88
DI. check_prep (f, prep, backend, x, tx, contexts... )
89
+ (; df, context_shadows) = prep
77
90
mode = forward_noprimal (backend)
78
- f_and_df = get_f_and_df (f, backend, mode , Val (B))
91
+ f_and_df = get_f_and_df_prepared! (df, f, backend , Val (B))
79
92
x_and_tx = BatchDuplicated (x, tx)
80
- annotated_contexts = translate (backend, mode , Val (B), contexts ... )
93
+ annotated_contexts = translate_prepared! (context_shadows, contexts , Val (B))
81
94
ty = only (autodiff (mode, f_and_df, x_and_tx, annotated_contexts... ))
82
95
return values (ty)
83
96
end
84
97
85
98
function DI. value_and_pushforward! (
86
99
f:: F ,
87
100
ty:: NTuple ,
88
- prep:: DI.NoPushforwardPrep ,
101
+ prep:: EnzymeOneArgPushforwardPrep ,
89
102
backend:: AutoEnzyme{<:Union{ForwardMode,Nothing}} ,
90
103
x,
91
104
tx:: NTuple ,
101
114
function DI. pushforward! (
102
115
f:: F ,
103
116
ty:: NTuple ,
104
- prep:: DI.NoPushforwardPrep ,
117
+ prep:: EnzymeOneArgPushforwardPrep ,
105
118
backend:: AutoEnzyme{<:Union{ForwardMode,Nothing}} ,
106
119
x,
107
120
tx:: NTuple ,
@@ -116,10 +129,12 @@ end
116
129
117
130
# # Gradient
118
131
119
- struct EnzymeForwardGradientPrep{SIG,B,O} <: DI.GradientPrep{SIG}
132
+ struct EnzymeForwardGradientPrep{SIG,B,DF,DC, O} <: DI.GradientPrep{SIG}
120
133
_sig:: Val{SIG}
121
134
_valB:: Val{B}
122
- shadows:: O
135
+ df:: DF
136
+ context_shadows:: DC
137
+ basis_shadows:: O
123
138
end
124
139
125
140
function DI. prepare_gradient_nokwarg (
@@ -131,8 +146,11 @@ function DI.prepare_gradient_nokwarg(
131
146
) where {F,C}
132
147
_sig = DI. signature (f, backend, x, contexts... ; strict)
133
148
valB = to_val (DI. pick_batchsize (backend, x))
134
- shadows = create_shadows (valB, x)
135
- return EnzymeForwardGradientPrep (_sig, valB, shadows)
149
+ df = function_shadow (f, backend, valB)
150
+ mode = forward_withprimal (backend)
151
+ context_shadows = make_context_shadows (backend, mode, valB, contexts... )
152
+ basis_shadows = create_shadows (valB, x)
153
+ return EnzymeForwardGradientPrep (_sig, valB, df, context_shadows, basis_shadows)
136
154
end
137
155
138
156
function DI. gradient (
@@ -143,11 +161,12 @@ function DI.gradient(
143
161
contexts:: Vararg{DI.Constant,C} ,
144
162
) where {F,SIG,B,C}
145
163
DI. check_prep (f, prep, backend, x, contexts... )
164
+ (; df, context_shadows, basis_shadows) = prep
146
165
mode = forward_noprimal (backend)
147
- f_and_df = get_f_and_df ( f, backend, mode )
148
- annotated_contexts = translate (backend, mode , Val (B), contexts ... )
166
+ f_and_df = get_f_and_df_prepared! (df, f, backend, Val (B) )
167
+ annotated_contexts = translate_prepared! (context_shadows, contexts , Val (B))
149
168
derivs = gradient (
150
- mode, f_and_df, x, annotated_contexts... ; chunk= Val (B), shadows= prep . shadows
169
+ mode, f_and_df, x, annotated_contexts... ; chunk= Val (B), shadows= basis_shadows
151
170
)
152
171
return first (derivs)
153
172
end
@@ -160,11 +179,12 @@ function DI.value_and_gradient(
160
179
contexts:: Vararg{DI.Constant,C} ,
161
180
) where {F,SIG,B,C}
162
181
DI. check_prep (f, prep, backend, x, contexts... )
182
+ (; df, context_shadows, basis_shadows) = prep
163
183
mode = forward_withprimal (backend)
164
- f_and_df = get_f_and_df ( f, backend, mode )
165
- annotated_contexts = translate (backend, mode , Val (B), contexts ... )
184
+ f_and_df = get_f_and_df_prepared! (df, f, backend, Val (B) )
185
+ annotated_contexts = translate_prepared! (context_shadows, contexts , Val (B))
166
186
(; derivs, val) = gradient (
167
- mode, f_and_df, x, annotated_contexts... ; chunk= Val (B), shadows= prep . shadows
187
+ mode, f_and_df, x, annotated_contexts... ; chunk= Val (B), shadows= basis_shadows
168
188
)
169
189
return val, first (derivs)
170
190
end
@@ -196,10 +216,12 @@ end
196
216
197
217
# # Jacobian
198
218
199
- struct EnzymeForwardOneArgJacobianPrep{SIG,B,O} <: DI.JacobianPrep{SIG}
219
+ struct EnzymeForwardOneArgJacobianPrep{SIG,B,DF,DC, O} <: DI.JacobianPrep{SIG}
200
220
_sig:: Val{SIG}
201
221
_valB:: Val{B}
202
- shadows:: O
222
+ df:: DF
223
+ context_shadows:: DC
224
+ basis_shadows:: O
203
225
output_length:: Int
204
226
end
205
227
@@ -213,8 +235,13 @@ function DI.prepare_jacobian_nokwarg(
213
235
_sig = DI. signature (f, backend, x, contexts... ; strict)
214
236
y = f (x, map (DI. unwrap, contexts)... )
215
237
valB = to_val (DI. pick_batchsize (backend, x))
216
- shadows = create_shadows (valB, x)
217
- return EnzymeForwardOneArgJacobianPrep (_sig, valB, shadows, length (y))
238
+ mode = forward_withprimal (backend)
239
+ df = function_shadow (f, backend, valB)
240
+ context_shadows = make_context_shadows (backend, mode, valB, contexts... )
241
+ basis_shadows = create_shadows (valB, x)
242
+ return EnzymeForwardOneArgJacobianPrep (
243
+ _sig, valB, df, context_shadows, basis_shadows, length (y)
244
+ )
218
245
end
219
246
220
247
function DI. jacobian (
@@ -225,14 +252,15 @@ function DI.jacobian(
225
252
contexts:: Vararg{DI.Constant,C} ,
226
253
) where {F,SIG,B,C}
227
254
DI. check_prep (f, prep, backend, x, contexts... )
255
+ (; df, context_shadows, basis_shadows, output_length) = prep
228
256
mode = forward_noprimal (backend)
229
- f_and_df = get_f_and_df ( f, backend, mode )
230
- annotated_contexts = translate (backend, mode , Val (B), contexts ... )
257
+ f_and_df = get_f_and_df_prepared! (df, f, backend, Val (B) )
258
+ annotated_contexts = translate_prepared! (context_shadows, contexts , Val (B))
231
259
derivs = jacobian (
232
- mode, f_and_df, x, annotated_contexts... ; chunk= Val (B), shadows= prep . shadows
260
+ mode, f_and_df, x, annotated_contexts... ; chunk= Val (B), shadows= basis_shadows
233
261
)
234
262
jac_tensor = first (derivs)
235
- return maybe_reshape (jac_tensor, prep . output_length, length (x))
263
+ return maybe_reshape (jac_tensor, output_length, length (x))
236
264
end
237
265
238
266
function DI. value_and_jacobian (
@@ -243,14 +271,15 @@ function DI.value_and_jacobian(
243
271
contexts:: Vararg{DI.Constant,C} ,
244
272
) where {F,SIG,B,C}
245
273
DI. check_prep (f, prep, backend, x, contexts... )
274
+ (; df, context_shadows, basis_shadows, output_length) = prep
246
275
mode = forward_withprimal (backend)
247
- f_and_df = get_f_and_df ( f, backend, mode )
248
- annotated_contexts = translate (backend, mode , Val (B), contexts ... )
276
+ f_and_df = get_f_and_df_prepared! (df, f, backend, Val (B) )
277
+ annotated_contexts = translate_prepared! (context_shadows, contexts , Val (B))
249
278
(; derivs, val) = jacobian (
250
- mode, f_and_df, x, annotated_contexts... ; chunk= Val (B), shadows= prep . shadows
279
+ mode, f_and_df, x, annotated_contexts... ; chunk= Val (B), shadows= basis_shadows
251
280
)
252
281
jac_tensor = first (derivs)
253
- return val, maybe_reshape (jac_tensor, prep . output_length, length (x))
282
+ return val, maybe_reshape (jac_tensor, output_length, length (x))
254
283
end
255
284
256
285
function DI. jacobian! (
0 commit comments