Skip to content

Commit bb50d0b

Browse files
authored
perf: allocate Enzyme shadow memory during preparation (#782)
* perf: allocate Enzyme shadow memory during preparation * Fix * Fix * Fixes * Allow JuliaFormatter 2 * Code coverage * Re-add matrix tests * Add finer tests and comments
1 parent bdddafb commit bb50d0b

File tree

7 files changed

+341
-185
lines changed

7 files changed

+341
-185
lines changed

DifferentiationInterface/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "DifferentiationInterface"
22
uuid = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
33
authors = ["Guillaume Dalle", "Adrian Hill"]
4-
version = "0.6.52"
4+
version = "0.6.53"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"

DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl

Lines changed: 68 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,91 +1,104 @@
11
## Pushforward
22

3+
struct EnzymeOneArgPushforwardPrep{SIG,DF,DC} <: DI.PushforwardPrep{SIG}
4+
_sig::Val{SIG}
5+
df::DF
6+
context_shadows::DC
7+
end
8+
39
function DI.prepare_pushforward_nokwarg(
410
strict::Val,
511
f::F,
612
backend::AutoEnzyme{<:Union{ForwardMode,Nothing}},
713
x,
8-
tx::NTuple,
14+
tx::NTuple{B},
915
contexts::Vararg{DI.Context,C};
10-
) where {F,C}
16+
) where {F,C,B}
1117
_sig = DI.signature(f, backend, x, tx, contexts...; strict)
12-
return DI.NoPushforwardPrep(_sig)
18+
df = function_shadow(f, backend, Val(B))
19+
mode = forward_withprimal(backend)
20+
context_shadows = make_context_shadows(backend, mode, Val(B), contexts...)
21+
return EnzymeOneArgPushforwardPrep(_sig, df, context_shadows)
1322
end
1423

1524
function DI.value_and_pushforward(
1625
f::F,
17-
prep::DI.NoPushforwardPrep,
26+
prep::EnzymeOneArgPushforwardPrep,
1827
backend::AutoEnzyme{<:Union{ForwardMode,Nothing}},
1928
x,
2029
tx::NTuple{1},
2130
contexts::Vararg{DI.Context,C},
2231
) where {F,C}
2332
DI.check_prep(f, prep, backend, x, tx, contexts...)
33+
(; df, context_shadows) = prep
2434
mode = forward_withprimal(backend)
25-
f_and_df = get_f_and_df(f, backend, mode)
35+
f_and_df = get_f_and_df_prepared!(df, f, backend, Val(1))
2636
dx = only(tx)
2737
x_and_dx = Duplicated(x, dx)
28-
annotated_contexts = translate(backend, mode, Val(1), contexts...)
38+
annotated_contexts = translate_prepared!(context_shadows, contexts, Val(1))
2939
dy, y = autodiff(mode, f_and_df, x_and_dx, annotated_contexts...)
3040
return y, (dy,)
3141
end
3242

3343
function DI.value_and_pushforward(
3444
f::F,
35-
prep::DI.NoPushforwardPrep,
45+
prep::EnzymeOneArgPushforwardPrep,
3646
backend::AutoEnzyme{<:Union{ForwardMode,Nothing}},
3747
x,
3848
tx::NTuple{B},
3949
contexts::Vararg{DI.Context,C},
4050
) where {F,B,C}
4151
DI.check_prep(f, prep, backend, x, tx, contexts...)
52+
(; df, context_shadows) = prep
4253
mode = forward_withprimal(backend)
43-
f_and_df = get_f_and_df(f, backend, mode, Val(B))
54+
f_and_df = get_f_and_df_prepared!(df, f, backend, Val(B))
4455
x_and_tx = BatchDuplicated(x, tx)
45-
annotated_contexts = translate(backend, mode, Val(B), contexts...)
56+
annotated_contexts = translate_prepared!(context_shadows, contexts, Val(B))
4657
ty, y = autodiff(mode, f_and_df, x_and_tx, annotated_contexts...)
4758
return y, values(ty)
4859
end
4960

