Skip to content

Commit 377530c

Browse files
committed
Move TestUtils to a submodule
1 parent 45e0bc1 commit 377530c

File tree

5 files changed

+57
-53
lines changed

5 files changed

+57
-53
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@ version = "1.2.1"
55
[deps]
66
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
77
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
8+
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
9+
Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"
810

911
[compat]
1012
ChainRulesCore = "1"

src/AbstractFFTs.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,5 +9,6 @@ export fft, ifft, bfft, fft!, ifft!, bfft!,
99

1010
include("definitions.jl")
1111
include("chainrules.jl")
12+
include("TestUtils.jl")
1213

1314
end # module

test/AbstractFFTsTestUtils.jl renamed to src/TestUtils.jl

Lines changed: 1 addition & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
11
# This file contains code that was formerly part of Julia. License is MIT: https://julialang.org/license
22

3-
module AbstractFFTsTestUtils
3+
module TestUtils
44

55
export test_fft_backend
66

77
using AbstractFFTs
88
using AbstractFFTs: Plan
9-
using ChainRulesTestUtils
109

1110
using LinearAlgebra
1211
using Test
@@ -212,54 +211,6 @@ function test_fft_backend()
212211
f9(p::Plan{T}, sz) where {T} = AbstractFFTs.normalization(real(T), sz, fftdims(p))
213212
@test @inferred(f9(plan_fft(zeros(10), 1), 10)) == 1/10
214213
end
215-
216-
@testset "ChainRules" begin
217-
@testset "shift functions" begin
218-
for x in (randn(3), randn(3, 4), randn(3, 4, 5))
219-
for dims in ((), 1, 2, (1,2), 1:2)
220-
any(d > ndims(x) for d in dims) && continue
221-
222-
# type inference checks of `rrule` fail on old Julia versions
223-
# for higher-dimensional arrays:
224-
# https://github.com/JuliaMath/AbstractFFTs.jl/pull/58#issuecomment-916530016
225-
check_inferred = ndims(x) < 3 || VERSION >= v"1.6"
226-
227-
test_frule(AbstractFFTs.fftshift, x, dims)
228-
test_rrule(AbstractFFTs.fftshift, x, dims; check_inferred=check_inferred)
229-
230-
test_frule(AbstractFFTs.ifftshift, x, dims)
231-
test_rrule(AbstractFFTs.ifftshift, x, dims; check_inferred=check_inferred)
232-
end
233-
end
234-
end
235-
236-
@testset "fft" begin
237-
for x in (randn(3), randn(3, 4), randn(3, 4, 5))
238-
N = ndims(x)
239-
complex_x = complex.(x)
240-
for dims in unique((1, 1:N, N))
241-
for f in (fft, ifft, bfft)
242-
test_frule(f, x, dims)
243-
test_rrule(f, x, dims)
244-
test_frule(f, complex_x, dims)
245-
test_rrule(f, complex_x, dims)
246-
end
247-
248-
test_frule(rfft, x, dims)
249-
test_rrule(rfft, x, dims)
250-
251-
for f in (irfft, brfft)
252-
for d in (2 * size(x, first(dims)) - 1, 2 * size(x, first(dims)) - 2)
253-
test_frule(f, x, d, dims)
254-
test_rrule(f, x, d, dims)
255-
test_frule(f, complex_x, d, dims)
256-
test_rrule(f, complex_x, d, dims)
257-
end
258-
end
259-
end
260-
end
261-
end
262-
end
263214
end
264215

265216
end

test/chainrules.jl

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
@testset "ChainRules" begin
2+
@testset "shift functions" begin
3+
for x in (randn(3), randn(3, 4), randn(3, 4, 5))
4+
for dims in ((), 1, 2, (1,2), 1:2)
5+
any(d > ndims(x) for d in dims) && continue
6+
7+
# type inference checks of `rrule` fail on old Julia versions
8+
# for higher-dimensional arrays:
9+
# https://github.com/JuliaMath/AbstractFFTs.jl/pull/58#issuecomment-916530016
10+
check_inferred = ndims(x) < 3 || VERSION >= v"1.6"
11+
12+
test_frule(AbstractFFTs.fftshift, x, dims)
13+
test_rrule(AbstractFFTs.fftshift, x, dims; check_inferred=check_inferred)
14+
15+
test_frule(AbstractFFTs.ifftshift, x, dims)
16+
test_rrule(AbstractFFTs.ifftshift, x, dims; check_inferred=check_inferred)
17+
end
18+
end
19+
end
20+
21+
@testset "fft" begin
22+
for x in (randn(3), randn(3, 4), randn(3, 4, 5))
23+
N = ndims(x)
24+
complex_x = complex.(x)
25+
for dims in unique((1, 1:N, N))
26+
for f in (fft, ifft, bfft)
27+
test_frule(f, x, dims)
28+
test_rrule(f, x, dims)
29+
test_frule(f, complex_x, dims)
30+
test_rrule(f, complex_x, dims)
31+
end
32+
33+
test_frule(rfft, x, dims)
34+
test_rrule(rfft, x, dims)
35+
36+
for f in (irfft, brfft)
37+
for d in (2 * size(x, first(dims)) - 1, 2 * size(x, first(dims)) - 2)
38+
test_frule(f, x, d, dims)
39+
test_rrule(f, x, d, dims)
40+
test_frule(f, complex_x, d, dims)
41+
test_rrule(f, complex_x, d, dims)
42+
end
43+
end
44+
end
45+
end
46+
end
47+
end
48+

test/runtests.jl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
using Random
2+
using Test
3+
using AbstractFFTs
4+
using AbstractFFTs.TestUtils
5+
using ChainRulesTestUtils
26

37
Random.seed!(1234)
48

59
include("TestPlans.jl")
6-
include("AbstractFFTsTestUtils.jl")
710

811
using .TestPlans
9-
using .AbstractFFTsTestUtils
10-
1112
test_fft_backend()
1213

14+
include("chainrules.jl")

0 commit comments

Comments
 (0)