-
Notifications
You must be signed in to change notification settings - Fork 54
chain rules for DCT #273
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
chain rules for DCT #273
Changes from 7 commits
7abaeeb
1a9cd68
d87c550
1f321f8
8da2495
5a975ab
39c7b6d
10ac31c
3ee37e7
74b8c75
a31b52d
3d6e160
d0ef838
6b8d178
608aa02
26df888
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
module FFTWChainRulesCoreExt | ||
|
||
using FFTW | ||
using FFTW: r2r | ||
using ChainRulesCore | ||
|
||
# DCT | ||
|
||
function ChainRulesCore.frule(Δ, ::typeof(dct), x::AbstractArray, region...) | ||
vpuri3 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
Δx = Δ[2] | ||
y = dct(x, region...) | ||
Δy = dct(Δx, region...) | ||
return y, Δy | ||
end | ||
|
||
function ChainRulesCore.rrule(::typeof(dct), x::AbstractArray, region...) | ||
vpuri3 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
y = dct(x, region...) | ||
project_x = ProjectTo(x) | ||
|
||
function dct_pullback(ȳ) | ||
f̄ = NoTangent() | ||
x̄ = project_x(idct(unthunk(ȳ), region...)) | ||
r̄ = NoTangent() | ||
|
||
if isempty(region) | ||
return f̄, x̄ | ||
else | ||
return f̄, x̄, r̄ | ||
end | ||
end | ||
|
||
return y, dct_pullback | ||
end | ||
|
||
# IDCT | ||
|
||
function ChainRulesCore.frule(Δ, ::typeof(idct), x::AbstractArray, region...) | ||
Δx = Δ[2] | ||
y = idct(x, region...) | ||
Δy = idct(Δx, region...) | ||
return y, Δy | ||
end | ||
|
||
function ChainRulesCore.rrule(::typeof(idct), x::AbstractArray, region...) | ||
y = idct(x, region...) | ||
project_x = ProjectTo(x) | ||
|
||
function idct_pullback(ȳ) | ||
f̄ = NoTangent() | ||
x̄ = project_x(dct(unthunk(ȳ), region...)) | ||
r̄ = NoTangent() | ||
|
||
if isempty(region) | ||
return f̄, x̄ | ||
else | ||
return f̄, x̄, r̄ | ||
end | ||
end | ||
|
||
return y, idct_pullback | ||
end | ||
|
||
# R2R | ||
|
||
function ChainRulesCore.frule(Δ, ::typeof(r2r), x::AbstractArray, region...) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It seems the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The R2R transforms are not unitary. There is some scaling involved that depends on the kind of R2R transform. Because it looks like an involved task, I chose to skip that for now. I am happy to look into that in a separate PR |
||
Δx = Δ[2] | ||
y = r2r(x, region...) | ||
Δy = r2r(Δx, region...) | ||
return y, Δy | ||
end | ||
|
||
end # module |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,8 +1,6 @@ | ||
# A bug in Julia 1.6.0's Pkg causes Preferences to be dropped during `Pkg.test()`, so we work around | ||
# it by explicitly creating a `test/Project.toml` which will correctly communicate any preferences | ||
# through to the child Julia process. X-ref: https://github.com/JuliaLang/Pkg.jl/issues/2500 | ||
|
||
[deps] | ||
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c" | ||
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" | ||
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" | ||
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" | ||
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" |
Uh oh!
There was an error while loading. Please reload this page.