5061
function DI.pushforward(
5162
f::F,
52-
prep::DI.NoPushforwardPrep,
63+
prep::EnzymeOneArgPushforwardPrep,
5364
backend::AutoEnzyme{<:Union{ForwardMode,Nothing}},
5465
x,
5566
tx::NTuple{1},
5667
contexts::Vararg{DI.Context,C},
5768
) where {F,C}
5869
DI.check_prep(f, prep, backend, x, tx, contexts...)
70+
(; df, context_shadows) = prep
5971
mode = forward_noprimal(backend)
60-
f_and_df = get_f_and_df(f, backend, mode)
72+
f_and_df = get_f_and_df_prepared!(df, f, backend, Val(1))
6173
dx = only(tx)
6274
x_and_dx = Duplicated(x, dx)
63-
annotated_contexts = translate(backend, mode, Val(1), contexts...)
75+
annotated_contexts = translate_prepared!(context_shadows, contexts, Val(1))
6476
dy = only(autodiff(mode, f_and_df, x_and_dx, annotated_contexts...))
6577
return (dy,)
6678
end
6779

6880
function DI.pushforward(
6981
f::F,
70-
prep::DI.NoPushforwardPrep,
82+
prep::EnzymeOneArgPushforwardPrep,
7183
backend::AutoEnzyme{<:Union{ForwardMode,Nothing}},
7284
x,
7385
tx::NTuple{B},
7486
contexts::Vararg{DI.Context,C},
7587
) where {F,B,C}
7688
DI.check_prep(f, prep, backend, x, tx, contexts...)
89+
(; df, context_shadows) = prep
7790
mode = forward_noprimal(backend)
78-
f_and_df = get_f_and_df(f, backend, mode, Val(B))
91+
f_and_df = get_f_and_df_prepared!(df, f, backend, Val(B))
7992
x_and_tx = BatchDuplicated(x, tx)
80-
annotated_contexts = translate(backend, mode, Val(B), contexts...)
93+
annotated_contexts = translate_prepared!(context_shadows, contexts, Val(B))
8194
ty = only(autodiff(mode, f_and_df, x_and_tx, annotated_contexts...))
8295
return values(ty)
8396
end
8497

