Skip to content

Commit 27c37a0

Browse files
For conversion of scale vector in adjoint (#105)
* For conversion of scale vector in adjoint It always defines an `Array` which can fail on the GPU. This forces it to be the right type. One could also use `adapt` here, but since the element type promotion would have to occur anyways in the subsequent broadcast it seems you might as well convert all at once. * Update AbstractFFTsChainRulesCoreExt.jl
1 parent 3a3f0e4 commit 27c37a0

File tree

1 file changed

+6
-2
lines changed

1 file changed

+6
-2
lines changed

ext/AbstractFFTsChainRulesCoreExt.jl

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,9 @@ function ChainRulesCore.rrule(::typeof(rfft), x::AbstractArray{<:Real}, dims)
3737

3838
project_x = ChainRulesCore.ProjectTo(x)
3939
function rfft_pullback(ȳ)
40-
= project_x(brfft(ChainRulesCore.unthunk(ȳ) ./ scale, d, dims))
40+
ybar = ChainRulesCore.unthunk(ȳ)
41+
_scale = convert(typeof(ybar),scale)
42+
= project_x(brfft(ybar ./ _scale, d, dims))
4143
return ChainRulesCore.NoTangent(), x̄, ChainRulesCore.NoTangent()
4244
end
4345
return y, rfft_pullback
@@ -79,7 +81,9 @@ function ChainRulesCore.rrule(::typeof(irfft), x::AbstractArray, d::Int, dims)
7981

8082
project_x = ChainRulesCore.ProjectTo(x)
8183
function irfft_pullback(ȳ)
82-
= project_x(scale .* rfft(real.(ChainRulesCore.unthunk(ȳ)), dims))
84+
ybar = ChainRulesCore.unthunk(ȳ)
85+
_scale = convert(typeof(ybar),scale)
86+
= project_x(_scale .* rfft(real.(ybar), dims))
8387
return ChainRulesCore.NoTangent(), x̄, ChainRulesCore.NoTangent(), ChainRulesCore.NoTangent()
8488
end
8589
return y, irfft_pullback

0 commit comments

Comments
 (0)