Skip to content

Commit 22dc932

Browse files
committed
Fixes to test suite to support CUDA arrays
1 parent fae1170 commit 22dc932

File tree

2 files changed

+39
-31
lines changed

2 files changed

+39
-31
lines changed

Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ version = "1.5.0"
55
[deps]
66
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
77
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
8+
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
89
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
910

1011
[weakdeps]

ext/AbstractFFTsTestExt.jl

Lines changed: 38 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ using AbstractFFTs
66
using AbstractFFTs: TestUtils
77
using AbstractFFTs.LinearAlgebra
88
using Test
9+
import Random
910

1011
# Ground truth x_fft computed using FFTW library
1112
const TEST_CASES = (
@@ -52,15 +53,18 @@ const TEST_CASES = (
5253
)
5354

5455

55-
function TestUtils.test_plan(P::AbstractFFTs.Plan, x::AbstractArray, x_transformed::AbstractArray; inplace_plan=false, copy_input=false)
56+
function TestUtils.test_plan(P::AbstractFFTs.Plan, x::AbstractArray, x_transformed::AbstractArray;
57+
inplace_plan=false, copy_input=false, test_wrappers=true)
5658
_copy = copy_input ? copy : identity
5759
if !inplace_plan
5860
@test P * _copy(x) x_transformed
5961
@test P \ (P * _copy(x)) x
6062
_x_out = similar(P * _copy(x))
6163
@test mul!(_x_out, P, _copy(x)) x_transformed
6264
@test _x_out x_transformed
63-
@test P * view(_copy(x), axes(x)...) x_transformed # test view input
65+
if test_wrappers
66+
@test P * view(_copy(x), axes(x)...) x_transformed # test view input
67+
end
6468
else
6569
_x = copy(x)
6670
@test P * _copy(_x) x_transformed
@@ -70,9 +74,10 @@ function TestUtils.test_plan(P::AbstractFFTs.Plan, x::AbstractArray, x_transform
7074
end
7175
end
7276

73-
function TestUtils.test_plan_adjoint(P::AbstractFFTs.Plan, x::AbstractArray; real_plan=false, copy_input=false)
77+
function TestUtils.test_plan_adjoint(P::AbstractFFTs.Plan, x::AbstractArray;
78+
real_plan=false, copy_input=false, test_wrappers=true)
7479
_copy = copy_input ? copy : identity
75-
y = rand(eltype(P * _copy(x)), size(P * _copy(x)))
80+
y = Random.rand!(P * _copy(x))
7681
# test basic properties
7782
@test_skip eltype(P') === typeof(y) # (AbstractFFTs.jl#110)
7883
@test (P')' === P # test adjoint of adjoint
@@ -86,11 +91,13 @@ function TestUtils.test_plan_adjoint(P::AbstractFFTs.Plan, x::AbstractArray; rea
8691
@test _component_dot(y, P * _copy(x)) _component_dot(P' * _copy(y), x)
8792
@test _component_dot(x, P \ _copy(y)) _component_dot(P' \ _copy(x), y)
8893
end
89-
@test P' * view(_copy(y), axes(y)...) P' * _copy(y) # test view input (AbstractFFTs.jl#112)
94+
if test_wrappers
95+
@test P' * view(_copy(y), axes(y)...) P' * _copy(y) # test view input (AbstractFFTs.jl#112)
96+
end
9097
@test_throws MethodError mul!(x, P', y)
9198
end
9299

93-
function TestUtils.test_complex_ffts(ArrayType=Array; test_inplace=true, test_adjoint=true)
100+
function TestUtils.test_complex_ffts(ArrayType=Array; test_inplace=true, test_adjoint=true, test_wrappers=true)
94101
@testset "correctness of fft, bfft, ifft" begin
95102
for test_case in TEST_CASES
96103
_x, dims, _x_fft = copy(test_case.x), test_case.dims, copy(test_case.x_fft)
@@ -110,18 +117,18 @@ function TestUtils.test_complex_ffts(ArrayType=Array; test_inplace=true, test_ad
110117
for P in (plan_fft(similar(x_complexf), dims),
111118
(_inv(plan_ifft(similar(x_complexf), dims)) for _inv in (inv, AbstractFFTs.plan_inv))...)
112119
@test eltype(P) <: Complex
113-
@test fftdims(P) == dims
114-
TestUtils.test_plan(P, x_complexf, x_fft)
120+
@test collect(fftdims(P))[:] == collect(dims)[:] # compare as iterables
121+
TestUtils.test_plan(P, x_complexf, x_fft; test_wrappers=test_wrappers)
115122
if test_adjoint
116123
@test fftdims(P') == fftdims(P)
117-
TestUtils.test_plan_adjoint(P, x_complexf)
124+
TestUtils.test_plan_adjoint(P, x_complexf, test_wrappers=test_wrappers)
118125
end
119126
end
120127
if test_inplace
121128
# test IIP plans
122129
for P in (plan_fft!(similar(x_complexf), dims),
123130
(_inv(plan_ifft!(similar(x_complexf), dims)) for _inv in (inv, AbstractFFTs.plan_inv))...)
124-
TestUtils.test_plan(P, x_complexf, x_fft; inplace_plan=true)
131+
TestUtils.test_plan(P, x_complexf, x_fft; inplace_plan=true, test_wrappers=test_wrappers)
125132
end
126133
end
127134

@@ -136,17 +143,17 @@ function TestUtils.test_complex_ffts(ArrayType=Array; test_inplace=true, test_ad
136143
# test OOP plans. Just 1 plan to test, but we use a for loop for consistent style
137144
for P in (plan_bfft(similar(x_fft), dims),)
138145
@test eltype(P) <: Complex
139-
@test fftdims(P) == dims
140-
TestUtils.test_plan(P, x_fft, x_scaled)
146+
@test collect(fftdims(P))[:] == collect(dims)[:] # compare as iterables
147+
TestUtils.test_plan(P, x_fft, x_scaled; test_wrappers=test_wrappers)
141148
if test_adjoint
142-
TestUtils.test_plan_adjoint(P, x_fft)
149+
TestUtils.test_plan_adjoint(P, x_fft, test_wrappers=test_wrappers)
143150
end
144151
end
145152
# test IIP plans
146153
for P in (plan_bfft!(similar(x_fft), dims),)
147154
@test eltype(P) <: Complex
148-
@test fftdims(P) == dims
149-
TestUtils.test_plan(P, x_fft, x_scaled; inplace_plan=true)
155+
@test collect(fftdims(P))[:] == collect(dims)[:] # compare as iterables
156+
TestUtils.test_plan(P, x_fft, x_scaled; inplace_plan=true, test_wrappers=test_wrappers)
150157
end
151158

152159
# IFFT
@@ -160,33 +167,33 @@ function TestUtils.test_complex_ffts(ArrayType=Array; test_inplace=true, test_ad
160167
for P in (plan_ifft(similar(x_complexf), dims),
161168
(_inv(plan_fft(similar(x_complexf), dims)) for _inv in (inv, AbstractFFTs.plan_inv))...)
162169
@test eltype(P) <: Complex
163-
@test fftdims(P) == dims
164-
TestUtils.test_plan(P, x_fft, x)
170+
@test collect(fftdims(P))[:] == collect(dims)[:] # compare as iterables
171+
TestUtils.test_plan(P, x_fft, x; test_wrappers=test_wrappers)
165172
if test_adjoint
166-
TestUtils.test_plan_adjoint(P, x_fft)
173+
TestUtils.test_plan_adjoint(P, x_fft; test_wrappers=test_wrappers)
167174
end
168175
end
169176
# test IIP plans
170177
if test_inplace
171178
for P in (plan_ifft!(similar(x_complexf), dims),
172179
(_inv(plan_fft!(similar(x_complexf), dims)) for _inv in (inv, AbstractFFTs.plan_inv))...)
173180
@test eltype(P) <: Complex
174-
@test fftdims(P) == dims
175-
TestUtils.test_plan(P, x_fft, x; inplace_plan=true)
181+
@test collect(fftdims(P))[:] == collect(dims)[:] # compare as iterables
182+
TestUtils.test_plan(P, x_fft, x; inplace_plan=true, test_wrappers=test_wrappers)
176183
end
177184
end
178185
end
179186
end
180187
end
181188

182-
function TestUtils.test_real_ffts(ArrayType=Array; test_adjoint=true, copy_input=false)
189+
function TestUtils.test_real_ffts(ArrayType=Array; test_adjoint=true, copy_input=false, test_wrappers=true)
183190
@testset "correctness of rfft, brfft, irfft" begin
184191
for test_case in TEST_CASES
185192
_x, dims, _x_fft = copy(test_case.x), test_case.dims, copy(test_case.x_fft)
186193
x = convert(ArrayType, _x) # dummy array that will be passed to plans
187194
x_real = float.(x) # for testing mutating real FFTs
188195
x_fft = convert(ArrayType, _x_fft)
189-
x_rfft = collect(selectdim(x_fft, first(dims), 1:(size(x_fft, first(dims)) ÷ 2 + 1)))
196+
x_rfft = convert(ArrayType, collect(selectdim(x_fft, first(dims), 1:(size(x_fft, first(dims)) ÷ 2 + 1))))
190197

191198
if !(eltype(x) <: Real)
192199
continue
@@ -197,10 +204,10 @@ function TestUtils.test_real_ffts(ArrayType=Array; test_adjoint=true, copy_input
197204
for P in (plan_rfft(similar(x_real), dims),
198205
(_inv(plan_irfft(similar(x_rfft), size(x, first(dims)), dims)) for _inv in (inv, AbstractFFTs.plan_inv))...)
199206
@test eltype(P) <: Real
200-
@test fftdims(P) == dims
201-
TestUtils.test_plan(P, x_real, x_rfft; copy_input=copy_input)
207+
@test collect(fftdims(P))[:] == collect(dims)[:] # compare as iterables
208+
TestUtils.test_plan(P, x_real, x_rfft; copy_input=copy_input, test_wrappers=test_wrappers)
202209
if test_adjoint
203-
TestUtils.test_plan_adjoint(P, x_real; real_plan=true, copy_input=copy_input)
210+
TestUtils.test_plan_adjoint(P, x_real; real_plan=true, copy_input=copy_input, test_wrappers=test_wrappers)
204211
end
205212
end
206213

@@ -209,10 +216,10 @@ function TestUtils.test_real_ffts(ArrayType=Array; test_adjoint=true, copy_input
209216
@test brfft(x_rfft, size(x, first(dims)), dims) x_scaled
210217
for P in (plan_brfft(similar(x_rfft), size(x, first(dims)), dims),)
211218
@test eltype(P) <: Complex
212-
@test fftdims(P) == dims
213-
TestUtils.test_plan(P, x_rfft, x_scaled; copy_input=copy_input)
219+
@test collect(fftdims(P))[:] == collect(dims)[:] # compare as iterables
220+
TestUtils.test_plan(P, x_rfft, x_scaled; copy_input=copy_input, test_wrappers=test_wrappers)
214221
if test_adjoint
215-
TestUtils.test_plan_adjoint(P, x_rfft; real_plan=true, copy_input=copy_input)
222+
TestUtils.test_plan_adjoint(P, x_rfft; real_plan=true, copy_input=copy_input, test_wrappers=test_wrappers)
216223
end
217224
end
218225

@@ -221,10 +228,10 @@ function TestUtils.test_real_ffts(ArrayType=Array; test_adjoint=true, copy_input
221228
for P in (plan_irfft(similar(x_rfft), size(x, first(dims)), dims),
222229
(_inv(plan_rfft(similar(x_real), dims)) for _inv in (inv, AbstractFFTs.plan_inv))...)
223230
@test eltype(P) <: Complex
224-
@test fftdims(P) == dims
225-
TestUtils.test_plan(P, x_rfft, x; copy_input=copy_input)
231+
@test collect(fftdims(P))[:] == collect(dims)[:] # compare as iterables
232+
TestUtils.test_plan(P, x_rfft, x; copy_input=copy_input, test_wrappers=test_wrappers)
226233
if test_adjoint
227-
TestUtils.test_plan_adjoint(P, x_rfft; real_plan=true, copy_input=copy_input)
234+
TestUtils.test_plan_adjoint(P, x_rfft; real_plan=true, copy_input=copy_input, test_wrappers=test_wrappers)
228235
end
229236
end
230237
end

0 commit comments

Comments
 (0)