55
55
function EnzymeCore. EnzymeRules. forward (ofn:: Const{typeof(cudaconvert)} ,
56
56
:: Type{RT} , x:: IT ) where {RT, IT}
57
57
if RT <: Duplicated
58
- RT (ofn. val (x. val), ofn. val (x. dval))
58
+ Duplicated (ofn. val (x. val), ofn. val (x. dval))
59
59
elseif RT <: Const
60
60
ofn. val (x. val):: eltype (RT)
61
61
elseif RT <: DuplicatedNoNeed
@@ -73,6 +73,33 @@ function EnzymeCore.EnzymeRules.forward(ofn::Const{typeof(cudaconvert)},
73
73
end
74
74
end
75
75
76
+ function EnzymeCore. EnzymeRules. augmented_primal (config, ofn:: Const{typeof(cudaconvert)} , :: Type{RT} , x:: IT ) where {RT, IT}
77
+ primal = if EnzymeRules. needs_primal (config)
78
+ ofn. val (x. val)
79
+ else
80
+ nothing
81
+ end
82
+
83
+ shadow = if EnzymeRules. needs_shadow (config)
84
+ if EnzymeRules. width (config) == 1
85
+ ofn. val (x. dval)
86
+ else
87
+ ntuple (Val (EnzymeRules. width (config))) do i
88
+ Base. @_inline_meta
89
+ ofn. val (x. dval[i])
90
+ end
91
+ end
92
+ else
93
+ nothing
94
+ end
95
+ return EnzymeRules. AugmentedReturn {(EnzymeRules.needs_primal(config) ? eltype(RT) : Nothing), (EnzymeRules.needs_shadow(config) ? (EnzymeRules.width(config) == 1 ? eltype(RT) : NTuple{EnzymeRules.width(config), eltype(RT)}) : Nothing), Nothing} (primal, shadow, nothing )
96
+ end
97
+
98
+ function EnzymeCore. EnzymeRules. reverse (config, ofn:: Const{typeof(cudaconvert)} , :: Type{RT} , tape, x:: IT ) where {RT, IT}
99
+ (nothing ,)
100
+ end
101
+
102
+
76
103
function EnzymeCore. EnzymeRules. forward (ofn:: Const{Type{CT}} ,
77
104
:: Type{RT} , uval:: EnzymeCore.Annotation{UndefInitializer} , args... ) where {CT <: CuArray , RT}
78
105
primargs = ntuple (Val (length (args))) do i
0 commit comments