@@ -44,17 +44,18 @@ EnzymeRules.inactive(::Type{StaticSize}, x...) = nothing
44
44
# https://github.com/EnzymeAD/Enzyme.jl/issues/1516
45
45
# On the CPU `autodiff_deferred` can deadlock.
46
46
# Hence a specialized CPU version
47
- function cpu_fwd (ctx, f, args... )
48
- EnzymeCore. autodiff (Forward, Const (f), Const{Nothing}, Const (ctx), args... )
47
+ function cpu_fwd (ctx, config, f, args... )
48
+ EnzymeCore. autodiff (EnzymeCore . set_runtime_activity ( Forward, config) , Const (f), Const{Nothing}, Const (ctx), args... )
49
49
return nothing
50
50
end
51
51
52
- function gpu_fwd (ctx, f, args... )
53
- EnzymeCore. autodiff_deferred (Forward, Const (f), Const{Nothing}, Const (ctx), args... )
52
+ function gpu_fwd (ctx, config, f, args... )
53
+ EnzymeCore. autodiff_deferred (EnzymeCore . set_runtime_activity ( Forward, config) , Const (f), Const{Nothing}, Const (ctx), args... )
54
54
return nothing
55
55
end
56
56
57
57
function EnzymeRules. forward (
58
+ config,
58
59
func:: Const{<:Kernel{CPU}} ,
59
60
:: Type{Const{Nothing}} ,
60
61
args... ;
@@ -65,10 +66,11 @@ function EnzymeRules.forward(
65
66
f = kernel. f
66
67
fwd_kernel = similar (kernel, cpu_fwd)
67
68
68
- fwd_kernel (f, args... ; ndrange, workgroupsize)
69
+ fwd_kernel (config, f, args... ; ndrange, workgroupsize)
69
70
end
70
71
71
72
function EnzymeRules. forward (
73
+ config,
72
74
func:: Const{<:Kernel{<:GPU}} ,
73
75
:: Type{Const{Nothing}} ,
74
76
args... ;
@@ -79,7 +81,7 @@ function EnzymeRules.forward(
79
81
f = kernel. f
80
82
fwd_kernel = similar (kernel, gpu_fwd)
81
83
82
- fwd_kernel (f, args... ; ndrange, workgroupsize)
84
+ fwd_kernel (config, f, args... ; ndrange, workgroupsize)
83
85
end
84
86
85
87
_enzyme_mkcontext (kernel:: Kernel{CPU} , ndrange, iterspace, dynamic) =
@@ -253,7 +255,7 @@ function gpu_rev(
253
255
end
254
256
255
257
function EnzymeRules. augmented_primal (
256
- config:: Config ,
258
+ config:: RevConfig ,
257
259
func:: Const{<:Kernel} ,
258
260
:: Type{Const{Nothing}} ,
259
261
args:: Vararg{Any, N} ;
@@ -311,7 +313,7 @@ function EnzymeRules.augmented_primal(
311
313
end
312
314
313
315
function EnzymeRules. reverse (
314
- config:: Config ,
316
+ config:: RevConfig ,
315
317
func:: Const{<:Kernel} ,
316
318
:: Type{<:EnzymeCore.Annotation} ,
317
319
tape,
364
366
# synchronize rule and then synchronize where the launch was. However, with the current
365
367
# kernel semantics this ensures correctness for now.
366
368
function EnzymeRules. augmented_primal (
367
- config:: Config ,
369
+ config:: RevConfig ,
368
370
func:: Const{typeof(synchronize)} ,
369
371
:: Type{Const{Nothing}} ,
370
372
backend:: T ,
@@ -374,7 +376,7 @@ function EnzymeRules.augmented_primal(
374
376
end
375
377
376
378
function EnzymeRules. reverse (
377
- config:: Config ,
379
+ config:: RevConfig ,
378
380
func:: Const{typeof(synchronize)} ,
379
381
:: Type{Const{Nothing}} ,
380
382
tape,
0 commit comments