From 450bc49d482b7a52d96bd46baccd5af84c14f5b2 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Mon, 22 May 2023 15:58:47 +0200 Subject: [PATCH 1/4] Add EnzymeCore extension --- Project.toml | 4 ++++ ext/AbstractFFTsEnzymeCoreExt.jl | 3 +++ 2 files changed, 7 insertions(+) create mode 100644 ext/AbstractFFTsEnzymeCoreExt.jl diff --git a/Project.toml b/Project.toml index 8e7206c1..f49bb8d0 100644 --- a/Project.toml +++ b/Project.toml @@ -8,12 +8,16 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" [weakdeps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" [extensions] AbstractFFTsChainRulesCoreExt = "ChainRulesCore" +AbstractFFTsEnzymeCoreExt = ["EnzymeCore", "LinearAlgebra"] [compat] ChainRulesCore = "1" +EnzymeCore = "0.3" julia = "^1.0" [extras] diff --git a/ext/AbstractFFTsEnzymeCoreExt.jl b/ext/AbstractFFTsEnzymeCoreExt.jl new file mode 100644 index 00000000..f9f17642 --- /dev/null +++ b/ext/AbstractFFTsEnzymeCoreExt.jl @@ -0,0 +1,3 @@ +module AbstractFFTsEnzymeCoreExt + +end # module From f4c7a7ef4179142fd16d4cb7ec705d953f76543d Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Mon, 22 May 2023 15:59:00 +0200 Subject: [PATCH 2/4] Add Enzyme as test dependency --- Project.toml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index f49bb8d0..8868dca2 100644 --- a/Project.toml +++ b/Project.toml @@ -23,9 +23,10 @@ julia = "^1.0" [extras] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d" [targets] -test = ["ChainRulesCore", "ChainRulesTestUtils", "Random", "Test", "Unitful"] +test = ["ChainRulesCore", "ChainRulesTestUtils", "Enzyme", "Random", "Test", "Unitful"] From 06bef3ab20217690d6a62f3311f1469983683b57 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Mon, 22 May 2023 22:05:45 +0200 Subject: [PATCH 3/4] List LinearAlgebra only as dependency --- Project.toml | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/Project.toml b/Project.toml index 8868dca2..8e23fc0d 100644 --- a/Project.toml +++ b/Project.toml @@ -9,11 +9,10 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" [weakdeps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" -LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" [extensions] AbstractFFTsChainRulesCoreExt = "ChainRulesCore" -AbstractFFTsEnzymeCoreExt = ["EnzymeCore", "LinearAlgebra"] +AbstractFFTsEnzymeCoreExt = "EnzymeCore" [compat] ChainRulesCore = "1" From 859abf0ea26f0ae728ab0b6ef863687617fc6912 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Mon, 22 May 2023 23:59:32 +0200 Subject: [PATCH 4/4] Add forward-mode rules --- ext/AbstractFFTsEnzymeCoreExt.jl | 55 ++++++++++++++++++++++++++++++++ 1 file changed, 55 insertions(+) diff --git a/ext/AbstractFFTsEnzymeCoreExt.jl b/ext/AbstractFFTsEnzymeCoreExt.jl index f9f17642..75a08f78 100644 --- a/ext/AbstractFFTsEnzymeCoreExt.jl +++ b/ext/AbstractFFTsEnzymeCoreExt.jl @@ -1,3 +1,58 @@ module AbstractFFTsEnzymeCoreExt +using AbstractFFTs +using AbstractFFTs.LinearAlgebra +using EnzymeCore +using EnzymeCore.EnzymeRules + +###################### +# Forward-mode rules # +###################### + +const DuplicatedOrBatchDuplicated{T} = Union{Duplicated{T},BatchDuplicated{T}} + +# since FFTs are linear, implement all forward-model rules generically at a low-level + +function EnzymeRules.forward( + func::Const{typeof(mul!)}, + RT::Type{<:Const}, + y::DuplicatedOrBatchDuplicated{<:StridedArray{T}}, + p::Const{<:AbstractFFTs.Plan{T}}, + x::DuplicatedOrBatchDuplicated{<:StridedArray{T}}, +) where {T} + val = func.val(y.val, p.val, x.val) + if x isa Duplicated && y isa Duplicated + dval = func.val(y.dval, p.val, x.dval) + elseif x isa Duplicated && y isa Duplicated + dval = map(y.dval, x.dval) do dy, dx + return func.val(dy, p.val, dx) + end + end + return nothing +end + +function EnzymeRules.forward( + func::Const{typeof(*)}, + RT::Type{ + <:Union{Const,Duplicated,DuplicatedNoNeed,BatchDuplicated,BatchDuplicatedNoNeed} + }, + p::Const{<:AbstractFFTs.Plan}, + x::DuplicatedOrBatchDuplicated{<:StridedArray}, +) + RT <: Const && return func.val(p.val, x.val) + if x isa Duplicated + dval = func.val(p.val, x.dval) + RT <: DuplicatedNoNeed && return dval + val = func.val(p.val, x.val) + RT <: Duplicated && return Duplicated(val, dval) + else # x isa BatchDuplicated + dval = map(x.dval) do dx + return func.val(p.val, dx) + end + RT <: BatchDuplicatedNoNeed && return dval + val = func.val(p.val, x.val) + RT <: BatchDuplicated && return BatchDuplicated(val, dval) + end +end + end # module