Skip to content

Commit b3c2a09

Browse files
committed
Implement chain rules for FFT plans
1 parent 09fd6c2 commit b3c2a09

File tree

1 file changed

+17
-0
lines changed

1 file changed

+17
-0
lines changed

src/chainrules.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,3 +150,20 @@ function ChainRulesCore.rrule(::typeof(ifftshift), x::AbstractArray, dims)
150150
end
151151
return y, ifftshift_pullback
152152
end
153+
154+
# plans
155+
function ChainRulesCore.frule((_, _, Δx), ::typeof(*), P::Plan, x::AbstractArray)
156+
y = P * x
157+
Δy = P * Δx
158+
return y, Δy
159+
end
160+
function ChainRulesCore.rrule(::typeof(*), P::Plan, x::AbstractArray)
161+
y = P * x
162+
project_x = ChainRulesCore.ProjectTo(x)
163+
Pt = P'
164+
function mul_plan_pullback(ȳ)
165+
= project_x(Pt * ȳ)
166+
return ChainRulesCore.NoTangent(), ChainRulesCore.NoTangent(), x̄
167+
end
168+
return y, mul_plan_pullback
169+
end

0 commit comments

Comments
 (0)