Skip to content

Commit 8e8ce8a

Browse files
authored
Avoid deadlock in EnzymeExt (#478)
1 parent bf25c71 commit 8e8ce8a

File tree

1 file changed

+34
-5
lines changed

1 file changed

+34
-5
lines changed

ext/EnzymeExt.jl

Lines changed: 34 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,21 +14,41 @@ module EnzymeExt
1414

1515
EnzymeRules.inactive(::Type{StaticSize}, x...) = nothing
1616

17+
# https://github.com/EnzymeAD/Enzyme.jl/issues/1516
18+
# On the CPU `autodiff_deferred` can deadlock.
1719
function fwd(ctx, f, args...)
1820
EnzymeCore.autodiff_deferred(Forward, Const(f), Const{Nothing}, Const(ctx), args...)
1921
return nothing
2022
end
2123

2224
function aug_fwd(ctx, f::FT, ::Val{ModifiedBetween}, subtape, args...) where {ModifiedBetween, FT}
2325
TapeType = EnzymeCore.tape_type(ReverseSplitModified(ReverseSplitWithPrimal, Val(ModifiedBetween)), Const{Core.Typeof(f)}, Const{Nothing}, Const{Core.Typeof(ctx)}, map(Core.Typeof, args)...)
24-
forward, reverse = EnzymeCore.autodiff_deferred_thunk(ReverseSplitModified(ReverseSplitWithPrimal, Val(ModifiedBetween)), TapeType, Const{Core.Typeof(f)}, Const{Nothing}, Const{Core.Typeof(ctx)}, map(Core.Typeof, args)...)
26+
forward, _ = EnzymeCore.autodiff_deferred_thunk(ReverseSplitModified(ReverseSplitWithPrimal, Val(ModifiedBetween)), TapeType, Const{Core.Typeof(f)}, Const{Nothing}, Const{Core.Typeof(ctx)}, map(Core.Typeof, args)...)
2527
subtape[__groupindex(ctx)] = forward(Const(f), Const(ctx), args...)[1]
2628
return nothing
2729
end
2830

2931
function rev(ctx, f::FT, ::Val{ModifiedBetween}, subtape, args...) where {ModifiedBetween, FT}
3032
TapeType = EnzymeCore.tape_type(ReverseSplitModified(ReverseSplitWithPrimal, Val(ModifiedBetween)), Const{Core.Typeof(f)}, Const{Nothing}, Const{Core.Typeof(ctx)}, map(Core.Typeof, args)...)
31-
forward, reverse = EnzymeCore.autodiff_deferred_thunk(ReverseSplitModified(ReverseSplitWithPrimal, Val(ModifiedBetween)), TapeType, Const{Core.Typeof(f)}, Const{Nothing}, Const{Core.Typeof(ctx)}, map(Core.Typeof, args)...)
33+
_, reverse = EnzymeCore.autodiff_deferred_thunk(ReverseSplitModified(ReverseSplitWithPrimal, Val(ModifiedBetween)), TapeType, Const{Core.Typeof(f)}, Const{Nothing}, Const{Core.Typeof(ctx)}, map(Core.Typeof, args)...)
34+
tp = subtape[__groupindex(ctx)]
35+
reverse(Const(f), Const(ctx), args..., tp)
36+
return nothing
37+
end
38+
39+
function fwd_cpu(ctx, f, args...)
40+
EnzymeCore.autodiff(Forward, Const(f), Const{Nothing}, Const(ctx), args...)
41+
return nothing
42+
end
43+
44+
function aug_fwd_cpu(ctx, f::FT, ::Val{ModifiedBetween}, subtape, args...) where {ModifiedBetween, FT}
45+
forward, _ = EnzymeCore.autodiff_thunk(ReverseSplitModified(ReverseSplitWithPrimal, Val(ModifiedBetween)), Const{Core.Typeof(f)}, Const{Nothing}, Const{Core.Typeof(ctx)}, map(Core.Typeof, args)...)
46+
subtape[__groupindex(ctx)] = forward(Const(f), Const(ctx), args...)[1]
47+
return nothing
48+
end
49+
50+
function rev_cpu(ctx, f::FT, ::Val{ModifiedBetween}, subtape, args...) where {ModifiedBetween, FT}
51+
_, reverse = EnzymeCore.autodiff_thunk(ReverseSplitModified(ReverseSplitWithPrimal, Val(ModifiedBetween)), Const{Core.Typeof(f)}, Const{Nothing}, Const{Core.Typeof(ctx)}, map(Core.Typeof, args)...)
3252
tp = subtape[__groupindex(ctx)]
3353
reverse(Const(f), Const(ctx), args..., tp)
3454
return nothing
@@ -42,6 +62,15 @@ module EnzymeExt
4262
fwd_kernel(f, args...; ndrange, workgroupsize)
4363
end
4464

65+
function EnzymeRules.forward(func::Const{<:Kernel{CPU}}, ::Type{Const{Nothing}}, args...; ndrange=nothing, workgroupsize=nothing)
66+
kernel = func.val
67+
f = kernel.f
68+
fwd_kernel = similar(kernel, fwd_cpu)
69+
70+
fwd_kernel(f, args...; ndrange, workgroupsize)
71+
end
72+
73+
4574
@inline function make_active_byref(f::F, ::Val{ActiveTys}) where {F, ActiveTys}
4675
if !any(ActiveTys)
4776
return f
@@ -103,7 +132,7 @@ module EnzymeExt
103132

104133
subtape = Array{TapeType}(undef, size(blocks(iterspace)))
105134

106-
aug_kernel = similar(kernel, aug_fwd)
135+
aug_kernel = similar(kernel, aug_fwd_cpu)
107136

108137
aug_kernel(f, ModifiedBetween, subtape, args2...; ndrange, workgroupsize)
109138

@@ -115,7 +144,7 @@ module EnzymeExt
115144
return res
116145
end
117146

118-
function EnzymeRules.reverse(config::Config, func::Const{<:Kernel}, ::Type{<:EnzymeCore.Annotation}, tape, args::Vararg{Any, N}; ndrange=nothing, workgroupsize=nothing) where N
147+
function EnzymeRules.reverse(config::Config, func::Const{<:Kernel{CPU}}, ::Type{<:EnzymeCore.Annotation}, tape, args::Vararg{Any, N}; ndrange=nothing, workgroupsize=nothing) where N
119148
subtape, arg_refs = tape
120149

121150
args2 = ntuple(Val(N)) do i
@@ -138,7 +167,7 @@ module EnzymeExt
138167

139168
ModifiedBetween = Val((overwritten(config)[1], false, overwritten(config)[2:end]...))
140169

141-
rev_kernel = similar(func.val, rev)
170+
rev_kernel = similar(func.val, rev_cpu)
142171
rev_kernel(f, ModifiedBetween, subtape, args2...; ndrange, workgroupsize)
143172
return ntuple(Val(N)) do i
144173
Base.@_inline_meta

0 commit comments

Comments
 (0)