@@ -12,6 +12,7 @@ eltype(::Type{<:Plan{T}}) where {T} = T
12
12
13
13
# size(p) should return the size of the input array for p
14
14
size (p:: Plan , d) = size (p)[d]
15
+ output_size (p:: Plan , d) = output_size (p)[d]
15
16
ndims (p:: Plan ) = length (size (p))
16
17
length (p:: Plan ) = prod (size (p)):: Int
17
18
@@ -254,6 +255,7 @@ ScaledPlan(p::Plan{T}, scale::Number) where {T} = ScaledPlan{T}(p, scale)
254
255
ScaledPlan (p:: ScaledPlan , α:: Number ) = ScaledPlan (p. p, p. scale * α)
255
256
256
257
size (p:: ScaledPlan ) = size (p. p)
258
+ output_size (p:: ScaledPlan ) = size (p)
257
259
258
260
region (p:: ScaledPlan ) = region (p. p)
259
261
@@ -301,9 +303,12 @@ for f in (:brfft, :irfft)
301
303
end
302
304
303
305
for f in (:brfft , :irfft )
306
+ pf = Symbol (" plan_" , f)
304
307
@eval begin
305
308
$ f (x:: AbstractArray{<:Real} , d:: Integer , region= 1 : ndims (x)) = $ f (complexfloat (x), d, region)
309
+ $ pf (x:: AbstractArray{<:Real} , d:: Integer , region; kws... ) = $ pf (complexfloat (x), d, region; kws... )
306
310
$ f (x:: AbstractArray{<:Complex{<:Union{Integer,Rational}}} , d:: Integer , region= 1 : ndims (x)) = $ f (complexfloat (x), d, region)
311
+ $ pf (x:: AbstractArray{<:Complex{<:Union{Integer,Rational}}} , d:: Integer , region; kws... ) = $ pf (complexfloat (x), d, region; kws... )
307
312
end
308
313
end
309
314
@@ -343,6 +348,16 @@ function brfft_output_size(sz::Dims{N}, d::Integer, region) where {N}
343
348
return ntuple (i -> i == d1 ? d : sz[i], Val (N))
344
349
end
345
350
351
+ function output_size (p:: Plan )
352
+ if projection_style (p) == :none
353
+ return size (p)
354
+ elseif projection_style (p) == :real
355
+ return rfft_output_size (size (p), region (p))
356
+ elseif projection_style (p) == :real_inv
357
+ return brfft_output_size (size (p), irfft_dim (p), region (p))
358
+ end
359
+ end
360
+
346
361
plan_irfft (x:: AbstractArray{Complex{T}} , d:: Integer , region; kws... ) where {T} =
347
362
ScaledPlan (plan_brfft (x, d, region; kws... ),
348
363
normalization (T, brfft_output_size (x, d, region), region))
@@ -575,3 +590,58 @@ Pre-plan an optimized real-input unnormalized transform, similar to
575
590
the same as for [`brfft`](@ref).
576
591
"""
577
592
plan_brfft
593
+
594
+ # #############################################################################
595
+
596
+ region (p:: Plan ) = p. region
597
+ region (p:: ScaledPlan ) = region (p. p)
598
+
599
+ # Projection style (:none, :real, or :real_inv) to handle real FFTs
600
+ function projection_style end
601
+ # Length of halved dimension, needed only for irfft
602
+ function irfft_dim end
603
+
604
+ mutable struct AdjointPlan{T,P} <: Plan{T}
605
+ p:: P
606
+ pinv:: Plan
607
+ AdjointPlan {T,P} (p) where {T,P} = new (p)
608
+ # always have adjoint inside scaled
609
+ AdjointPlan {T,P} (p:: P ) where {T,P<: ScaledPlan{T} } = ScaledPlan {T} (AdjointPlan {T} (p. p), p. scale)
610
+ AdjointPlan {T,P} (p:: AdjointPlan{T} ) where {T,P} = new (p. p)
611
+ end
612
+
613
+ AdjointPlan {T} (p:: P ) where {T,P} = AdjointPlan {T,P} (p)
614
+ AdjointPlan (p:: Plan{T} ) where {T} = AdjointPlan {T} (p)
615
+ Base. adjoint (p:: Plan{T} ) where {T} = AdjointPlan {T} (p)
616
+
617
+ size (p:: AdjointPlan ) = output_size (p)
618
+ output_size (p:: AdjointPlan ) = size (p)
619
+
620
+ function Base.:* (p:: AdjointPlan{T} , x:: AbstractArray ) where {T}
621
+ dims = region (p. p)
622
+ halfdim = first (dims)
623
+ d = size (p. p, halfdim)
624
+ n = output_size (p. p, halfdim)
625
+ if projection_style (p. p) == :none
626
+ N = normalization (T, size (p. p), dims)
627
+ return 1 / N * (p. p \ x)
628
+ elseif projection_style (p. p) == :real
629
+ N = normalization (T, size (p. p), dims)
630
+ scale = reshape (
631
+ [(i == 1 || (i == n && 2 * (i - 1 )) == d) ? 1 : 2 for i in 1 : n],
632
+ ntuple (i -> i == first (dims) ? n : 1 , Val (ndims (x)))
633
+ )
634
+ return 1 / N * (p. p \ (x ./ scale))
635
+ elseif projection_style (p. p) == :real_inv
636
+ N = normalization (real (T), output_size (p. p), dims)
637
+ scale = reshape (
638
+ [(i == 1 || (i == d && 2 * (i - 1 )) == n) ? 1 : 2 for i in 1 : d],
639
+ ntuple (i -> i == first (dims) ? d : 1 , Val (ndims (x)))
640
+ )
641
+ return 1 / N * scale .* (p. p \ x)
642
+ else
643
+ error (" plan must define a valid projection style" )
644
+ end
645
+ end
646
+
647
+ plan_inv (p:: AdjointPlan ) = AdjointPlan (plan_inv (p. p))
0 commit comments