@@ -100,15 +100,15 @@ _augmented_return(::Kernel{<:GPU}, subtape, arg_refs, tape_type) =
100
100
101
101
function _create_tape_kernel (
102
102
kernel:: Kernel{CPU} ,
103
- ModifiedBetween ,
103
+ Mode ,
104
104
FT,
105
105
ctxTy,
106
106
ndrange,
107
107
iterspace,
108
108
args2... ,
109
109
)
110
110
TapeType = EnzymeCore. tape_type (
111
- ReverseSplitModified (ReverseSplitWithPrimal, ModifiedBetween) ,
111
+ Mode ,
112
112
FT,
113
113
Const{Nothing},
114
114
Const{ctxTy},
121
121
122
122
function _create_tape_kernel (
123
123
kernel:: Kernel{<:GPU} ,
124
- ModifiedBetween ,
124
+ Mode ,
125
125
FT,
126
126
ctxTy,
127
127
ndrange,
@@ -139,7 +139,7 @@ function _create_tape_kernel(
139
139
EnzymeCore. compiler_job_from_backend (backend (kernel), typeof (() -> return ), Tuple{})
140
140
TapeType = EnzymeCore. tape_type (
141
141
job,
142
- ReverseSplitModified (ReverseSplitWithPrimal, ModifiedBetween) ,
142
+ Mode ,
143
143
FT,
144
144
Const{Nothing},
145
145
Const{ctxTy},
@@ -159,14 +159,14 @@ _create_rev_kernel(kernel::Kernel{<:GPU}) = similar(kernel, gpu_rev)
159
159
function cpu_aug_fwd (
160
160
ctx,
161
161
f:: FT ,
162
- :: Val{ModifiedBetween} ,
162
+ mode :: Mode ,
163
163
subtape,
164
164
:: Val{TapeType} ,
165
165
args... ,
166
- ) where {ModifiedBetween , FT, TapeType}
166
+ ) where {Mode , FT, TapeType}
167
167
# A2 = Const{Nothing} -- since f->Nothing
168
168
forward, _ = EnzymeCore. autodiff_thunk (
169
- ReverseSplitModified (ReverseSplitWithPrimal, Val (ModifiedBetween)) ,
169
+ mode ,
170
170
Const{Core. Typeof (f)},
171
171
Const{Nothing},
172
172
Const{Core. Typeof (ctx)},
@@ -183,13 +183,13 @@ end
183
183
function cpu_rev (
184
184
ctx,
185
185
f:: FT ,
186
- :: Val{ModifiedBetween} ,
186
+ mode :: Mode ,
187
187
subtape,
188
188
:: Val{TapeType} ,
189
189
args... ,
190
- ) where {ModifiedBetween , FT, TapeType}
190
+ ) where {Mode , FT, TapeType}
191
191
_, reverse = EnzymeCore. autodiff_thunk (
192
- ReverseSplitModified (ReverseSplitWithPrimal, Val (ModifiedBetween)) ,
192
+ mode ,
193
193
Const{Core. Typeof (f)},
194
194
Const{Nothing},
195
195
Const{Core. Typeof (ctx)},
@@ -205,14 +205,14 @@ end
205
205
function gpu_aug_fwd (
206
206
ctx,
207
207
f:: FT ,
208
- :: Val{ModifiedBetween} ,
208
+ mode :: Mode ,
209
209
subtape,
210
210
:: Val{TapeType} ,
211
211
args... ,
212
- ) where {ModifiedBetween , FT, TapeType}
212
+ ) where {Mode , FT, TapeType}
213
213
# A2 = Const{Nothing} -- since f->Nothing
214
214
forward, _ = EnzymeCore. autodiff_deferred_thunk (
215
- ReverseSplitModified (ReverseSplitWithPrimal, Val (ModifiedBetween)) ,
215
+ mode ,
216
216
TapeType,
217
217
Const{Core. Typeof (f)},
218
218
Const{Nothing},
@@ -232,14 +232,14 @@ end
232
232
function gpu_rev (
233
233
ctx,
234
234
f:: FT ,
235
- :: Val{ModifiedBetween} ,
235
+ mode :: Mode ,
236
236
subtape,
237
237
:: Val{TapeType} ,
238
238
args... ,
239
- ) where {ModifiedBetween , FT, TapeType}
239
+ ) where {Mode , FT, TapeType}
240
240
# XXX : TapeType and A2 as args to autodiff_deferred_thunk
241
241
_, reverse = EnzymeCore. autodiff_deferred_thunk (
242
- ReverseSplitModified (ReverseSplitWithPrimal, Val (ModifiedBetween)) ,
242
+ mode ,
243
243
TapeType,
244
244
Const{Core. Typeof (f)},
245
245
Const{Nothing},
@@ -294,17 +294,17 @@ function EnzymeRules.augmented_primal(
294
294
args[i]
295
295
end
296
296
end
297
-
297
+ Mode = EnzymeCore . set_runtime_activity ( ReverseSplitModified (ReverseSplitWithPrimal, ModifiedBetween), config)
298
298
TapeType, subtape, aug_kernel = _create_tape_kernel (
299
299
kernel,
300
- ModifiedBetween ,
300
+ Mode ,
301
301
FT,
302
302
ctxTy,
303
303
ndrange,
304
304
iterspace,
305
305
args2... ,
306
306
)
307
- aug_kernel (f, ModifiedBetween , subtape, Val (TapeType), args2... ; ndrange, workgroupsize)
307
+ aug_kernel (f, Mode , subtape, Val (TapeType), args2... ; ndrange, workgroupsize)
308
308
309
309
# TODO the fact that ctxTy is type unstable means this is all type unstable.
310
310
# Since custom rules require a fixed return type, explicitly cast to Any, rather
@@ -336,11 +336,11 @@ function EnzymeRules.reverse(
336
336
f = kernel. f
337
337
338
338
ModifiedBetween = Val ((overwritten (config)[1 ], false , overwritten (config)[2 : end ]. .. ))
339
-
339
+ Mode = EnzymeCore . set_runtime_activity ( ReverseSplitModified (ReverseSplitWithPrimal, ModifiedBetween), config)
340
340
rev_kernel = _create_rev_kernel (kernel)
341
341
rev_kernel (
342
342
f,
343
- ModifiedBetween ,
343
+ Mode ,
344
344
subtape,
345
345
Val (tape_type),
346
346
args2... ;
0 commit comments