Skip to content

Commit ee78ce3

Browse files
CarloLucibelloDhairyaLGandhiToucheSir
authored
fix cpu(x) for immutable arrays (#2117)
* fix cpu(x) for immutable arrays * Update test/cuda/cuda.jl Co-authored-by: Dhairya Gandhi <dhairya@juliacomputing.com> * Update test/cuda/cuda.jl Co-authored-by: Brian Chen <ToucheSir@users.noreply.github.com> * rrules for adapt * do not unthunk if not needed * add comment on adapt rrules Co-authored-by: Dhairya Gandhi <dhairya@juliacomputing.com> Co-authored-by: Brian Chen <ToucheSir@users.noreply.github.com>
1 parent 6f24d3a commit ee78ce3

File tree

3 files changed

+37
-8
lines changed

3 files changed

+37
-8
lines changed

src/functor.jl

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -121,14 +121,31 @@ adapt_storage(to::FluxCPUAdaptor, x::AbstractSparseArray) = x
121121
adapt_storage(to::FluxCPUAdaptor, x::CUDA.RNG) = Random.default_rng()
122122
adapt_storage(to::FluxCPUAdaptor, x::AbstractRNG) = x
123123

124+
# PIRACY, should be defined in CUDA.jl
124125
function ChainRulesCore.rrule(::Type{Array}, x::CUDA.CuArray)
125-
Array(x), dx -> (NoTangent(), CUDA.cu(unthunk(dx)),)
126+
Array(x), dx -> (NoTangent(), CUDA.cu(unthunk(dx)))
126127
end
127128

128129
function ChainRulesCore.rrule(::typeof(Adapt.adapt_storage), to::FluxCPUAdaptor, x::CUDA.AbstractGPUArray)
129-
adapt_storage(to, x), dx -> (NoTangent(), NoTangent(), adapt_storage(FluxCUDAAdaptor(), unthunk(dx)),)
130+
adapt_storage(to, x), dx -> (NoTangent(), NoTangent(), adapt_storage(FluxCUDAAdaptor(), unthunk(dx)))
130131
end
131132

133+
# The following rrules for adapt are here to avoid double wrapping issues
134+
# as seen in https://github.com/FluxML/Flux.jl/pull/2117#discussion_r1027321801
135+
136+
ChainRulesCore.rrule(::typeof(adapt), a::FluxCPUAdaptor, x::AnyCuArray) =
137+
adapt(a, x), Δ -> (NoTangent(), NoTangent(), adapt(FluxCUDAAdaptor(), unthunk(Δ)))
138+
139+
ChainRulesCore.rrule(::typeof(adapt), a::FluxCPUAdaptor, x::AbstractArray) =
140+
adapt(a, x), Δ -> (NoTangent(), NoTangent(), Δ)
141+
142+
ChainRulesCore.rrule(::typeof(adapt), a::FluxCUDAAdaptor, x::AnyCuArray) =
143+
adapt(a, x), Δ -> (NoTangent(), NoTangent(), Δ)
144+
145+
ChainRulesCore.rrule(::typeof(adapt), a::FluxCUDAAdaptor, x::AbstractArray) =
146+
adapt(a, x), Δ -> (NoTangent(), NoTangent(), adapt(FluxCPUAdaptor(), unthunk(Δ)))
147+
148+
132149
# CPU/GPU movement conveniences
133150

134151
"""
@@ -154,7 +171,7 @@ julia> typeof(m_cpu.W)
154171
Matrix{Float32}
155172
```
156173
"""
157-
cpu(x) = fmap(x -> adapt(FluxCPUAdaptor(), x), x)
174+
cpu(x) = fmap(x -> adapt(FluxCPUAdaptor(), x), x, exclude = _isleaf)
158175

159176
_isbitsarray(::AbstractArray{<:Number}) = true
160177
_isbitsarray(::AbstractArray{T}) where T = isbitstype(T)

test/cuda/cuda.jl

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ using SparseArrays: sparse, SparseMatrixCSC, AbstractSparseArray
2020
m = Chain(Dense(10, 5, tanh), Dense(5, 2), softmax)
2121
cm = gpu(m)
2222

23-
@test all(p isa CuArray for p in params(cm))
23+
@test all(p isa CuArray for p in Flux.params(cm))
2424
@test cm(gpu(rand(10, 10))) isa CuArray{Float32,2}
2525

2626
xs = rand(5, 5)
@@ -65,7 +65,7 @@ end
6565
end
6666

6767
@testset "onehot forward map to broadcast" begin
68-
oa = OneHotArray(rand(1:10, 5, 5), 10) |> gpu
68+
oa = Flux.OneHotArray(rand(1:10, 5, 5), 10) |> gpu
6969
@test all(map(identity, oa) .== oa)
7070
@test all(map(x -> 2 * x, oa) .== 2 .* oa)
7171
end
@@ -110,14 +110,14 @@ end
110110
# This test should really not go through indirections and pull out Fills for efficiency
111111
# but we forcefully materialise. TODO: remove materialising CuArray here
112112
@test gradient(x -> sum(cpu(x)), ca)[1] isa CuArray # This involves FillArray, which should be GPU compatible
113-
@test gradient(x -> sum(cpu(x)), ca')[1] isa LinearAlgebra.Adjoint
113+
@test gradient(x -> sum(cpu(x)), ca')[1] isa CuArray
114114

115115
# Even more trivial: no movement
116116
@test gradient(x -> sum(abs, cpu(x)), a)[1] isa Matrix
117117
@test gradient(x -> sum(abs, cpu(x)), a')[1] isa Matrix
118118
@test gradient(x -> sum(cpu(x)), a)[1] isa typeof(gradient(sum, a)[1]) # FillArray
119119
@test gradient(x -> sum(abs, gpu(x)), ca)[1] isa CuArray
120-
@test_skip gradient(x -> sum(abs, gpu(x)), ca')[1] isa CuArray # KernelError: passing and using non-bitstype argument
120+
@test gradient(x -> sum(abs, gpu(x)), ca')[1] isa CuArray
121121

122122
# More complicated, Array * CuArray is an error
123123
g0 = gradient(x -> sum(abs, (a * (a * x))), a)[1]
@@ -165,4 +165,16 @@ end
165165
@test gpu(g2) isa CuArray
166166
@test gpu(g2) cu(Vector(g2))
167167
@test parent(gpu(g3)) isa CuArray
168+
169+
170+
#Issue #2116
171+
struct A2116
172+
x::Int
173+
y::Int
174+
end
175+
x = [A2116(1,1), A2116(2,2)]
176+
xgpu = gpu(x)
177+
@test xgpu isa CuVector{A2116}
178+
@test cpu(xgpu) isa Vector{A2116}
179+
@test cpu(gpu([CartesianIndex(1)])) isa Vector{CartesianIndex{1}}
168180
end

test/cuda/runtests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
using Flux, Test, CUDA
22
using Zygote
33
using Zygote: pullback
4-
using Random
4+
using Random, LinearAlgebra, Statistics
55

66
@info "Testing GPU Support"
77
CUDA.allowscalar(false)

0 commit comments

Comments
 (0)