Skip to content

Commit 2fe82a8

Browse files
fix: make gpu(x) return unmodified x when GPU backends aren't loaded (#2295)
* fix: make gpu return unmodified input when gpu isn't available * add tests * fix
1 parent fb507aa commit 2fe82a8

File tree

3 files changed

+10
-0
lines changed

3 files changed

+10
-0
lines changed

src/functor.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -341,6 +341,7 @@ function gpu(::FluxCUDAAdaptor, x)
341341
`CUDA.jl` must be loaded to access it.
342342
Add `using CUDA` or `import CUDA` to your code.
343343
""" maxlog=1
344+
return x
344345
end
345346
end
346347

@@ -361,6 +362,7 @@ function gpu(::FluxAMDAdaptor, x)
361362
`AMDGPU.jl` must be loaded to access it.
362363
Add `using AMDGPU` or `import AMDGPU` to your code.
363364
""" maxlog=1
365+
return x
364366
end
365367
end
366368

@@ -380,6 +382,7 @@ function gpu(::FluxMetalAdaptor, x)
380382
The Metal functionality is being called but
381383
`Metal.jl` must be loaded to access it.
382384
""" maxlog=1
385+
return x
383386
end
384387
end
385388

test/functors.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
x = rand(Float32, 10, 10)
2+
if !(Flux.CUDA_LOADED[] || Flux.AMD_LOADED[] || Flux.METAL_LOADED[])
3+
@test x === gpu(x)
4+
end

test/runtests.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,9 @@ Random.seed!(0)
5656
include("outputsize.jl")
5757
end
5858

59+
@testset "functors" begin
60+
include("functors.jl")
61+
end
5962

6063
if get(ENV, "FLUX_TEST_CUDA", "false") == "true"
6164
using CUDA

0 commit comments

Comments
 (0)