Skip to content

Commit 00736a8

Browse files
committed
Fixes to test suite to support CUDA arrays
1 parent 511d56b commit 00736a8

File tree

2 files changed

+39
-31
lines changed

2 files changed

+39
-31
lines changed

Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ version = "1.5.0"
55
[deps]
66
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
77
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
8+
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
89
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
910

1011
[weakdeps]

ext/AbstractFFTsTestExt.jl

Lines changed: 38 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ using AbstractFFTs
66
using AbstractFFTs: TestUtils
77
using AbstractFFTs.LinearAlgebra
88
using Test
9+
import Random
910

1011
# Ground truth x_fft computed using FFTW library
1112
const TEST_CASES = (
@@ -52,7 +53,8 @@ const TEST_CASES = (
5253
)
5354

5455

55-
function TestUtils.test_plan(P::AbstractFFTs.Plan, x::AbstractArray, x_transformed::AbstractArray; inplace_plan=false, copy_input=false)
56+
function TestUtils.test_plan(P::AbstractFFTs.Plan, x::AbstractArray, x_transformed::AbstractArray;
57+
inplace_plan=false, copy_input=false, test_wrappers=true)
5658
_copy = copy_input ? copy : identity
5759
@test size(P) == size(x)
5860
if !inplace_plan
@@ -61,7 +63,9 @@ function TestUtils.test_plan(P::AbstractFFTs.Plan, x::AbstractArray, x_transform
6163
_x_out = similar(P * _copy(x))
6264
@test mul!(_x_out, P, _copy(x)) x_transformed
6365
@test _x_out x_transformed
64-
@test P * view(_copy(x), axes(x)...) x_transformed # test view input
66+
if test_wrappers
67+
@test P * view(_copy(x), axes(x)...) x_transformed # test view input
68+
end
6569
else
6670
_x = copy(x)
6771
@test P * _copy(_x) x_transformed
@@ -71,9 +75,10 @@ function TestUtils.test_plan(P::AbstractFFTs.Plan, x::AbstractArray, x_transform
7175
end
7276
end
7377

74-
function TestUtils.test_plan_adjoint(P::AbstractFFTs.Plan, x::AbstractArray; real_plan=false, copy_input=false)
78+
function TestUtils.test_plan_adjoint(P::AbstractFFTs.Plan, x::AbstractArray;
79+
real_plan=false, copy_input=false, test_wrappers=true)
7580
_copy = copy_input ? copy : identity
76-
y = rand(eltype(P * _copy(x)), size(P * _copy(x)))
81+
y = Random.rand!(P * _copy(x))
7782
# test basic properties
7883
@test eltype(P') === eltype(y)
7984
@test (P')' === P # test adjoint of adjoint
@@ -87,11 +92,13 @@ function TestUtils.test_plan_adjoint(P::AbstractFFTs.Plan, x::AbstractArray; rea
8792
@test _component_dot(y, P * _copy(x)) _component_dot(P' * _copy(y), x)
8893
@test _component_dot(x, P \ _copy(y)) _component_dot(P' \ _copy(x), y)
8994
end
90-
@test P' * view(_copy(y), axes(y)...) P' * _copy(y) # test view input (AbstractFFTs.jl#112)
95+
if test_wrappers
96+
@test P' * view(_copy(y), axes(y)...) P' * _copy(y) # test view input (AbstractFFTs.jl#112)
97+
end
9198
@test_throws MethodError mul!(x, P', y)
9299
end
93100

94-
function TestUtils.test_complex_ffts(ArrayType=Array; test_inplace=true, test_adjoint=true)
101+
function TestUtils.test_complex_ffts(ArrayType=Array; test_inplace=true, test_adjoint=true, test_wrappers=true)
95102
@testset "correctness of fft, bfft, ifft" begin
96103
for test_case in TEST_CASES
97104
_x, dims, _x_fft = copy(test_case.x), test_case.dims, copy(test_case.x_fft)
@@ -111,18 +118,18 @@ function TestUtils.test_complex_ffts(ArrayType=Array; test_inplace=true, test_ad
111118
for P in (plan_fft(similar(x_complexf), dims),
112119
(_inv(plan_ifft(similar(x_complexf), dims)) for _inv in (inv, AbstractFFTs.plan_inv))...)
113120
@test eltype(P) <: Complex
114-
@test fftdims(P) == dims
115-
TestUtils.test_plan(P, x_complexf, x_fft)
121+
@test collect(fftdims(P))[:] == collect(dims)[:] # compare as iterables
122+
TestUtils.test_plan(P, x_complexf, x_fft; test_wrappers=test_wrappers)
116123
if test_adjoint
117124
@test fftdims(P') == fftdims(P)
118-
TestUtils.test_plan_adjoint(P, x_complexf)
125+
TestUtils.test_plan_adjoint(P, x_complexf, test_wrappers=test_wrappers)
119126
end
120127
end
121128
if test_inplace
122129
# test IIP plans
123130
for P in (plan_fft!(similar(x_complexf), dims),
124131
(_inv(plan_ifft!(similar(x_complexf), dims)) for _inv in (inv, AbstractFFTs.plan_inv))...)
125-
TestUtils.test_plan(P, x_complexf, x_fft; inplace_plan=true)
132+
TestUtils.test_plan(P, x_complexf, x_fft; inplace_plan=true, test_wrappers=test_wrappers)
126133
end
127134
end
128135

@@ -137,17 +144,17 @@ function TestUtils.test_complex_ffts(ArrayType=Array; test_inplace=true, test_ad
137144
# test OOP plans. Just 1 plan to test, but we use a for loop for consistent style
138145
for P in (plan_bfft(similar(x_fft), dims),)
139146
@test eltype(P) <: Complex
140-
@test fftdims(P) == dims
141-
TestUtils.test_plan(P, x_fft, x_scaled)
147+
@test collect(fftdims(P))[:] == collect(dims)[:] # compare as iterables
148+
TestUtils.test_plan(P, x_fft, x_scaled; test_wrappers=test_wrappers)
142149
if test_adjoint
143-
TestUtils.test_plan_adjoint(P, x_fft)
150+
TestUtils.test_plan_adjoint(P, x_fft, test_wrappers=test_wrappers)
144151
end
145152
end
146153
# test IIP plans
147154
for P in (plan_bfft!(similar(x_fft), dims),)
148155
@test eltype(P) <: Complex
149-
@test fftdims(P) == dims
150-
TestUtils.test_plan(P, x_fft, x_scaled; inplace_plan=true)
156+
@test collect(fftdims(P))[:] == collect(dims)[:] # compare as iterables
157+
TestUtils.test_plan(P, x_fft, x_scaled; inplace_plan=true, test_wrappers=test_wrappers)
151158
end
152159

153160
# IFFT
@@ -161,33 +168,33 @@ function TestUtils.test_complex_ffts(ArrayType=Array; test_inplace=true, test_ad
161168
for P in (plan_ifft(similar(x_complexf), dims),
162169
(_inv(plan_fft(similar(x_complexf), dims)) for _inv in (inv, AbstractFFTs.plan_inv))...)
163170
@test eltype(P) <: Complex
164-
@test fftdims(P) == dims
165-
TestUtils.test_plan(P, x_fft, x)
171+
@test collect(fftdims(P))[:] == collect(dims)[:] # compare as iterables
172+
TestUtils.test_plan(P, x_fft, x; test_wrappers=test_wrappers)
166173
if test_adjoint
167-
TestUtils.test_plan_adjoint(P, x_fft)
174+
TestUtils.test_plan_adjoint(P, x_fft; test_wrappers=test_wrappers)
168175
end
169176
end
170177
# test IIP plans
171178
if test_inplace
172179
for P in (plan_ifft!(similar(x_complexf), dims),
173180
(_inv(plan_fft!(similar(x_complexf), dims)) for _inv in (inv, AbstractFFTs.plan_inv))...)
174181
@test eltype(P) <: Complex
175-
@test fftdims(P) == dims
176-
TestUtils.test_plan(P, x_fft, x; inplace_plan=true)
182+
@test collect(fftdims(P))[:] == collect(dims)[:] # compare as iterables
183+
TestUtils.test_plan(P, x_fft, x; inplace_plan=true, test_wrappers=test_wrappers)
177184
end
178185
end
179186
end
180187
end
181188
end
182189

183-
function TestUtils.test_real_ffts(ArrayType=Array; test_adjoint=true, copy_input=false)
190+
function TestUtils.test_real_ffts(ArrayType=Array; test_adjoint=true, copy_input=false, test_wrappers=true)
184191
@testset "correctness of rfft, brfft, irfft" begin
185192
for test_case in TEST_CASES
186193
_x, dims, _x_fft = copy(test_case.x), test_case.dims, copy(test_case.x_fft)
187194
x = convert(ArrayType, _x) # dummy array that will be passed to plans
188195
x_real = float.(x) # for testing mutating real FFTs
189196
x_fft = convert(ArrayType, _x_fft)
190-
x_rfft = collect(selectdim(x_fft, first(dims), 1:(size(x_fft, first(dims)) ÷ 2 + 1)))
197+
x_rfft = convert(ArrayType, collect(selectdim(x_fft, first(dims), 1:(size(x_fft, first(dims)) ÷ 2 + 1))))
191198

192199
if !(eltype(x) <: Real)
193200
continue
@@ -198,10 +205,10 @@ function TestUtils.test_real_ffts(ArrayType=Array; test_adjoint=true, copy_input
198205
for P in (plan_rfft(similar(x_real), dims),
199206
(_inv(plan_irfft(similar(x_rfft), size(x, first(dims)), dims)) for _inv in (inv, AbstractFFTs.plan_inv))...)
200207
@test eltype(P) <: Real
201-
@test fftdims(P) == dims
202-
TestUtils.test_plan(P, x_real, x_rfft; copy_input=copy_input)
208+
@test collect(fftdims(P))[:] == collect(dims)[:] # compare as iterables
209+
TestUtils.test_plan(P, x_real, x_rfft; copy_input=copy_input, test_wrappers=test_wrappers)
203210
if test_adjoint
204-
TestUtils.test_plan_adjoint(P, x_real; real_plan=true, copy_input=copy_input)
211+
TestUtils.test_plan_adjoint(P, x_real; real_plan=true, copy_input=copy_input, test_wrappers=test_wrappers)
205212
end
206213
end
207214

@@ -210,10 +217,10 @@ function TestUtils.test_real_ffts(ArrayType=Array; test_adjoint=true, copy_input
210217
@test brfft(x_rfft, size(x, first(dims)), dims) x_scaled
211218
for P in (plan_brfft(similar(x_rfft), size(x, first(dims)), dims),)
212219
@test eltype(P) <: Complex
213-
@test fftdims(P) == dims
214-
TestUtils.test_plan(P, x_rfft, x_scaled; copy_input=copy_input)
220+
@test collect(fftdims(P))[:] == collect(dims)[:] # compare as iterables
221+
TestUtils.test_plan(P, x_rfft, x_scaled; copy_input=copy_input, test_wrappers=test_wrappers)
215222
if test_adjoint
216-
TestUtils.test_plan_adjoint(P, x_rfft; real_plan=true, copy_input=copy_input)
223+
TestUtils.test_plan_adjoint(P, x_rfft; real_plan=true, copy_input=copy_input, test_wrappers=test_wrappers)
217224
end
218225
end
219226

@@ -222,10 +229,10 @@ function TestUtils.test_real_ffts(ArrayType=Array; test_adjoint=true, copy_input
222229
for P in (plan_irfft(similar(x_rfft), size(x, first(dims)), dims),
223230
(_inv(plan_rfft(similar(x_real), dims)) for _inv in (inv, AbstractFFTs.plan_inv))...)
224231
@test eltype(P) <: Complex
225-
@test fftdims(P) == dims
226-
TestUtils.test_plan(P, x_rfft, x; copy_input=copy_input)
232+
@test collect(fftdims(P))[:] == collect(dims)[:] # compare as iterables
233+
TestUtils.test_plan(P, x_rfft, x; copy_input=copy_input, test_wrappers=test_wrappers)
227234
if test_adjoint
228-
TestUtils.test_plan_adjoint(P, x_rfft; real_plan=true, copy_input=copy_input)
235+
TestUtils.test_plan_adjoint(P, x_rfft; real_plan=true, copy_input=copy_input, test_wrappers=test_wrappers)
229236
end
230237
end
231238
end

0 commit comments

Comments
 (0)