|
| 1 | +# This file contains code that was formerly part of Julia. License is MIT: https://julialang.org/license |
| 2 | + |
1 | 3 | using AbstractFFTs
|
2 | 4 | using Base.Test
|
3 | 5 |
|
4 |
| -# TODO |
| 6 | +import AbstractFFTs: Plan, plan_fft, plan_inv, plan_bfft |
| 7 | +import Base: A_mul_B!, * |
| 8 | + |
| 9 | +mutable struct TestPlan{T} <: Plan{T} |
| 10 | + region |
| 11 | + pinv::Plan{T} |
| 12 | + TestPlan{T}(region) where {T} = new{T}(region) |
| 13 | +end |
| 14 | + |
| 15 | +mutable struct InverseTestPlan{T} <: Plan{T} |
| 16 | + region |
| 17 | + pinv::Plan{T} |
| 18 | + InverseTestPlan{T}(region) where {T} = new{T}(region) |
| 19 | +end |
| 20 | + |
| 21 | +AbstractFFTs.plan_fft(x::Vector{T}, region; kwargs...) where {T} = TestPlan{T}(region) |
| 22 | +AbstractFFTs.plan_bfft(x::Vector{T}, region; kwargs...) where {T} = InverseTestPlan{T}(region) |
| 23 | +AbstractFFTs.plan_inv(p::TestPlan{T}) where {T} = InverseTestPlan{T} |
| 24 | + |
| 25 | +# Just a helper function since forward and backward are nearly identical |
| 26 | +function dft!(y::Vector, x::Vector, sign::Int) |
| 27 | + n = length(x) |
| 28 | + length(y) == n || throw(DimensionMismatch()) |
| 29 | + fill!(y, zero(complex(float(eltype(x))))) |
| 30 | + c = sign * 2π / n |
| 31 | + @inbounds for j = 0:n-1, k = 0:n-1 |
| 32 | + y[k+1] += x[j+1] * cis(c*j*k) |
| 33 | + end |
| 34 | + return y |
| 35 | +end |
| 36 | + |
| 37 | +Base.A_mul_B!(y::Vector, p::TestPlan, x::Vector) = dft!(y, x, -1) |
| 38 | +Base.A_mul_B!(y::Vector, p::InverseTestPlan, x::Vector) = dft!(y, x, 1) |
| 39 | + |
| 40 | +Base.:*(p::TestPlan, x::Vector) = A_mul_B!(copy(x), p, x) |
| 41 | +Base.:*(p::InverseTestPlan, x::Vector) = A_mul_B!(copy(x), p, x) |
| 42 | + |
| 43 | +@testset "Custom Plan" begin |
| 44 | + x = AbstractFFTs.fft(collect(1:8)) |
| 45 | + # Result computed using FFTW |
| 46 | + fftw_fft = [36.0 + 0.0im, |
| 47 | + -4.0 + 9.65685424949238im, |
| 48 | + -4.0 + 4.0im, |
| 49 | + -4.0 + 1.6568542494923806im, |
| 50 | + -4.0 + 0.0im, |
| 51 | + -4.0 - 1.6568542494923806im, |
| 52 | + -4.0 - 4.0im, |
| 53 | + -4.0 - 9.65685424949238im] |
| 54 | + @test x ≈ fftw_fft |
| 55 | + |
| 56 | + fftw_bfft = [Complex{Float64}(8i, 0) for i in 1:8] |
| 57 | + @test AbstractFFTs.bfft(x) ≈ fftw_bfft |
| 58 | + |
| 59 | + fftw_ifft = [Complex{Float64}(i, 0) for i in 1:8] |
| 60 | + @test AbstractFFTs.ifft(x) ≈ fftw_ifft |
| 61 | +end |
| 62 | + |
| 63 | +@testset "Shift functions" begin |
| 64 | + @test AbstractFFTs.fftshift([1 2 3]) == [3 1 2] |
| 65 | + @test AbstractFFTs.fftshift([1, 2, 3]) == [3, 1, 2] |
| 66 | + @test AbstractFFTs.fftshift([1 2 3; 4 5 6]) == [6 4 5; 3 1 2] |
| 67 | + |
| 68 | + @test AbstractFFTs.fftshift([1 2 3; 4 5 6], 1) == [4 5 6; 1 2 3] |
| 69 | + @test AbstractFFTs.fftshift([1 2 3; 4 5 6], ()) == [1 2 3; 4 5 6] |
| 70 | + @test AbstractFFTs.fftshift([1 2 3; 4 5 6], (1,2)) == [6 4 5; 3 1 2] |
| 71 | + @test AbstractFFTs.fftshift([1 2 3; 4 5 6], 1:2) == [6 4 5; 3 1 2] |
| 72 | + |
| 73 | + @test AbstractFFTs.ifftshift([1 2 3]) == [2 3 1] |
| 74 | + @test AbstractFFTs.ifftshift([1, 2, 3]) == [2, 3, 1] |
| 75 | + @test AbstractFFTs.ifftshift([1 2 3; 4 5 6]) == [5 6 4; 2 3 1] |
| 76 | + |
| 77 | + @test AbstractFFTs.ifftshift([1 2 3; 4 5 6], 1) == [4 5 6; 1 2 3] |
| 78 | + @test AbstractFFTs.ifftshift([1 2 3; 4 5 6], ()) == [1 2 3; 4 5 6] |
| 79 | + @test AbstractFFTs.ifftshift([1 2 3; 4 5 6], (1,2)) == [5 6 4; 2 3 1] |
| 80 | + @test AbstractFFTs.ifftshift([1 2 3; 4 5 6], 1:2) == [5 6 4; 2 3 1] |
| 81 | +end |
0 commit comments