@@ -345,16 +345,6 @@ function brfft_output_size(sz::Dims{N}, d::Integer, region) where {N}
345
345
return ntuple (i -> i == d1 ? d : sz[i], Val (N))
346
346
end
347
347
348
- function output_size (p:: Plan )
349
- if projection_style (p) == :none
350
- return size (p)
351
- elseif projection_style (p) == :real
352
- return rfft_output_size (size (p), region (p))
353
- elseif projection_style (p) == :real_inv
354
- return brfft_output_size (size (p), irfft_dim (p), region (p))
355
- end
356
- end
357
-
358
348
plan_irfft (x:: AbstractArray{Complex{T}} , d:: Integer , region; kws... ) where {T} =
359
349
ScaledPlan (plan_brfft (x, d, region; kws... ),
360
350
normalization (T, brfft_output_size (x, d, region), region))
@@ -590,11 +580,19 @@ plan_brfft
590
580
591
581
# #############################################################################
592
582
593
- # Projection style (:none, :real, or :real_inv) to handle real FFTs
594
- function projection_style end
595
- # Length of halved dimension, needed only for irfft
583
+ struct NoProjectionStyle end
584
+ struct RealProjectionStyle end
585
+ struct RealInverseProjectionStyle end
586
+ const ProjectionStyle = Union{NoProjectionStyle, RealProjectionStyle, RealInverseProjectionStyle}
587
+
596
588
function irfft_dim end
597
589
590
+ ProjectionStyle (p:: Plan ) = error (" No projection style defined for plan" )
591
+ output_size (p:: Plan ) = _output_size (p, ProjectionStyle (p))
592
+ _output_size (p:: Plan , :: NoProjectionStyle ) = size (p)
593
+ _output_size (p:: Plan , :: RealProjectionStyle ) = rfft_output_size (size (p), region (p))
594
+ _output_size (p:: Plan , :: RealInverseProjectionStyle ) = brfft_output_size (size (p), irfft_dim (p), region (p))
595
+
598
596
mutable struct AdjointPlan{T,P} <: Plan{T}
599
597
p:: P
600
598
pinv:: Plan
@@ -611,31 +609,38 @@ Base.adjoint(p::Plan{T}) where {T} = AdjointPlan{T}(p)
611
609
size (p:: AdjointPlan ) = output_size (p)
612
610
output_size (p:: AdjointPlan ) = size (p)
613
611
614
- function Base.:* (p:: AdjointPlan{T} , x:: AbstractArray ) where {T}
612
+ Base.:* (p:: AdjointPlan , x:: AbstractArray ) = _mul (p, x, ProjectionStyle (p. p))
613
+
614
+ function _mul (p:: AdjointPlan{T} , x:: AbstractArray , :: NoProjectionStyle ) where {T}
615
+ dims = region (p. p)
616
+ N = normalization (T, size (p. p), dims)
617
+ return 1 / N * (p. p \ x)
618
+ end
619
+
620
+ function _mul (p:: AdjointPlan{T} , x:: AbstractArray , :: RealProjectionStyle ) where {T}
615
621
dims = region (p. p)
622
+ N = normalization (T, size (p. p), dims)
616
623
halfdim = first (dims)
617
624
d = size (p. p, halfdim)
618
625
n = output_size (p. p, halfdim)
619
- if projection_style (p. p) == :none
620
- N = normalization (T, size (p. p), dims)
621
- return 1 / N * (p. p \ x)
622
- elseif projection_style (p. p) == :real
623
- N = normalization (T, size (p. p), dims)
624
- scale = reshape (
625
- [(i == 1 || (i == n && 2 * (i - 1 )) == d) ? 1 : 2 for i in 1 : n],
626
- ntuple (i -> i == first (dims) ? n : 1 , Val (ndims (x)))
627
- )
628
- return 1 / N * (p. p \ (x ./ scale))
629
- elseif projection_style (p. p) == :real_inv
630
- N = normalization (real (T), output_size (p. p), dims)
631
- scale = reshape (
632
- [(i == 1 || (i == d && 2 * (i - 1 )) == n) ? 1 : 2 for i in 1 : d],
633
- ntuple (i -> i == first (dims) ? d : 1 , Val (ndims (x)))
634
- )
635
- return 1 / N * scale .* (p. p \ x)
636
- else
637
- error (" plan must define a valid projection style" )
638
- end
626
+ scale = reshape (
627
+ [(i == 1 || (i == n && 2 * (i - 1 )) == d) ? 1 : 2 for i in 1 : n],
628
+ ntuple (i -> i == first (dims) ? n : 1 , Val (ndims (x)))
629
+ )
630
+ return 1 / N * (p. p \ (x ./ scale))
631
+ end
632
+
633
+ function _mul (p:: AdjointPlan{T} , x:: AbstractArray , :: RealInverseProjectionStyle ) where {T}
634
+ dims = region (p. p)
635
+ N = normalization (real (T), output_size (p. p), dims)
636
+ halfdim = first (dims)
637
+ n = size (p. p, halfdim)
638
+ d = output_size (p. p, halfdim)
639
+ scale = reshape (
640
+ [(i == 1 || (i == n && 2 * (i - 1 )) == d) ? 1 : 2 for i in 1 : n],
641
+ ntuple (i -> i == first (dims) ? n : 1 , Val (ndims (x)))
642
+ )
643
+ return 1 / N * scale .* (p. p \ x)
639
644
end
640
645
641
646
plan_inv (p:: AdjointPlan ) = AdjointPlan (plan_inv (p. p))
0 commit comments