Skip to content

Commit 53c67ca

Browse files
authored
Add region(::Plan) for accessing transformed region
1 parent 4f630d6 commit 53c67ca

File tree

3 files changed

+22
-2
lines changed

3 files changed

+22
-2
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "AbstractFFTs"
22
uuid = "621f4979-c628-5d54-868e-fcf4e3e8185c"
3-
version = "1.1.0"
3+
version = "1.2.0"
44

55
[deps]
66
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

src/definitions.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,18 @@ size(p::Plan, d) = size(p)[d]
1515
ndims(p::Plan) = length(size(p))
1616
length(p::Plan) = prod(size(p))::Int
1717

18+
"""
19+
region(p::Plan)
20+
21+
Return an iterable of the dimensions that are transformed by the FFT plan `p`.
22+
23+
# Implementation
24+
25+
The default definition of `region` returns `p.region`.
26+
Hence this method should be implemented only for types of `Plan`s that do not store the transformed region in a field of name `region`.
27+
"""
28+
region(p::Plan) = p.region
29+
1830
fftfloat(x) = _fftfloat(float(x))
1931
_fftfloat(::Type{T}) where {T<:BlasReal} = T
2032
_fftfloat(::Type{Float16}) = Float32
@@ -243,6 +255,8 @@ ScaledPlan(p::ScaledPlan, α::Number) = ScaledPlan(p.p, p.scale * α)
243255

244256
size(p::ScaledPlan) = size(p.p)
245257

258+
region(p::ScaledPlan) = region(p.p)
259+
246260
show(io::IO, p::ScaledPlan) = print(io, p.scale, " * ", p.p)
247261
summary(p::ScaledPlan) = string(p.scale, " * ", summary(p.p))
248262

test/runtests.jl

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,18 +60,21 @@ end
6060
@test eltype(P) === ComplexF64
6161
@test P * x fftw_fft
6262
@test P \ (P * x) x
63+
@test AbstractFFTs.region(P) == dims
6364

6465
fftw_bfft = complex.(size(x, dims) .* x)
6566
@test AbstractFFTs.bfft(y, dims) fftw_bfft
6667
P = plan_bfft(x, dims)
6768
@test P * y fftw_bfft
6869
@test P \ (P * y) y
70+
@test AbstractFFTs.region(P) == dims
6971

7072
fftw_ifft = complex.(x)
7173
@test AbstractFFTs.ifft(y, dims) fftw_ifft
7274
P = plan_ifft(x, dims)
7375
@test P * y fftw_ifft
7476
@test P \ (P * y) y
77+
@test AbstractFFTs.region(P) == dims
7578

7679
# real FFT
7780
fftw_rfft = fftw_fft[
@@ -84,18 +87,21 @@ end
8487
@test eltype(P) === Int
8588
@test P * x fftw_rfft
8689
@test P \ (P * x) x
90+
@test AbstractFFTs.region(P) == dims
8791

8892
fftw_brfft = complex.(size(x, dims) .* x)
8993
@test AbstractFFTs.brfft(ry, size(x, dims), dims) fftw_brfft
9094
P = plan_brfft(ry, size(x, dims), dims)
9195
@test P * ry fftw_brfft
9296
@test P \ (P * ry) ry
97+
@test AbstractFFTs.region(P) == dims
9398

9499
fftw_irfft = complex.(x)
95100
@test AbstractFFTs.irfft(ry, size(x, dims), dims) fftw_irfft
96101
P = plan_irfft(ry, size(x, dims), dims)
97102
@test P * ry fftw_irfft
98103
@test P \ (P * ry) ry
104+
@test AbstractFFTs.region(P) == dims
99105
end
100106
end
101107

@@ -170,7 +176,7 @@ end
170176
# normalization should be inferable even if region is only inferred as ::Any,
171177
# need to wrap in another function to test this (note that p.region::Any for
172178
# p::TestPlan)
173-
f9(p::Plan{T}, sz) where {T} = AbstractFFTs.normalization(real(T), sz, p.region)
179+
f9(p::Plan{T}, sz) where {T} = AbstractFFTs.normalization(real(T), sz, region(p))
174180
@test @inferred(f9(plan_fft(zeros(10), 1), 10)) == 1/10
175181
end
176182

0 commit comments

Comments
 (0)