|
3 | 3 | using AbstractFFTs
|
4 | 4 | using AbstractFFTs: Plan
|
5 | 5 | using ChainRulesTestUtils
|
| 6 | +using ChainRulesCore: NoTangent |
6 | 7 |
|
7 | 8 | using LinearAlgebra
|
8 | 9 | using Random
|
|
197 | 198 | @test @inferred(f9(plan_fft(zeros(10), 1), 10)) == 1/10
|
198 | 199 | end
|
199 | 200 |
|
| 201 | +@testset "adjoint" begin |
| 202 | + @testset "complex fft adjoint" begin |
| 203 | + for x in (randn(3), randn(3, 4), randn(3, 4, 5)) |
| 204 | + N = ndims(x) |
| 205 | + y = randn(size(x)) |
| 206 | + for dims in unique((1, 1:N, N)) |
| 207 | + P = plan_fft(x, dims) |
| 208 | + @test dot(y, P * x) ≈ dot(P' * y, x) |
| 209 | + @test_broken dot(y, P \ x) ≈ dot(P' \ y, x) |
| 210 | + Pinv = plan_ifft(x) |
| 211 | + @test dot(x, Pinv * y) ≈ dot(Pinv' * x, y) |
| 212 | + @test_broken dot(x, Pinv \ y) ≈ dot(Pinv' \ x, y) |
| 213 | + end |
| 214 | + end |
| 215 | + end |
| 216 | + @testset "real fft adjoint" begin |
| 217 | + for x in (randn(3), randn(4), randn(3, 4), randn(3, 4, 5)) # test odd and even lengths |
| 218 | + N = ndims(x) |
| 219 | + for dims in unique((1, 1:N, N)) |
| 220 | + P = plan_rfft(similar(x), dims) |
| 221 | + y_real = randn(size(P * x)) |
| 222 | + y_imag = randn(size(P * x)) |
| 223 | + y = y_real .+ y_imag .* im |
| 224 | + @test dot(y_real, real.(P * x)) + dot(y_imag, imag.(P * x)) ≈ dot(P' * y, x) |
| 225 | + @test_broken dot(y_real, real.(P \ x)) + dot(y_imag, imag.(P \ x)) ≈ dot(P' * y, x) |
| 226 | + Pinv = plan_irfft(similar(y), size(x)[first(dims)], dims) |
| 227 | + @test dot(x, Pinv * y) ≈ dot(y_real, real.(Pinv' * x)) + dot(y_imag, imag.(Pinv' * x)) |
| 228 | + @test_broken dot(x, Pinv \ y) ≈ dot(y_real, real.(Pinv' \ x)) + dot(y_imag, imag.(Pinv' \ x)) |
| 229 | + end |
| 230 | + end |
| 231 | + end |
| 232 | +end |
| 233 | + |
200 | 234 | @testset "ChainRules" begin
|
201 | 235 | @testset "shift functions" begin
|
202 | 236 | for x in (randn(3), randn(3, 4), randn(3, 4, 5))
|
|
218 | 252 | end
|
219 | 253 |
|
220 | 254 | @testset "fft" begin
|
221 |
| - for x in (randn(3), randn(3, 4), randn(3, 4, 5)) |
| 255 | + for x in (randn(2), randn(2, 3), randn(3, 4, 5)) |
222 | 256 | N = ndims(x)
|
223 | 257 | complex_x = complex.(x)
|
224 | 258 | for dims in unique((1, 1:N, N))
|
|
229 | 263 | test_rrule(f, complex_x, dims)
|
230 | 264 | end
|
231 | 265 |
|
232 |
| - test_frule(rfft, x, dims) |
233 |
| - test_rrule(rfft, x, dims) |
| 266 | + for pf in (plan_fft, plan_ifft, plan_bfft) |
| 267 | + test_frule(*, pf(x, dims) ⊢ NoTangent(), x) |
| 268 | + test_rrule(*, pf(x, dims) ⊢ NoTangent(), x) |
| 269 | + test_frule(*, pf(complex_x, dims) ⊢ NoTangent(), complex_x) |
| 270 | + test_rrule(*, pf(complex_x, dims) ⊢ NoTangent(), complex_x) |
| 271 | + end |
234 | 272 |
|
235 | 273 | for f in (irfft, brfft)
|
236 | 274 | for d in (2 * size(x, first(dims)) - 1, 2 * size(x, first(dims)) - 2)
|
|
240 | 278 | test_rrule(f, complex_x, d, dims)
|
241 | 279 | end
|
242 | 280 | end
|
| 281 | + |
| 282 | + for pf in (plan_irfft, plan_brfft) |
| 283 | + for d in (2 * size(x, first(dims)) - 1, 2 * size(x, first(dims)) - 2) |
| 284 | + test_frule(*, pf(complex_x, d, dims) ⊢ NoTangent(), complex_x) |
| 285 | + test_rrule(*, pf(complex_x, d, dims) ⊢ NoTangent(), complex_x) |
| 286 | + end |
| 287 | + end |
| 288 | + |
243 | 289 | end
|
244 | 290 | end
|
245 | 291 | end
|
|
0 commit comments