Skip to content

Commit c332e89

Browse files
authored
Throw error for in-place plans
1 parent 266c88f commit c332e89

File tree

2 files changed

+16
-4
lines changed

2 files changed

+16
-4
lines changed

ext/AbstractFFTsChainRulesCoreExt.jl

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -161,12 +161,18 @@ end
161161

162162
# plans
163163
function ChainRulesCore.frule((_, _, Δx), ::typeof(*), P::AbstractFFTs.Plan, x::AbstractArray)
164-
y = P * x
164+
y = P * x
165+
if Base.mightalias(y, x)
166+
throw(ArgumentError("differentiation rules are not supported for in-place plans"))
167+
end
165168
Δy = P * Δx
166169
return y, Δy
167170
end
168171
function ChainRulesCore.rrule(::typeof(*), P::AbstractFFTs.Plan, x::AbstractArray)
169172
y = P * x
173+
if Base.mightalias(y, x)
174+
throw(ArgumentError("differentiation rules are not supported for in-place plans"))
175+
end
170176
project_x = ChainRulesCore.ProjectTo(x)
171177
Pt = P'
172178
function mul_plan_pullback(ȳ)
@@ -177,12 +183,18 @@ function ChainRulesCore.rrule(::typeof(*), P::AbstractFFTs.Plan, x::AbstractArra
177183
end
178184

179185
function ChainRulesCore.frule((_, ΔP, Δx), ::typeof(*), P::AbstractFFTs.ScaledPlan, x::AbstractArray)
180-
y = P * x
186+
y = P * x
187+
if Base.mightalias(y, x)
188+
throw(ArgumentError("differentiation rules are not supported for in-place plans"))
189+
end
181190
Δy = P * Δx .+ (ΔP.scale / P.scale) .* y
182191
return y, Δy
183192
end
184193
function ChainRulesCore.rrule(::typeof(*), P::AbstractFFTs.ScaledPlan, x::AbstractArray)
185194
y = P * x
195+
if Base.mightalias(y, x)
196+
throw(ArgumentError("differentiation rules are not supported for in-place plans"))
197+
end
186198
Pt = P'
187199
scale = P.scale
188200
project_x = ChainRulesCore.ProjectTo(x)

src/definitions.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -638,7 +638,7 @@ function _mul(p::AdjointPlan{T}, x::AbstractArray, ::RealProjectionStyle) where
638638
[(i == 1 || (i == n && 2 * (i - 1)) == d) ? N : 2 * N for i in 1:n],
639639
ntuple(i -> i == halfdim ? n : 1, Val(ndims(x)))
640640
)
641-
return p.p \ (x ./ scale)
641+
return p.p \ (x ./ convert(typeof(x), scale))
642642
end
643643

644644
function _mul(p::AdjointPlan{T}, x::AbstractArray, ::RealInverseProjectionStyle) where {T}
@@ -651,7 +651,7 @@ function _mul(p::AdjointPlan{T}, x::AbstractArray, ::RealInverseProjectionStyle)
651651
[(i == 1 || (i == n && 2 * (i - 1)) == d) ? 1 : 2 for i in 1:n],
652652
ntuple(i -> i == halfdim ? n : 1, Val(ndims(x)))
653653
)
654-
return scale ./ N .* (p.p \ x)
654+
return convert(typeof(x), scale) ./ N .* (p.p \ x)
655655
end
656656

657657
# Analogously to ScaledPlan, define both plan_inv (for no caching) and inv (caches inner plan only).

0 commit comments

Comments
 (0)