@@ -291,7 +291,7 @@ function EnzymeCore.EnzymeRules.augmented_primal(config, ofn::Const{Type{CT}}, :
291
291
else
292
292
nothing
293
293
end
294
- return EnzymeRules. AugmentedReturn {(EnzymeRules.needs_primal(config) ? CT : Nothing), (EnzymeRules.needs_shadow(config) ? (EnzymeRules.width(config) == 1 ? CT : NTuple{EnzymeRules.width(config), CT }) : Nothing), Nothing} (primal, shadow, nothing )
294
+ return EnzymeRules. AugmentedReturn {(EnzymeRules.needs_primal(config) ? eltype(RT) : Nothing), (EnzymeRules.needs_shadow(config) ? (EnzymeRules.width(config) == 1 ? eltype(RT) : NTuple{EnzymeRules.width(config), eltype(RT) }) : Nothing), Nothing} (primal, shadow, nothing )
295
295
end
296
296
297
297
function EnzymeCore. EnzymeRules. reverse (config, ofn:: Const{Type{CT}} , :: Type{RT} , tape, A:: EnzymeCore.Annotation{UndefInitializer} , args:: Vararg{EnzymeCore.Annotation, N} ) where {CT <: CuArray , RT, N}
@@ -325,7 +325,7 @@ function EnzymeCore.EnzymeRules.augmented_primal(config, ofn::Const{Type{CT}}, :
325
325
else
326
326
nothing
327
327
end
328
- return EnzymeRules. AugmentedReturn {(EnzymeRules.needs_primal(config) ? CT : Nothing), (EnzymeRules.needs_shadow(config) ? (EnzymeRules.width(config) == 1 ? CT : NTuple{EnzymeRules.width(config), CT }) : Nothing), Nothing} (primal, shadow, nothing )
328
+ return EnzymeRules. AugmentedReturn {(EnzymeRules.needs_primal(config) ? eltype(RT) : Nothing), (EnzymeRules.needs_shadow(config) ? (EnzymeRules.width(config) == 1 ? eltype(RT) : NTuple{EnzymeRules.width(config), eltype(RT) }) : Nothing), Nothing} (primal, shadow, nothing )
329
329
end
330
330
331
331
function EnzymeCore. EnzymeRules. reverse (config, ofn:: Const{Type{CT}} , :: Type{RT} , tape, A:: EnzymeCore.Annotation{DR} , args:: Vararg{EnzymeCore.Annotation, N} ; kwargs... ) where {CT <: CuArray , DR <: CUDA.DataRef , RT, N}
0 commit comments