Skip to content

Commit bd3b61b

Browse files
authored
Enzyme: Reversemode cudaconvert (#2476)
1 parent 4a215e3 commit bd3b61b

File tree

1 file changed

+28
-1
lines changed

1 file changed

+28
-1
lines changed

ext/EnzymeCoreExt.jl

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ end
5555
function EnzymeCore.EnzymeRules.forward(ofn::Const{typeof(cudaconvert)},
5656
::Type{RT}, x::IT) where {RT, IT}
5757
if RT <: Duplicated
58-
RT(ofn.val(x.val), ofn.val(x.dval))
58+
Duplicated(ofn.val(x.val), ofn.val(x.dval))
5959
elseif RT <: Const
6060
ofn.val(x.val)::eltype(RT)
6161
elseif RT <: DuplicatedNoNeed
@@ -73,6 +73,33 @@ function EnzymeCore.EnzymeRules.forward(ofn::Const{typeof(cudaconvert)},
7373
end
7474
end
7575

76+
function EnzymeCore.EnzymeRules.augmented_primal(config, ofn::Const{typeof(cudaconvert)}, ::Type{RT}, x::IT) where {RT, IT}
77+
primal = if EnzymeRules.needs_primal(config)
78+
ofn.val(x.val)
79+
else
80+
nothing
81+
end
82+
83+
shadow = if EnzymeRules.needs_shadow(config)
84+
if EnzymeRules.width(config) == 1
85+
ofn.val(x.dval)
86+
else
87+
ntuple(Val(EnzymeRules.width(config))) do i
88+
Base.@_inline_meta
89+
ofn.val(x.dval[i])
90+
end
91+
end
92+
else
93+
nothing
94+
end
95+
return EnzymeRules.AugmentedReturn{(EnzymeRules.needs_primal(config) ? eltype(RT) : Nothing), (EnzymeRules.needs_shadow(config) ? (EnzymeRules.width(config) == 1 ? eltype(RT) : NTuple{EnzymeRules.width(config), eltype(RT)}) : Nothing), Nothing}(primal, shadow, nothing)
96+
end
97+
98+
function EnzymeCore.EnzymeRules.reverse(config, ofn::Const{typeof(cudaconvert)}, ::Type{RT}, tape, x::IT) where {RT, IT}
99+
(nothing,)
100+
end
101+
102+
76103
function EnzymeCore.EnzymeRules.forward(ofn::Const{Type{CT}},
77104
::Type{RT}, uval::EnzymeCore.Annotation{UndefInitializer}, args...) where {CT <: CuArray, RT}
78105
primargs = ntuple(Val(length(args))) do i

0 commit comments

Comments
 (0)