@@ -12,6 +12,7 @@ eltype(::Type{<:Plan{T}}) where {T} = T
12
12
13
13
# size(p) should return the size of the input array for p
14
14
size (p:: Plan , d) = size (p)[d]
15
+ output_size (p:: Plan , d) = output_size (p)[d]
15
16
ndims (p:: Plan ) = length (size (p))
16
17
length (p:: Plan ) = prod (size (p)):: Int
17
18
@@ -254,6 +255,7 @@ ScaledPlan(p::Plan{T}, scale::Number) where {T} = ScaledPlan{T}(p, scale)
254
255
ScaledPlan (p:: ScaledPlan , α:: Number ) = ScaledPlan (p. p, p. scale * α)
255
256
256
257
size (p:: ScaledPlan ) = size (p. p)
258
+ output_size (p:: ScaledPlan ) = output_size (p. p)
257
259
258
260
fftdims (p:: ScaledPlan ) = fftdims (p. p)
259
261
@@ -575,3 +577,67 @@ Pre-plan an optimized real-input unnormalized transform, similar to
575
577
the same as for [`brfft`](@ref).
576
578
"""
577
579
plan_brfft
580
+
581
+ # #############################################################################
582
+
583
+ struct NoProjectionStyle end
584
+ struct RealProjectionStyle end
585
+ struct RealInverseProjectionStyle end
586
+ const ProjectionStyle = Union{NoProjectionStyle, RealProjectionStyle, RealInverseProjectionStyle}
587
+
588
+ function irfft_dim end
589
+
590
+ output_size (p:: Plan ) = _output_size (p, ProjectionStyle (p))
591
+ _output_size (p:: Plan , :: NoProjectionStyle ) = size (p)
592
+ _output_size (p:: Plan , :: RealProjectionStyle ) = rfft_output_size (size (p), region (p))
593
+ _output_size (p:: Plan , :: RealInverseProjectionStyle ) = brfft_output_size (size (p), irfft_dim (p), region (p))
594
+
595
+ mutable struct AdjointPlan{T,P} <: Plan{T}
596
+ p:: P
597
+ pinv:: Plan
598
+ AdjointPlan {T,P} (p) where {T,P} = new (p)
599
+ end
600
+
601
+ Base. adjoint (p:: Plan{T} ) where {T} = AdjointPlan {T, typeof(p)} (p)
602
+ Base. adjoint (p:: AdjointPlan{T} ) where {T} = p. p
603
+ # always have AdjointPlan inside ScaledPlan.
604
+ Base. adjoint (p:: ScaledPlan{T} ) where {T} = ScaledPlan {T} (p. p' , p. scale)
605
+
606
+ size (p:: AdjointPlan ) = output_size (p. p)
607
+ output_size (p:: AdjointPlan ) = size (p. p)
608
+
609
+ Base.:* (p:: AdjointPlan , x:: AbstractArray ) = _mul (p, x, ProjectionStyle (p. p))
610
+
611
+ function _mul (p:: AdjointPlan{T} , x:: AbstractArray , :: NoProjectionStyle ) where {T}
612
+ dims = region (p. p)
613
+ N = normalization (T, size (p. p), dims)
614
+ return 1 / N * (p. p \ x)
615
+ end
616
+
617
+ function _mul (p:: AdjointPlan{T} , x:: AbstractArray , :: RealProjectionStyle ) where {T}
618
+ dims = region (p. p)
619
+ N = normalization (T, size (p. p), dims)
620
+ halfdim = first (dims)
621
+ d = size (p. p, halfdim)
622
+ n = output_size (p. p, halfdim)
623
+ scale = reshape (
624
+ [(i == 1 || (i == n && 2 * (i - 1 )) == d) ? 1 : 2 for i in 1 : n],
625
+ ntuple (i -> i == first (dims) ? n : 1 , Val (ndims (x)))
626
+ )
627
+ return 1 / N * (p. p \ (x ./ scale))
628
+ end
629
+
630
+ function _mul (p:: AdjointPlan{T} , x:: AbstractArray , :: RealInverseProjectionStyle ) where {T}
631
+ dims = region (p. p)
632
+ N = normalization (real (T), output_size (p. p), dims)
633
+ halfdim = first (dims)
634
+ n = size (p. p, halfdim)
635
+ d = output_size (p. p, halfdim)
636
+ scale = reshape (
637
+ [(i == 1 || (i == n && 2 * (i - 1 )) == d) ? 1 : 2 for i in 1 : n],
638
+ ntuple (i -> i == first (dims) ? n : 1 , Val (ndims (x)))
639
+ )
640
+ return 1 / N * scale .* (p. p \ x)
641
+ end
642
+
643
+ plan_inv (p:: AdjointPlan ) = AdjointPlan (plan_inv (p. p))
0 commit comments