Skip to content

Commit fae1170

Browse files
Remove arrays of scaling factors (#116)
* Remove arrays of scaling factors * Fix typo * Test view inputs in interface tests (#117) --------- Co-authored-by: Gaurav Arya <gauravarya272@gmail.com>
1 parent ee9f1b8 commit fae1170

File tree

3 files changed

+53
-27
lines changed

3 files changed

+53
-27
lines changed

ext/AbstractFFTsChainRulesCoreExt.jl

Lines changed: 30 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -30,16 +30,20 @@ function ChainRulesCore.rrule(::typeof(rfft), x::AbstractArray{<:Real}, dims)
3030
halfdim = first(dims)
3131
d = size(x, halfdim)
3232
n = size(y, halfdim)
33-
scale = reshape(
34-
[i == 1 || (i == n && 2 * (i - 1) == d) ? 1 : 2 for i in 1:n],
35-
ntuple(i -> i == first(dims) ? n : 1, Val(ndims(x))),
36-
)
3733

3834
project_x = ChainRulesCore.ProjectTo(x)
3935
function rfft_pullback(ȳ)
4036
ybar = ChainRulesCore.unthunk(ȳ)
41-
_scale = convert(typeof(ybar),scale)
42-
= project_x(brfft(ybar ./ _scale, d, dims))
37+
ybar_scaled = map(ybar, CartesianIndices(ybar)) do ybar_j, j
38+
i = j[halfdim]
39+
ybar_scaled_j = if i == 1 || (i == n && 2 * (i - 1) == d)
40+
ybar_j
41+
else
42+
ybar_j / 2
43+
end
44+
return ybar_scaled_j
45+
end
46+
= project_x(brfft(ybar_scaled, d, dims))
4347
return ChainRulesCore.NoTangent(), x̄, ChainRulesCore.NoTangent()
4448
end
4549
return y, rfft_pullback
@@ -74,16 +78,20 @@ function ChainRulesCore.rrule(::typeof(irfft), x::AbstractArray, d::Int, dims)
7478
n = size(x, halfdim)
7579
invN = AbstractFFTs.normalization(y, dims)
7680
twoinvN = 2 * invN
77-
scale = reshape(
78-
[i == 1 || (i == n && 2 * (i - 1) == d) ? invN : twoinvN for i in 1:n],
79-
ntuple(i -> i == first(dims) ? n : 1, Val(ndims(x))),
80-
)
8181

8282
project_x = ChainRulesCore.ProjectTo(x)
8383
function irfft_pullback(ȳ)
8484
ybar = ChainRulesCore.unthunk(ȳ)
85-
_scale = convert(typeof(ybar),scale)
86-
= project_x(_scale .* rfft(real.(ybar), dims))
85+
x̄_scaled = rfft(real.(ybar), dims)
86+
= project_x(map(x̄_scaled, CartesianIndices(x̄_scaled)) do x̄_scaled_j, j
87+
i = j[halfdim]
88+
x̄_j = if i == 1 || (i == n && 2 * (i - 1) == d)
89+
invN * x̄_scaled_j
90+
else
91+
twoinvN * x̄_scaled_j
92+
end
93+
return x̄_j
94+
end)
8795
return ChainRulesCore.NoTangent(), x̄, ChainRulesCore.NoTangent(), ChainRulesCore.NoTangent()
8896
end
8997
return y, irfft_pullback
@@ -115,14 +123,19 @@ function ChainRulesCore.rrule(::typeof(brfft), x::AbstractArray, d::Int, dims)
115123
# compute scaling factors
116124
halfdim = first(dims)
117125
n = size(x, halfdim)
118-
scale = reshape(
119-
[i == 1 || (i == n && 2 * (i - 1) == d) ? 1 : 2 for i in 1:n],
120-
ntuple(i -> i == first(dims) ? n : 1, Val(ndims(x))),
121-
)
122126

123127
project_x = ChainRulesCore.ProjectTo(x)
124128
function brfft_pullback(ȳ)
125-
= project_x(scale .* rfft(real.(ChainRulesCore.unthunk(ȳ)), dims))
129+
x̄_scaled = rfft(real.(ChainRulesCore.unthunk(ȳ)), dims)
130+
= project_x(map(x̄_scaled, CartesianIndices(x̄_scaled)) do x̄_scaled_j, j
131+
i = j[halfdim]
132+
x̄_j = if i == 1 || (i == n && 2 * (i - 1) == d)
133+
x̄_scaled_j
134+
else
135+
2 * x̄_scaled_j
136+
end
137+
return x̄_j
138+
end)
126139
return ChainRulesCore.NoTangent(), x̄, ChainRulesCore.NoTangent(), ChainRulesCore.NoTangent()
127140
end
128141
return y, brfft_pullback

ext/AbstractFFTsTestExt.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ function TestUtils.test_plan(P::AbstractFFTs.Plan, x::AbstractArray, x_transform
6060
_x_out = similar(P * _copy(x))
6161
@test mul!(_x_out, P, _copy(x)) x_transformed
6262
@test _x_out x_transformed
63+
@test P * view(_copy(x), axes(x)...) x_transformed # test view input
6364
else
6465
_x = copy(x)
6566
@test P * _copy(_x) x_transformed
@@ -85,6 +86,7 @@ function TestUtils.test_plan_adjoint(P::AbstractFFTs.Plan, x::AbstractArray; rea
8586
@test _component_dot(y, P * _copy(x)) _component_dot(P' * _copy(y), x)
8687
@test _component_dot(x, P \ _copy(y)) _component_dot(P' \ _copy(x), y)
8788
end
89+
@test P' * view(_copy(y), axes(y)...) P' * _copy(y) # test view input (AbstractFFTs.jl#112)
8890
@test_throws MethodError mul!(x, P', y)
8991
end
9092

src/definitions.jl

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -702,11 +702,16 @@ function adjoint_mul(p::Plan{T}, x::AbstractArray, ::RFFTAdjointStyle) where {T<
702702
halfdim = first(dims)
703703
d = size(p, halfdim)
704704
n = output_size(p, halfdim)
705-
scale = reshape(
706-
[(i == 1 || (i == n && 2 * (i - 1)) == d) ? N : 2 * N for i in 1:n],
707-
ntuple(i -> i == halfdim ? n : 1, Val(ndims(x)))
708-
)
709-
return p \ (x ./ convert(typeof(x), scale))
705+
y = map(x, CartesianIndices(x)) do xj, j
706+
i = j[halfdim]
707+
yj = if i == 1 || (i == n && 2 * (i - 1) == d)
708+
xj / N
709+
else
710+
xj / (2 * N)
711+
end
712+
return yj
713+
end
714+
return p \ y
710715
end
711716

712717
function adjoint_mul(p::Plan{T}, x::AbstractArray, ::IRFFTAdjointStyle) where {T}
@@ -715,11 +720,17 @@ function adjoint_mul(p::Plan{T}, x::AbstractArray, ::IRFFTAdjointStyle) where {T
715720
halfdim = first(dims)
716721
n = size(p, halfdim)
717722
d = output_size(p, halfdim)
718-
scale = reshape(
719-
[(i == 1 || (i == n && 2 * (i - 1)) == d) ? 1 : 2 for i in 1:n],
720-
ntuple(i -> i == halfdim ? n : 1, Val(ndims(x)))
721-
)
722-
return (convert(typeof(x), scale) ./ N) .* (p \ x)
723+
y = p \ x
724+
z = map(y, CartesianIndices(y)) do yj, j
725+
i = j[halfdim]
726+
zj = if i == 1 || (i == n && 2 * (i - 1) == d)
727+
yj / N
728+
else
729+
2 * yj / N
730+
end
731+
return zj
732+
end
733+
return z
723734
end
724735

725736
adjoint_mul(p::Plan, x::AbstractArray, ::UnitaryAdjointStyle) = p \ x

0 commit comments

Comments
 (0)