@@ -14,21 +14,41 @@ module EnzymeExt
14
14
15
15
EnzymeRules. inactive (:: Type{StaticSize} , x... ) = nothing
16
16
17
+ # https://github.com/EnzymeAD/Enzyme.jl/issues/1516
18
+ # On the CPU `autodiff_deferred` can deadlock.
17
19
function fwd (ctx, f, args... )
18
20
EnzymeCore. autodiff_deferred (Forward, Const (f), Const{Nothing}, Const (ctx), args... )
19
21
return nothing
20
22
end
21
23
22
24
function aug_fwd (ctx, f:: FT , :: Val{ModifiedBetween} , subtape, args... ) where {ModifiedBetween, FT}
23
25
TapeType = EnzymeCore. tape_type (ReverseSplitModified (ReverseSplitWithPrimal, Val (ModifiedBetween)), Const{Core. Typeof (f)}, Const{Nothing}, Const{Core. Typeof (ctx)}, map (Core. Typeof, args)... )
24
- forward, reverse = EnzymeCore. autodiff_deferred_thunk (ReverseSplitModified (ReverseSplitWithPrimal, Val (ModifiedBetween)), TapeType, Const{Core. Typeof (f)}, Const{Nothing}, Const{Core. Typeof (ctx)}, map (Core. Typeof, args)... )
26
+ forward, _ = EnzymeCore. autodiff_deferred_thunk (ReverseSplitModified (ReverseSplitWithPrimal, Val (ModifiedBetween)), TapeType, Const{Core. Typeof (f)}, Const{Nothing}, Const{Core. Typeof (ctx)}, map (Core. Typeof, args)... )
25
27
subtape[__groupindex (ctx)] = forward (Const (f), Const (ctx), args... )[1 ]
26
28
return nothing
27
29
end
28
30
29
31
function rev (ctx, f:: FT , :: Val{ModifiedBetween} , subtape, args... ) where {ModifiedBetween, FT}
30
32
TapeType = EnzymeCore. tape_type (ReverseSplitModified (ReverseSplitWithPrimal, Val (ModifiedBetween)), Const{Core. Typeof (f)}, Const{Nothing}, Const{Core. Typeof (ctx)}, map (Core. Typeof, args)... )
31
- forward, reverse = EnzymeCore. autodiff_deferred_thunk (ReverseSplitModified (ReverseSplitWithPrimal, Val (ModifiedBetween)), TapeType, Const{Core. Typeof (f)}, Const{Nothing}, Const{Core. Typeof (ctx)}, map (Core. Typeof, args)... )
33
+ _, reverse = EnzymeCore. autodiff_deferred_thunk (ReverseSplitModified (ReverseSplitWithPrimal, Val (ModifiedBetween)), TapeType, Const{Core. Typeof (f)}, Const{Nothing}, Const{Core. Typeof (ctx)}, map (Core. Typeof, args)... )
34
+ tp = subtape[__groupindex (ctx)]
35
+ reverse (Const (f), Const (ctx), args... , tp)
36
+ return nothing
37
+ end
38
+
39
+ function fwd_cpu (ctx, f, args... )
40
+ EnzymeCore. autodiff (Forward, Const (f), Const{Nothing}, Const (ctx), args... )
41
+ return nothing
42
+ end
43
+
44
+ function aug_fwd_cpu (ctx, f:: FT , :: Val{ModifiedBetween} , subtape, args... ) where {ModifiedBetween, FT}
45
+ forward, _ = EnzymeCore. autodiff_thunk (ReverseSplitModified (ReverseSplitWithPrimal, Val (ModifiedBetween)), Const{Core. Typeof (f)}, Const{Nothing}, Const{Core. Typeof (ctx)}, map (Core. Typeof, args)... )
46
+ subtape[__groupindex (ctx)] = forward (Const (f), Const (ctx), args... )[1 ]
47
+ return nothing
48
+ end
49
+
50
+ function rev_cpu (ctx, f:: FT , :: Val{ModifiedBetween} , subtape, args... ) where {ModifiedBetween, FT}
51
+ _, reverse = EnzymeCore. autodiff_thunk (ReverseSplitModified (ReverseSplitWithPrimal, Val (ModifiedBetween)), Const{Core. Typeof (f)}, Const{Nothing}, Const{Core. Typeof (ctx)}, map (Core. Typeof, args)... )
32
52
tp = subtape[__groupindex (ctx)]
33
53
reverse (Const (f), Const (ctx), args... , tp)
34
54
return nothing
@@ -42,6 +62,15 @@ module EnzymeExt
42
62
fwd_kernel (f, args... ; ndrange, workgroupsize)
43
63
end
44
64
65
+ function EnzymeRules. forward (func:: Const{<:Kernel{CPU}} , :: Type{Const{Nothing}} , args... ; ndrange= nothing , workgroupsize= nothing )
66
+ kernel = func. val
67
+ f = kernel. f
68
+ fwd_kernel = similar (kernel, fwd_cpu)
69
+
70
+ fwd_kernel (f, args... ; ndrange, workgroupsize)
71
+ end
72
+
73
+
45
74
@inline function make_active_byref (f:: F , :: Val{ActiveTys} ) where {F, ActiveTys}
46
75
if ! any (ActiveTys)
47
76
return f
@@ -103,7 +132,7 @@ module EnzymeExt
103
132
104
133
subtape = Array {TapeType} (undef, size (blocks (iterspace)))
105
134
106
- aug_kernel = similar (kernel, aug_fwd )
135
+ aug_kernel = similar (kernel, aug_fwd_cpu )
107
136
108
137
aug_kernel (f, ModifiedBetween, subtape, args2... ; ndrange, workgroupsize)
109
138
@@ -115,7 +144,7 @@ module EnzymeExt
115
144
return res
116
145
end
117
146
118
- function EnzymeRules. reverse (config:: Config , func:: Const{<:Kernel} , :: Type{<:EnzymeCore.Annotation} , tape, args:: Vararg{Any, N} ; ndrange= nothing , workgroupsize= nothing ) where N
147
+ function EnzymeRules. reverse (config:: Config , func:: Const{<:Kernel{CPU} } , :: Type{<:EnzymeCore.Annotation} , tape, args:: Vararg{Any, N} ; ndrange= nothing , workgroupsize= nothing ) where N
119
148
subtape, arg_refs = tape
120
149
121
150
args2 = ntuple (Val (N)) do i
@@ -138,7 +167,7 @@ module EnzymeExt
138
167
139
168
ModifiedBetween = Val ((overwritten (config)[1 ], false , overwritten (config)[2 : end ]. .. ))
140
169
141
- rev_kernel = similar (func. val, rev )
170
+ rev_kernel = similar (func. val, rev_cpu )
142
171
rev_kernel (f, ModifiedBetween, subtape, args2... ; ndrange, workgroupsize)
143
172
return ntuple (Val (N)) do i
144
173
Base. @_inline_meta
0 commit comments