@@ -30,16 +30,20 @@ function ChainRulesCore.rrule(::typeof(rfft), x::AbstractArray{<:Real}, dims)
30
30
halfdim = first (dims)
31
31
d = size (x, halfdim)
32
32
n = size (y, halfdim)
33
- scale = reshape (
34
- [i == 1 || (i == n && 2 * (i - 1 ) == d) ? 1 : 2 for i in 1 : n],
35
- ntuple (i -> i == first (dims) ? n : 1 , Val (ndims (x))),
36
- )
37
33
38
34
project_x = ChainRulesCore. ProjectTo (x)
39
35
function rfft_pullback (ȳ)
40
36
ybar = ChainRulesCore. unthunk (ȳ)
41
- _scale = convert (typeof (ybar),scale)
42
- x̄ = project_x (brfft (ybar ./ _scale, d, dims))
37
+ ybar_scaled = map (ybar, CartesianIndices (ybar)) do ybar_j, j
38
+ i = j[halfdim]
39
+ ybar_scaled_j = if i == 1 || (i == n && 2 * (i - 1 ) == d)
40
+ ybar_j
41
+ else
42
+ ybar_j / 2
43
+ end
44
+ return ybar_scaled_j
45
+ end
46
+ x̄ = project_x (brfft (ybar_scaled, d, dims))
43
47
return ChainRulesCore. NoTangent (), x̄, ChainRulesCore. NoTangent ()
44
48
end
45
49
return y, rfft_pullback
@@ -74,16 +78,20 @@ function ChainRulesCore.rrule(::typeof(irfft), x::AbstractArray, d::Int, dims)
74
78
n = size (x, halfdim)
75
79
invN = AbstractFFTs. normalization (y, dims)
76
80
twoinvN = 2 * invN
77
- scale = reshape (
78
- [i == 1 || (i == n && 2 * (i - 1 ) == d) ? invN : twoinvN for i in 1 : n],
79
- ntuple (i -> i == first (dims) ? n : 1 , Val (ndims (x))),
80
- )
81
81
82
82
project_x = ChainRulesCore. ProjectTo (x)
83
83
function irfft_pullback (ȳ)
84
84
ybar = ChainRulesCore. unthunk (ȳ)
85
- _scale = convert (typeof (ybar),scale)
86
- x̄ = project_x (_scale .* rfft (real .(ybar), dims))
85
+ x̄_scaled = rfft (real .(ybar), dims)
86
+ x̄ = project_x (map (x̄_scaled, CartesianIndices (x̄_scaled)) do x̄_scaled_j, j
87
+ i = j[halfdim]
88
+ x̄_j = if i == 1 || (i == n && 2 * (i - 1 ) == d)
89
+ invN * x̄_scaled_j
90
+ else
91
+ twoinvN * x̄_scaled_j
92
+ end
93
+ return x̄_j
94
+ end )
87
95
return ChainRulesCore. NoTangent (), x̄, ChainRulesCore. NoTangent (), ChainRulesCore. NoTangent ()
88
96
end
89
97
return y, irfft_pullback
@@ -115,14 +123,19 @@ function ChainRulesCore.rrule(::typeof(brfft), x::AbstractArray, d::Int, dims)
115
123
# compute scaling factors
116
124
halfdim = first (dims)
117
125
n = size (x, halfdim)
118
- scale = reshape (
119
- [i == 1 || (i == n && 2 * (i - 1 ) == d) ? 1 : 2 for i in 1 : n],
120
- ntuple (i -> i == first (dims) ? n : 1 , Val (ndims (x))),
121
- )
122
126
123
127
project_x = ChainRulesCore. ProjectTo (x)
124
128
function brfft_pullback (ȳ)
125
- x̄ = project_x (scale .* rfft (real .(ChainRulesCore. unthunk (ȳ)), dims))
129
+ x̄_scaled = rfft (real .(ChainRulesCore. unthunk (ȳ)), dims)
130
+ x̄ = project_x (map (x̄_scaled, CartesianIndices (x̄_scaled)) do x̄_scaled_j, j
131
+ i = j[halfdim]
132
+ x̄_j = if i == 1 || (i == n && 2 * (i - 1 ) == d)
133
+ x̄_scaled_j
134
+ else
135
+ 2 * x̄_scaled_j
136
+ end
137
+ return x̄_j
138
+ end )
126
139
return ChainRulesCore. NoTangent (), x̄, ChainRulesCore. NoTangent (), ChainRulesCore. NoTangent ()
127
140
end
128
141
return y, brfft_pullback
0 commit comments