Skip to content

Commit 70e91d7

Browse files
authored
feat: support DuplicatedNoNeed as function annotation for Enzyme (#805)
1 parent 6a58124 commit 70e91d7

File tree

3 files changed

+18
-11
lines changed

3 files changed

+18
-11
lines changed

DifferentiationInterface/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "DifferentiationInterface"
22
uuid = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
33
authors = ["Guillaume Dalle", "Adrian Hill"]
4-
version = "0.7.0"
4+
version = "0.7.1"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"

DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/utils.jl

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,6 @@
1-
const AnyDuplicated = Union{
2-
Duplicated,
3-
MixedDuplicated,
4-
BatchDuplicated,
5-
BatchMixedDuplicated,
6-
DuplicatedNoNeed,
7-
BatchDuplicatedNoNeed,
8-
}
1+
const AnyDuplicated = Union{Duplicated,MixedDuplicated,BatchDuplicated,BatchMixedDuplicated}
2+
3+
const AnyDuplicatedNoNeed = Union{DuplicatedNoNeed,BatchDuplicatedNoNeed}
94

105
# until https://github.com/EnzymeAD/Enzyme.jl/pull/1545 is merged
116
function DI.pick_batchsize(::AutoEnzyme, N::Integer)
@@ -40,13 +35,25 @@ function get_f_and_df_prepared!(
4035
end
4136
end
4237

38+
function get_f_and_df_prepared!(
39+
df, f::F, ::AutoEnzyme{M,<:AnyDuplicatedNoNeed}, ::Val{B}
40+
) where {F,M,B}
41+
if B == 1
42+
return DuplicatedNoNeed(f, df)
43+
else
44+
return BatchDuplicatedNoNeed(f, df)
45+
end
46+
end
47+
4348
function function_shadow(
4449
::F, ::AutoEnzyme{M,<:Union{Const,Nothing}}, ::Val{B}
4550
) where {M,B,F}
4651
return nothing
4752
end
4853

49-
function function_shadow(f::F, ::AutoEnzyme{M,<:AnyDuplicated}, ::Val{B}) where {F,M,B}
54+
function function_shadow(
55+
f::F, ::AutoEnzyme{M,<:Union{AnyDuplicated,AnyDuplicatedNoNeed}}, ::Val{B}
56+
) where {F,M,B}
5057
if B == 1
5158
return make_zero(f)
5259
else

DifferentiationInterface/test/Back/Enzyme/test.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ backends = [
3232

3333
duplicated_backends = [
3434
AutoEnzyme(; mode=Enzyme.Forward, function_annotation=Enzyme.Duplicated),
35-
AutoEnzyme(; mode=Enzyme.Reverse, function_annotation=Enzyme.Duplicated),
35+
AutoEnzyme(; mode=Enzyme.Reverse, function_annotation=Enzyme.DuplicatedNoNeed),
3636
]
3737

3838
@testset "Checks" begin

0 commit comments

Comments
 (0)