Skip to content

Commit a854a36

Browse files
committed
Adapt to pending Enzymecore changes
1 parent d9062a3 commit a854a36

File tree

2 files changed

+8
-6
lines changed

2 files changed

+8
-6
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ UnsafeAtomicsLLVM = "d80eeb9a-aca5-4d75-85e5-170c8b632249"
2121
[compat]
2222
Adapt = "0.4, 1.0, 2.0, 3.0, 4"
2323
Atomix = "0.1"
24-
EnzymeCore = "0.7.5"
24+
EnzymeCore = "0.8"
2525
InteractiveUtils = "1.6"
2626
LinearAlgebra = "1.6"
2727
MacroTools = "0.5"

ext/EnzymeExt.jl

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -44,17 +44,18 @@ EnzymeRules.inactive(::Type{StaticSize}, x...) = nothing
4444
# https://github.com/EnzymeAD/Enzyme.jl/issues/1516
4545
# On the CPU `autodiff_deferred` can deadlock.
4646
# Hence a specialized CPU version
47-
function cpu_fwd(ctx, f, args...)
48-
EnzymeCore.autodiff(Forward, Const(f), Const{Nothing}, Const(ctx), args...)
47+
function cpu_fwd(config, ctx, f, args...)
48+
EnzymeCore.autodiff(EnzymeCore.set_runtime_activity(Forward, config), Const(f), Const{Nothing}, Const(ctx), args...)
4949
return nothing
5050
end
5151

5252
function gpu_fwd(ctx, f, args...)
53-
EnzymeCore.autodiff_deferred(Forward, Const(f), Const{Nothing}, Const(ctx), args...)
53+
EnzymeCore.autodiff_deferred(EnzymeCore.set_runtime_activity(Forward, config), Const(f), Const{Nothing}, Const(ctx), args...)
5454
return nothing
5555
end
5656

5757
function EnzymeRules.forward(
58+
config,
5859
func::Const{<:Kernel{CPU}},
5960
::Type{Const{Nothing}},
6061
args...;
@@ -63,12 +64,13 @@ function EnzymeRules.forward(
6364
)
6465
kernel = func.val
6566
f = kernel.f
66-
fwd_kernel = similar(kernel, cpu_fwd)
67+
fwd_kernel = similar(config, kernel, cpu_fwd)
6768

6869
fwd_kernel(f, args...; ndrange, workgroupsize)
6970
end
7071

7172
function EnzymeRules.forward(
73+
config,
7274
func::Const{<:Kernel{<:GPU}},
7375
::Type{Const{Nothing}},
7476
args...;
@@ -77,7 +79,7 @@ function EnzymeRules.forward(
7779
)
7880
kernel = func.val
7981
f = kernel.f
80-
fwd_kernel = similar(kernel, gpu_fwd)
82+
fwd_kernel = similar(config, kernel, gpu_fwd)
8183

8284
fwd_kernel(f, args...; ndrange, workgroupsize)
8385
end

0 commit comments

Comments
 (0)