Skip to content

Commit 27ded01

Browse files
authored
Enzyme: fix propagation of runtime activity (#534)
1 parent bc89f91 commit 27ded01

File tree

1 file changed

+21
-21
lines changed

1 file changed

+21
-21
lines changed

ext/EnzymeExt.jl

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -100,15 +100,15 @@ _augmented_return(::Kernel{<:GPU}, subtape, arg_refs, tape_type) =
100100

101101
function _create_tape_kernel(
102102
kernel::Kernel{CPU},
103-
ModifiedBetween,
103+
Mode,
104104
FT,
105105
ctxTy,
106106
ndrange,
107107
iterspace,
108108
args2...,
109109
)
110110
TapeType = EnzymeCore.tape_type(
111-
ReverseSplitModified(ReverseSplitWithPrimal, ModifiedBetween),
111+
Mode,
112112
FT,
113113
Const{Nothing},
114114
Const{ctxTy},
@@ -121,7 +121,7 @@ end
121121

122122
function _create_tape_kernel(
123123
kernel::Kernel{<:GPU},
124-
ModifiedBetween,
124+
Mode,
125125
FT,
126126
ctxTy,
127127
ndrange,
@@ -139,7 +139,7 @@ function _create_tape_kernel(
139139
EnzymeCore.compiler_job_from_backend(backend(kernel), typeof(() -> return), Tuple{})
140140
TapeType = EnzymeCore.tape_type(
141141
job,
142-
ReverseSplitModified(ReverseSplitWithPrimal, ModifiedBetween),
142+
Mode,
143143
FT,
144144
Const{Nothing},
145145
Const{ctxTy},
@@ -159,14 +159,14 @@ _create_rev_kernel(kernel::Kernel{<:GPU}) = similar(kernel, gpu_rev)
159159
function cpu_aug_fwd(
160160
ctx,
161161
f::FT,
162-
::Val{ModifiedBetween},
162+
mode::Mode,
163163
subtape,
164164
::Val{TapeType},
165165
args...,
166-
) where {ModifiedBetween, FT, TapeType}
166+
) where {Mode, FT, TapeType}
167167
# A2 = Const{Nothing} -- since f->Nothing
168168
forward, _ = EnzymeCore.autodiff_thunk(
169-
ReverseSplitModified(ReverseSplitWithPrimal, Val(ModifiedBetween)),
169+
mode,
170170
Const{Core.Typeof(f)},
171171
Const{Nothing},
172172
Const{Core.Typeof(ctx)},
@@ -183,13 +183,13 @@ end
183183
function cpu_rev(
184184
ctx,
185185
f::FT,
186-
::Val{ModifiedBetween},
186+
mode::Mode,
187187
subtape,
188188
::Val{TapeType},
189189
args...,
190-
) where {ModifiedBetween, FT, TapeType}
190+
) where {Mode, FT, TapeType}
191191
_, reverse = EnzymeCore.autodiff_thunk(
192-
ReverseSplitModified(ReverseSplitWithPrimal, Val(ModifiedBetween)),
192+
mode,
193193
Const{Core.Typeof(f)},
194194
Const{Nothing},
195195
Const{Core.Typeof(ctx)},
@@ -205,14 +205,14 @@ end
205205
function gpu_aug_fwd(
206206
ctx,
207207
f::FT,
208-
::Val{ModifiedBetween},
208+
mode::Mode,
209209
subtape,
210210
::Val{TapeType},
211211
args...,
212-
) where {ModifiedBetween, FT, TapeType}
212+
) where {Mode, FT, TapeType}
213213
# A2 = Const{Nothing} -- since f->Nothing
214214
forward, _ = EnzymeCore.autodiff_deferred_thunk(
215-
ReverseSplitModified(ReverseSplitWithPrimal, Val(ModifiedBetween)),
215+
mode,
216216
TapeType,
217217
Const{Core.Typeof(f)},
218218
Const{Nothing},
@@ -232,14 +232,14 @@ end
232232
function gpu_rev(
233233
ctx,
234234
f::FT,
235-
::Val{ModifiedBetween},
235+
mode::Mode,
236236
subtape,
237237
::Val{TapeType},
238238
args...,
239-
) where {ModifiedBetween, FT, TapeType}
239+
) where {Mode, FT, TapeType}
240240
# XXX: TapeType and A2 as args to autodiff_deferred_thunk
241241
_, reverse = EnzymeCore.autodiff_deferred_thunk(
242-
ReverseSplitModified(ReverseSplitWithPrimal, Val(ModifiedBetween)),
242+
mode,
243243
TapeType,
244244
Const{Core.Typeof(f)},
245245
Const{Nothing},
@@ -294,17 +294,17 @@ function EnzymeRules.augmented_primal(
294294
args[i]
295295
end
296296
end
297-
297+
Mode = EnzymeCore.set_runtime_activity(ReverseSplitModified(ReverseSplitWithPrimal, ModifiedBetween), config)
298298
TapeType, subtape, aug_kernel = _create_tape_kernel(
299299
kernel,
300-
ModifiedBetween,
300+
Mode,
301301
FT,
302302
ctxTy,
303303
ndrange,
304304
iterspace,
305305
args2...,
306306
)
307-
aug_kernel(f, ModifiedBetween, subtape, Val(TapeType), args2...; ndrange, workgroupsize)
307+
aug_kernel(f, Mode, subtape, Val(TapeType), args2...; ndrange, workgroupsize)
308308

309309
# TODO the fact that ctxTy is type unstable means this is all type unstable.
310310
# Since custom rules require a fixed return type, explicitly cast to Any, rather
@@ -336,11 +336,11 @@ function EnzymeRules.reverse(
336336
f = kernel.f
337337

338338
ModifiedBetween = Val((overwritten(config)[1], false, overwritten(config)[2:end]...))
339-
339+
Mode = EnzymeCore.set_runtime_activity(ReverseSplitModified(ReverseSplitWithPrimal, ModifiedBetween), config)
340340
rev_kernel = _create_rev_kernel(kernel)
341341
rev_kernel(
342342
f,
343-
ModifiedBetween,
343+
Mode,
344344
subtape,
345345
Val(tape_type),
346346
args2...;

0 commit comments

Comments
 (0)