8598
function DI.value_and_pushforward!(
8699
f::F,
87100
ty::NTuple,
88-
prep::DI.NoPushforwardPrep,
101+
prep::EnzymeOneArgPushforwardPrep,
89102
backend::AutoEnzyme{<:Union{ForwardMode,Nothing}},
90103
x,
91104
tx::NTuple,
@@ -101,7 +114,7 @@ end
101114
function DI.pushforward!(
102115
f::F,
103116
ty::NTuple,
104-
prep::DI.NoPushforwardPrep,
117+
prep::EnzymeOneArgPushforwardPrep,
105118
backend::AutoEnzyme{<:Union{ForwardMode,Nothing}},
106119
x,
107120
tx::NTuple,
@@ -116,10 +129,12 @@ end
116129

117130
## Gradient
118131

119-
struct EnzymeForwardGradientPrep{SIG,B,O} <: DI.GradientPrep{SIG}
132+
struct EnzymeForwardGradientPrep{SIG,B,DF,DC,O} <: DI.GradientPrep{SIG}
120133
_sig::Val{SIG}
121134
_valB::Val{B}
122-
shadows::O
135+
df::DF
136+
context_shadows::DC
137+
basis_shadows::O
123138
end
124139

125140
function DI.prepare_gradient_nokwarg(
@@ -131,8 +146,11 @@ function DI.prepare_gradient_nokwarg(
131146
) where {F,C}
132147
_sig = DI.signature(f, backend, x, contexts...; strict)
133148
valB = to_val(DI.pick_batchsize(backend, x))
134-
shadows = create_shadows(valB, x)
135-
return EnzymeForwardGradientPrep(_sig, valB, shadows)
149+
df = function_shadow(f, backend, valB)
150+
mode = forward_withprimal(backend)
151+
context_shadows = make_context_shadows(backend, mode, valB, contexts...)
152+
basis_shadows = create_shadows(valB, x)
153+
return EnzymeForwardGradientPrep(_sig, valB, df, context_shadows, basis_shadows)
136154
end
137155

138156
function DI.gradient(
@@ -143,11 +161,12 @@ function DI.gradient(
143161
contexts::Vararg{DI.Constant,C},
144162
) where {F,SIG,B,C}
145163
DI.check_prep(f, prep, backend, x, contexts...)
164+
(; df, context_shadows, basis_shadows) = prep
146165
mode = forward_noprimal(backend)
147-
f_and_df = get_f_and_df(f, backend, mode)
148-
annotated_contexts = translate(backend, mode, Val(B), contexts...)
166+
f_and_df = get_f_and_df_prepared!(df, f, backend, Val(B))
167+
annotated_contexts = translate_prepared!(context_shadows, contexts, Val(B))
149168
derivs = gradient(
150-
mode, f_and_df, x, annotated_contexts...; chunk=Val(B), shadows=prep.shadows
169+
mode, f_and_df, x, annotated_contexts...; chunk=Val(B), shadows=basis_shadows
151170
)
152171
return first(derivs)
153172
end
@@ -160,11 +179,12 @@ function DI.value_and_gradient(
160179
contexts::Vararg{DI.Constant,C},
161180
) where {F,SIG,B,C}
162181
DI.check_prep(f, prep, backend, x, contexts...)
182+
(; df, context_shadows, basis_shadows) = prep
163183
mode = forward_withprimal(backend)
164-
f_and_df = get_f_and_df(f, backend, mode)
165-
annotated_contexts = translate(backend, mode, Val(B), contexts...)
184+
f_and_df = get_f_and_df_prepared!(df, f, backend, Val(B))
185+
annotated_contexts = translate_prepared!(context_shadows, contexts, Val(B))
166186
(; derivs, val) = gradient(
167-
mode, f_and_df, x, annotated_contexts...; chunk=Val(B), shadows=prep.shadows
187+
mode, f_and_df, x, annotated_contexts...; chunk=Val(B), shadows=basis_shadows
168188
)
169189
return val, first(derivs)
170190
end
@@ -196,10 +216,12 @@ end
196216

197217
## Jacobian
198218

199-
struct EnzymeForwardOneArgJacobianPrep{SIG,B,O} <: DI.JacobianPrep{SIG}
219+
struct EnzymeForwardOneArgJacobianPrep{SIG,B,DF,DC,O} <: DI.JacobianPrep{SIG}
200220
_sig::Val{SIG}
201221
_valB::Val{B}
202-
shadows::O
222+
df::DF
223+
context_shadows::DC
224+
basis_shadows::O
203225
output_length::Int
204226
end
205227

@@ -213,8 +235,13 @@ function DI.prepare_jacobian_nokwarg(
213235
_sig = DI.signature(f, backend, x, contexts...; strict)
214236
y = f(x, map(DI.unwrap, contexts)...)
215237
valB = to_val(DI.pick_batchsize(backend, x))
216-
shadows = create_shadows(valB, x)
217-
return EnzymeForwardOneArgJacobianPrep(_sig, valB, shadows, length(y))
238+
mode = forward_withprimal(backend)
239+
df = function_shadow(f, backend, valB)
240+
context_shadows = make_context_shadows(backend, mode, valB, contexts...)
241+
basis_shadows = create_shadows(valB, x)
242+
return EnzymeForwardOneArgJacobianPrep(
243+
_sig, valB, df, context_shadows, basis_shadows, length(y)
244+
)
218245
end
219246

220247
function DI.jacobian(
@@ -225,14 +252,15 @@ function DI.jacobian(
225252
contexts::Vararg{DI.Constant,C},
226253
) where {F,SIG,B,C}
227254
DI.check_prep(f, prep, backend, x, contexts...)
255+
(; df, context_shadows, basis_shadows, output_length) = prep
228256
mode = forward_noprimal(backend)
229-
f_and_df = get_f_and_df(f, backend, mode)
230-
annotated_contexts = translate(backend, mode, Val(B), contexts...)
257+
f_and_df = get_f_and_df_prepared!(df, f, backend, Val(B))
258+
annotated_contexts = translate_prepared!(context_shadows, contexts, Val(B))
231259
derivs = jacobian(
232-
mode, f_and_df, x, annotated_contexts...; chunk=Val(B), shadows=prep.shadows
260+
mode, f_and_df, x, annotated_contexts...; chunk=Val(B), shadows=basis_shadows
233261
)
234262
jac_tensor = first(derivs)
235-
return maybe_reshape(jac_tensor, prep.output_length, length(x))
263+
return maybe_reshape(jac_tensor, output_length, length(x))
236264
end
237265

238266
function DI.value_and_jacobian(
@@ -243,14 +271,15 @@ function DI.value_and_jacobian(
243271
contexts::Vararg{DI.Constant,C},
244272
) where {F,SIG,B,C}
245273
DI.check_prep(f, prep, backend, x, contexts...)
274+
(; df, context_shadows, basis_shadows, output_length) = prep
246275
mode = forward_withprimal(backend)
247-
f_and_df = get_f_and_df(f, backend, mode)
248-
annotated_contexts = translate(backend, mode, Val(B), contexts...)
276+
f_and_df = get_f_and_df_prepared!(df, f, backend, Val(B))
277+
annotated_contexts = translate_prepared!(context_shadows, contexts, Val(B))
249278
(; derivs, val) = jacobian(
250-
mode, f_and_df, x, annotated_contexts...; chunk=Val(B), shadows=prep.shadows
279+
mode, f_and_df, x, annotated_contexts...; chunk=Val(B), shadows=basis_shadows
251280
)
252281
jac_tensor = first(derivs)
253-
return val, maybe_reshape(jac_tensor, prep.output_length, length(x))
282+
return val, maybe_reshape(jac_tensor, output_length, length(x))
254283
end
255284

256285
function DI.jacobian!(

DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_twoarg.jl

Lines changed: 26 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,63 +1,74 @@
11
## Pushforward
22

3+
struct EnzymeTwoArgPushforwardPrep{SIG,DF,DC} <: DI.PushforwardPrep{SIG}
4+
_sig::Val{SIG}
5+
df!::DF
6+
context_shadows::DC
7+
end
8+
39
function DI.prepare_pushforward_nokwarg(
410
strict::Val,
511
f!::F,
612
y,
713
backend::AutoEnzyme{<:Union{ForwardMode,Nothing}},
814
x,
9-
tx::NTuple,
15+
tx::NTuple{B},
1016
contexts::Vararg{DI.Context,C};
11-
) where {F,C}
17+
) where {F,B,C}
1218
_sig = DI.signature(f!, y, backend, x, tx, contexts...; strict)
13-
return DI.NoPushforwardPrep(_sig)
19+
df! = function_shadow(f!, backend, Val(B))
20+
mode = forward_noprimal(backend)
21+
context_shadows = make_context_shadows(backend, mode, Val(B), contexts...)
22+
return EnzymeTwoArgPushforwardPrep(_sig, df!, context_shadows)
1423
end
1524

1625
function DI.value_and_pushforward(
1726
f!::F,
1827
y,
19-
prep::DI.NoPushforwardPrep,
28+
prep::EnzymeTwoArgPushforwardPrep,
2029
backend::AutoEnzyme{<:Union{ForwardMode,Nothing}},
2130
x,
2231
tx::NTuple{1},
2332
contexts::Vararg{DI.Context,C},
2433
) where {F,C}
2534
DI.check_prep(f!, y, prep, backend, x, tx, contexts...)
35+
(; df!, context_shadows) = prep
2636
mode = forward_noprimal(backend)
27-
f!_and_df! = get_f_and_df(f!, backend, mode)
37+
f!_and_df! = get_f_and_df_prepared!(df!, f!, backend, Val(1))
2838
dx = only(tx)
2939
dy = make_zero(y)
3040
x_and_dx = Duplicated(x, dx)
3141
y_and_dy = Duplicated(y, dy)
32-
annotated_contexts = translate(backend, mode, Val(1), contexts...)
42+
annotated_contexts = translate_prepared!(context_shadows, contexts, Val(1))
3343
autodiff(mode, f!_and_df!, Const, y_and_dy, x_and_dx, annotated_contexts...)
3444
return y, (dy,)
3545
end
3646

3747
function DI.value_and_pushforward(
3848
f!::F,
3949
y,
40-
prep::DI.NoPushforwardPrep,
50+
prep::EnzymeTwoArgPushforwardPrep,
4151
backend::AutoEnzyme{<:Union{ForwardMode,Nothing}},
4252
x,
4353
tx::NTuple{B},
4454
contexts::Vararg{DI.Context,C},
4555
) where {F,B,C}
4656
DI.check_prep(f!, y, prep, backend, x, tx, contexts...)
57+
(; df!, context_shadows) = prep
4758
mode = forward_noprimal(backend)
48-
f!_and_df! = get_f_and_df(f!, backend, mode, Val(B))
59+
f!_and_df! = get_f_and_df_prepared!(df!, f!, backend, Val(B))
4960
ty = ntuple(_ -> make_zero(y), Val(B))
5061
x_and_tx = BatchDuplicated(x, tx)
5162
y_and_ty = BatchDuplicated(y, ty)
52-
annotated_contexts = translate(backend, mode, Val(B), contexts...)
63+
annotated_contexts = translate_prepared!(context_shadows, contexts, Val(B))
5364
autodiff(mode, f!_and_df!, Const, y_and_ty, x_and_tx, annotated_contexts...)
5465
return y, ty
5566
end
5667

5768
function DI.pushforward(
5869
f!::F,
5970
y,
60-
prep::DI.NoPushforwardPrep,
71+
prep::EnzymeTwoArgPushforwardPrep,
6172
backend::AutoEnzyme{<:Union{ForwardMode,Nothing}},
6273
x,
6374
tx::NTuple,
@@ -72,18 +83,19 @@ function DI.value_and_pushforward!(
7283
f!::F,
7384
y,
7485
ty::NTuple{B},
75-
prep::DI.NoPushforwardPrep,
86+
prep::EnzymeTwoArgPushforwardPrep,
7687
backend::AutoEnzyme{<:Union{ForwardMode,Nothing}},
7788
x,
7889
tx::NTuple{B},
7990
contexts::Vararg{DI.Context,C},
8091
) where {F,B,C}
8192
DI.check_prep(f!, y, prep, backend, x, tx, contexts...)
93+
(; df!, context_shadows) = prep
8294
mode = forward_noprimal(backend)
83-
f!_and_df! = get_f_and_df(f!, backend, mode, Val(B))
95+
f!_and_df! = get_f_and_df_prepared!(df!, f!, backend, Val(B))
8496
x_and_tx = BatchDuplicated(x, tx)
8597
y_and_ty = BatchDuplicated(y, ty)
86-
annotated_contexts = translate(backend, mode, Val(B), contexts...)
98+
annotated_contexts = translate_prepared!(context_shadows, contexts, Val(B))
8799
autodiff(mode, f!_and_df!, Const, y_and_ty, x_and_tx, annotated_contexts...)
88100
return y, ty
89101
end
@@ -92,7 +104,7 @@ function DI.pushforward!(
92104
f!::F,
93105
y,
94106
ty::NTuple,
95-
prep::DI.NoPushforwardPrep,
107+
prep::EnzymeTwoArgPushforwardPrep,
96108
backend::AutoEnzyme{<:Union{ForwardMode,Nothing}},
97109
x,
98110
tx::NTuple,

0 commit comments

Comments
 (0)