Skip to content

Commit 0ff965a

Browse files
committed
Add tests
1 parent 7e8f4e0 commit 0ff965a

File tree

1 file changed

+78
-1
lines changed

1 file changed

+78
-1
lines changed

test/runtests.jl

Lines changed: 78 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,81 @@
1+
# This file contains code that was formerly part of Julia. License is MIT: https://julialang.org/license
2+
13
using AbstractFFTs
24
using Base.Test
35

4-
# TODO
6+
import AbstractFFTs: Plan, plan_fft, plan_inv, plan_bfft
7+
import Base: A_mul_B!, *
8+
9+
mutable struct TestPlan{T} <: Plan{T}
10+
region
11+
pinv::Plan{T}
12+
TestPlan{T}(region) where {T} = new{T}(region)
13+
end
14+
15+
mutable struct InverseTestPlan{T} <: Plan{T}
16+
region
17+
pinv::Plan{T}
18+
InverseTestPlan{T}(region) where {T} = new{T}(region)
19+
end
20+
21+
AbstractFFTs.plan_fft(x::Vector{T}, region; kwargs...) where {T} = TestPlan{T}(region)
22+
AbstractFFTs.plan_bfft(x::Vector{T}, region; kwargs...) where {T} = InverseTestPlan{T}(region)
23+
AbstractFFTs.plan_inv(p::TestPlan{T}) where {T} = InverseTestPlan{T}
24+
25+
# Just a helper function since forward and backward are nearly identical
26+
function dft!(y::Vector, x::Vector, sign::Int)
27+
n = length(x)
28+
length(y) == n || throw(DimensionMismatch())
29+
fill!(y, zero(complex(float(eltype(x)))))
30+
c = sign * 2π / n
31+
@inbounds for j = 0:n-1, k = 0:n-1
32+
y[k+1] += x[j+1] * cis(c*j*k)
33+
end
34+
return y
35+
end
36+
37+
Base.A_mul_B!(y::Vector, p::TestPlan, x::Vector) = dft!(y, x, -1)
38+
Base.A_mul_B!(y::Vector, p::InverseTestPlan, x::Vector) = dft!(y, x, 1)
39+
40+
Base.:*(p::TestPlan, x::Vector) = A_mul_B!(copy(x), p, x)
41+
Base.:*(p::InverseTestPlan, x::Vector) = A_mul_B!(copy(x), p, x)
42+
43+
@testset "Custom Plan" begin
44+
x = AbstractFFTs.fft(collect(1:8))
45+
# Result computed using FFTW
46+
fftw_fft = [36.0 + 0.0im,
47+
-4.0 + 9.65685424949238im,
48+
-4.0 + 4.0im,
49+
-4.0 + 1.6568542494923806im,
50+
-4.0 + 0.0im,
51+
-4.0 - 1.6568542494923806im,
52+
-4.0 - 4.0im,
53+
-4.0 - 9.65685424949238im]
54+
@test x fftw_fft
55+
56+
fftw_bfft = [Complex{Float64}(8i, 0) for i in 1:8]
57+
@test AbstractFFTs.bfft(x) fftw_bfft
58+
59+
fftw_ifft = [Complex{Float64}(i, 0) for i in 1:8]
60+
@test AbstractFFTs.ifft(x) fftw_ifft
61+
end
62+
63+
@testset "Shift functions" begin
64+
@test AbstractFFTs.fftshift([1 2 3]) == [3 1 2]
65+
@test AbstractFFTs.fftshift([1, 2, 3]) == [3, 1, 2]
66+
@test AbstractFFTs.fftshift([1 2 3; 4 5 6]) == [6 4 5; 3 1 2]
67+
68+
@test AbstractFFTs.fftshift([1 2 3; 4 5 6], 1) == [4 5 6; 1 2 3]
69+
@test AbstractFFTs.fftshift([1 2 3; 4 5 6], ()) == [1 2 3; 4 5 6]
70+
@test AbstractFFTs.fftshift([1 2 3; 4 5 6], (1,2)) == [6 4 5; 3 1 2]
71+
@test AbstractFFTs.fftshift([1 2 3; 4 5 6], 1:2) == [6 4 5; 3 1 2]
72+
73+
@test AbstractFFTs.ifftshift([1 2 3]) == [2 3 1]
74+
@test AbstractFFTs.ifftshift([1, 2, 3]) == [2, 3, 1]
75+
@test AbstractFFTs.ifftshift([1 2 3; 4 5 6]) == [5 6 4; 2 3 1]
76+
77+
@test AbstractFFTs.ifftshift([1 2 3; 4 5 6], 1) == [4 5 6; 1 2 3]
78+
@test AbstractFFTs.ifftshift([1 2 3; 4 5 6], ()) == [1 2 3; 4 5 6]
79+
@test AbstractFFTs.ifftshift([1 2 3; 4 5 6], (1,2)) == [5 6 4; 2 3 1]
80+
@test AbstractFFTs.ifftshift([1 2 3; 4 5 6], 1:2) == [5 6 4; 2 3 1]
81+
end

0 commit comments

Comments
 (0)