Skip to content

Commit d72cdaa

Browse files
authored
Add Enzyme sum derivatives (#2471)
1 parent 9f93343 commit d72cdaa

File tree

2 files changed

+173
-0
lines changed

2 files changed

+173
-0
lines changed

ext/EnzymeCoreExt.jl

Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ else
1212
using ..EnzymeCore
1313
using ..EnzymeCore.EnzymeRules
1414
end
15+
using GPUArrays
1516

1617
function EnzymeCore.EnzymeRules.inactive(::typeof(CUDA.CUBLAS.handle))
1718
return nothing
@@ -516,5 +517,163 @@ function EnzymeCore.EnzymeRules.noalias(::Type{CT}, ::UndefInitializer, args...)
516517
return nothing
517518
end
518519

520+
function EnzymeCore.EnzymeRules.forward(ofn::Const{typeof(GPUArrays.mapreducedim!)},
521+
::Type{RT},
522+
f::EnzymeCore.Const{typeof(Base.identity)},
523+
op::EnzymeCore.Const{typeof(Base.add_sum)},
524+
R::EnzymeCore.Annotation{<:AnyCuArray{T}}, A; init) where {RT, T}
525+
if R isa Const || R isa Duplicated || R isa BatchDuplicated
526+
ofn.val(f.val, op.val, R.val, A.val; init)
527+
end
528+
529+
if A isa Duplicated || A isa DuplicatedNoNeed
530+
if A isa Const
531+
Base.fill!(R.dval, zero(T))
532+
else
533+
ofn.val(f.val, op.val, R.dval, A.dval)
534+
end
535+
elseif R isa BatchDuplicated || R isa BatchDuplicatedNoNeed
536+
ntuple(Val(EnzymeRules.batch_width(R))) do i
537+
Base.@_inline_meta
538+
if A isa Const
539+
Base.fill!(R.dval[i], zero(T))
540+
else
541+
ofn.val(f.val, op.val, R.dval[i], A.dval[i])
542+
end
543+
nothing
544+
end
545+
end
546+
547+
if RT <: Duplicated
548+
return R
549+
elseif RT <: Const
550+
return R.val
551+
elseif RT <: DuplicatedNoNeed
552+
return R.dval
553+
elseif RT <: BatchDuplicated
554+
return R
555+
else
556+
return R.dval
557+
end
558+
end
559+
560+
561+
function EnzymeCore.EnzymeRules.augmented_primal(config, ofn::Const{typeof(GPUArrays.mapreducedim!)},
562+
::Type{RT},
563+
f::EnzymeCore.Const{typeof(Base.identity)},
564+
op::EnzymeCore.Const{typeof(Base.add_sum)},
565+
R::EnzymeCore.Annotation{<:AnyCuArray{T}}, A; init) where {RT, T<:AbstractFloat}
566+
if A isa Const || A isa Duplicated || A isa BatchDuplicated
567+
ofn.val(f.val, op.val, R.val, A.val)
568+
end
569+
570+
primal = if EnzymeRules.needs_primal(config)
571+
R.val
572+
else
573+
nothing
574+
end
575+
576+
shadow = if EnzymeRules.needs_shadow(config)
577+
R.dval
578+
else
579+
nothing
580+
end
581+
return EnzymeRules.AugmentedReturn(primal, shadow, nothing)
582+
end
583+
584+
function EnzymeCore.EnzymeRules.reverse(config, ofn::Const{typeof(GPUArrays.mapreducedim!)},
585+
::Type{RT},
586+
tape,
587+
f::EnzymeCore.Const{typeof(Base.identity)},
588+
op::EnzymeCore.Const{typeof(Base.add_sum)},
589+
R::EnzymeCore.Annotation{<:AnyCuArray{T}}, A; init) where {RT, T<:AbstractFloat}
590+
591+
if !(A isa Const) && !(R isa Const)
592+
if A isa Duplicated || A isa DuplicatedNoNeed
593+
A.dval .+= R.dval
594+
Base.fill!(R.dval, zero(T))
595+
elseif A isa BatchDuplicated || A isa BatchDuplicatedNoNeed
596+
ntuple(Val(EnzymeRules.batch_width(A))) do i
597+
Base.@_inline_meta
598+
A.dval[i] .+= R.dval[i]
599+
Base.fill!(R.dval[i], zero(T))
600+
nothing
601+
end
602+
end
603+
end
604+
605+
return (nothing, nothing, nothing, nothing)
606+
end
607+
608+
function EnzymeCore.EnzymeRules.forward(ofn::Const{typeof(GPUArrays._mapreduce)},
609+
::Type{RT},
610+
f::EnzymeCore.Const{typeof(Base.identity)},
611+
op::EnzymeCore.Const{typeof(Base.add_sum)},
612+
A::EnzymeCore.Annotation{<:AnyCuArray{T}}; dims::D, init) where {RT, T, D}
613+
if RT <: Const
614+
ofn.val(f.val, op.val, A.val; dims, init)
615+
elseif RT <: Duplicated
616+
(
617+
ofn.val(f.val, op.val, A.val; dims, init),
618+
ofn.val(f.val, op.val, A.dval; dims, init)
619+
)
620+
elseif RT <: DuplicatedNoNeed
621+
ofn.val(f.val, op.val, A.dval; dims, init)
622+
elseif RT <: BatchDuplicated
623+
(
624+
ofn.val(f.val, op.val, A.val; dims, init),
625+
ntuple(Val(EnzymeRules.batch_width(RT))) do i
626+
Base.@_inline_meta
627+
ofn.val(f.val, op.val, A.dval[i]; dims, init)
628+
end
629+
)
630+
else
631+
@assert RT <: BatchDuplicatedNoNeed
632+
ntuple(Val(EnzymeRules.batch_width(RT))) do i
633+
Base.@_inline_meta
634+
ofn.val(f.val, op.val, A.dval[i]; dims, init)
635+
end
636+
end
637+
end
638+
639+
function EnzymeCore.EnzymeRules.augmented_primal(config, ofn::Const{typeof(GPUArrays._mapreduce)},
640+
::Type{Active{RT}},
641+
f::EnzymeCore.Const{typeof(Base.identity)},
642+
op::EnzymeCore.Const{typeof(Base.add_sum)},
643+
A::EnzymeCore.Annotation{<:AnyCuArray{T}}; dims::D, init) where {RT, T<:AbstractFloat, D}
644+
primal = if EnzymeRules.needs_primal(config)
645+
ofn.val(f.val, op.val, A.val; dims, init)
646+
else
647+
nothing
648+
end
649+
650+
shadow = if EnzymeRules.needs_shadow(config)
651+
A.dval
652+
else
653+
nothing
654+
end
655+
return EnzymeRules.AugmentedReturn(primal, shadow, nothing)
656+
end
657+
658+
function EnzymeCore.EnzymeRules.reverse(config, ofn::Const{typeof(GPUArrays._mapreduce)},
659+
dres::Active{RT},
660+
tape,
661+
f::EnzymeCore.Const{typeof(Base.identity)},
662+
op::EnzymeCore.Const{typeof(Base.add_sum)},
663+
A::EnzymeCore.Annotation{<:AnyCuArray{T}}; dims::D, init) where {RT, T<:AbstractFloat, D}
664+
665+
if A isa Duplicated || A isa DuplicatedNoNeed
666+
A.dval .+= dres.val
667+
elseif A isa BatchDuplicated || A isa BatchDuplicatedNoNeed
668+
ntuple(Val(EnzymeRules.batch_width(A))) do i
669+
Base.@_inline_meta
670+
A.dval[i] .+= dres.val
671+
nothing
672+
end
673+
end
674+
675+
return (nothing, nothing, nothing)
676+
end
677+
519678
end # module
520679

test/extensions/enzyme.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,20 @@ firstsum(x, y) = first(x .+ y)
103103
#@test res[2] ≈ 1.2
104104
end
105105

106+
@testset "Forward sum" begin
107+
x = CuArray([1.0, 2.0, 3.0, 4.0])
108+
dx = CuArray([100., 300.0, 500.0, 700.0])
109+
res = Enzyme.autodiff(Forward, sum, Duplicated(x, dx))
110+
@test res[1] 100+300+500+700.
111+
end
112+
113+
@testset "Reverse sum" begin
114+
x = CuArray([1.0, 2.0, 3.0, 4.0])
115+
dx = CuArray([0., 0.0, 0.0, 0.0])
116+
Enzyme.autodiff(Reverse, sum, Duplicated(x, dx))
117+
@test all(dx .≈ 1.0)
118+
end
119+
106120
# TODO once reverse kernels are in
107121
# function togpu(x)
108122
# x = CuArray(x)

0 commit comments

Comments
 (0)