Skip to content

Commit a377cb0

Browse files
Merge pull request #2034 from chengchingwen/master
fix array constructor rrule
2 parents c4837f7 + 1f31336 commit a377cb0

File tree

2 files changed

+12
-1
lines changed

2 files changed

+12
-1
lines changed

src/functor.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ adapt_storage(to::FluxCPUAdaptor, x::AbstractSparseArray) = x
115115
adapt_storage(to::FluxCPUAdaptor, x::CUDA.RNG) = Random.default_rng()
116116
adapt_storage(to::FluxCPUAdaptor, x::AbstractRNG) = x
117117

118-
function ChainRulesCore.rrule(::typeof(Array), x::CUDA.CuArray)
118+
function ChainRulesCore.rrule(::Type{Array}, x::CUDA.CuArray)
119119
Array(x), d -> (NoTangent(), CUDA.cu(d),)
120120
end
121121

test/utils.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -765,5 +765,16 @@ end
765765
g = Flux.Zygote.ForwardDiff.gradient(pv -> loss(data, 1, pv), pvec)
766766
@test g Flux.Zygote.gradient(pv -> loss(data, 1, pv), pvec)[1]
767767
end
768+
end
768769

770+
@testset "Rrule" begin
771+
@testset "issue 2033" begin
772+
if CUDA.functional()
773+
struct Wrapped{T}
774+
x::T
775+
end
776+
y, _ = Flux.pullback(Wrapped, cu(randn(3,3)))
777+
@test y isa Wrapped{<:CuArray}
778+
end
779+
end
769780
end

0 commit comments

Comments
 (0)