Skip to content

Commit aa035e9

Browse files
handle data movement with MLDataDevices.jl (#2492)
* removed Flux devices * fix gpu extensions * ported MPI extension * docs * docs * skip enzyme tests * fix docs * more enzyme fixes * fix metal * fix gpu * doc project * fix buildkite preference * fix docs * fix docs * fix docs * fix docs * some tests are broken * cleanup * fix tests * buildkite * rework rng_from_array
1 parent b0c6653 commit aa035e9

37 files changed

+400
-676
lines changed

.buildkite/pipeline.yml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,10 @@ steps:
2626
# cuda: "*"
2727
# timeout_in_minutes: 60
2828

29-
- label: "Metal with julia {{matrix.julia}}"
29+
- label: "Metal with julia v1"
3030
plugins:
3131
- JuliaCI/julia#v1:
32-
version: "{{matrix.julia}}"
32+
version: "1"
3333
- JuliaCI/julia-test#v1:
3434
test_args: "--quickfail"
3535
- JuliaCI/julia-coverage#v1:
@@ -46,7 +46,7 @@ steps:
4646
using Pkg
4747
Pkg.resolve()'
4848
commands: |
49-
printf "[Flux]\ngpu_backend = \"Metal\"" > LocalPreferences.toml
49+
printf "[MLDataDevices]\ngpu_backend = \"Metal\"\n" > LocalPreferences.toml
5050
5151
if: build.message !~ /\[skip tests\]/
5252
timeout_in_minutes: 60
@@ -74,7 +74,7 @@ steps:
7474
rocm: "*"
7575
rocmgpu: "*"
7676
commands: |
77-
printf "[Flux]\ngpu_backend = \"AMDGPU\"" > LocalPreferences.toml
77+
printf "[MLDataDevices]\ngpu_backend = \"AMDGPU\"\n" > LocalPreferences.toml
7878
timeout_in_minutes: 60
7979
env:
8080
JULIA_AMDGPU_CORE_MUST_LOAD: "1"

NEWS.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@
22

33
See also [github's page](https://github.com/FluxML/Flux.jl/releases) for a complete list of PRs merged before each release.
44

5+
## v0.14.22
6+
* Data movement between devices is now provided by [MLDataDevices.jl](https://github.com/LuxDL/MLDataDevices.jl).
7+
58
## v0.14.18
69
* Add [support for distributed data parallel training](https://github.com/FluxML/Flux.jl/pull/2446).
710
* MPI and NCCL backend available with `FluxMPIExt` and `FluxMPINCCLExt` extensions respectively.

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
88
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
99
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
1010
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
11+
MLDataDevices = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40"
1112
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
1213
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
1314
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
@@ -49,6 +50,7 @@ ChainRulesCore = "1.12"
4950
Compat = "4.10.0"
5051
Enzyme = "0.12, 0.13"
5152
Functors = "0.4"
53+
MLDataDevices = "1.2.0"
5254
MLUtils = "0.4"
5355
MPI = "0.20.19"
5456
MacroTools = "0.5"

docs/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
66
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
77
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
88
JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819"
9+
MLDataDevices = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40"
910
MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458"
1011
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
1112
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"

docs/make.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
using Documenter, Flux, NNlib, Functors, MLUtils, BSON, Optimisers,
22
OneHotArrays, Zygote, ChainRulesCore, Plots, MLDatasets, Statistics,
3-
DataFrames, JLD2
3+
DataFrames, JLD2, MLDataDevices
44

55
DocMeta.setdocmeta!(Flux, :DocTestSetup, :(using Flux); recursive = true)
66

77
makedocs(
8-
modules = [Flux, NNlib, Functors, MLUtils, Zygote, OneHotArrays, Optimisers, ChainRulesCore],
8+
modules = [Flux, NNlib, Functors, MLUtils, Zygote, OneHotArrays, Optimisers, ChainRulesCore, MLDataDevices],
99
sitename = "Flux",
1010
pages = [
1111
"Welcome" => "index.md",

docs/src/guide/gpu.md

Lines changed: 45 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -232,19 +232,17 @@ More information for conditional use of GPUs in CUDA.jl can be found in its [doc
232232

233233
## Using device objects
234234

235-
As a more convenient syntax, Flux allows the usage of GPU `device` objects which can be used to easily transfer models to GPUs (and defaulting to using the CPU if no GPU backend is available). This syntax has a few advantages including automatic selection of the GPU backend and type stability of data movement. To do this, the [`Flux.get_device`](@ref) function can be used.
235+
As a more convenient syntax, Flux allows the usage of GPU `device` objects which can be used to easily transfer models to GPUs (and defaulting to using the CPU if no GPU backend is available). This syntax has a few advantages including automatic selection of the GPU backend and type stability of data movement.
236+
These features are provided by [MLDataDevices.jl](https://github.com/LuxDL/MLDataDevices.jl) package, that Flux's uses internally and re-exports.
236237

237-
`Flux.get_device` first checks for a GPU preference, and if possible returns a device for the preference backend. For instance, consider the following example, where we load the [CUDA.jl](https://github.com/JuliaGPU/CUDA.jl) package to use an NVIDIA GPU (`"CUDA"` is the default preference):
238+
A `device` object can be created using the [`gpu_device`](@ref MLDataDevices.gpu_device) function.
239+
`gpu_device` first checks for a GPU preference, and if possible returns a device for the preference backend. For instance, consider the following example, where we load the [CUDA.jl](https://github.com/JuliaGPU/CUDA.jl) package to use an NVIDIA GPU (`"CUDA"` is the default preference):
238240

239241
```julia-repl
240242
julia> using Flux, CUDA;
241243
242-
julia> device = Flux.get_device(; verbose=true) # returns handle to an NVIDIA GPU
243-
[ Info: Using backend set in preferences: CUDA.
244-
(::Flux.FluxCUDADevice) (generic function with 1 method)
245-
246-
julia> device.deviceID # check the id of the GPU
247-
CuDevice(0): NVIDIA GeForce GTX 1650
244+
julia> device = gpu_device() # returns handle to an NVIDIA GPU if available
245+
(::CUDADevice{Nothing}) (generic function with 4 methods)
248246
249247
julia> model = Dense(2 => 3);
250248
@@ -262,77 +260,57 @@ julia> model.weight
262260
-0.984794 -0.904345
263261
0.720379 -0.486398
264262
0.851011 -0.586942
265-
266263
```
267264

268-
The device preference can also be set via the [`Flux.gpu_backend!`](@ref) function. For instance, below we first set our device preference to `"CPU"`:
265+
The device preference can also be set via the [`gpu_backend!`](@ref MLDataDevices.gpu_backend!) function. For instance, below we first set our device preference to `"AMDGPU"`:
269266

270267
```julia-repl
271-
julia> using Flux; Flux.gpu_backend!("CPU")
272-
┌ Info: New GPU backend set: CPU.
273-
└ Restart your Julia session for this change to take effect!
268+
julia> gpu_backend!("AMDGPU")
269+
[ Info: GPU backend has been set to AMDGPU. Restart Julia to use the new backend.
274270
```
275-
276-
Then, after restarting the Julia session, `Flux.get_device` returns a handle to the `"CPU"`:
271+
If no functional GPU backend is available, the device will default to a CPU device.
272+
You can also explictly request a CPU device by calling the [`cpu_device`](@ref MLDataDevices.cpu_device) function.
277273

278274
```julia-repl
279-
julia> using Flux, CUDA; # even if CUDA is loaded, we'll still get a CPU device
280-
281-
julia> device = Flux.get_device(; verbose=true) # get a CPU device
282-
[ Info: Using backend set in preferences: CPU.
283-
(::Flux.FluxCPUDevice) (generic function with 1 method)
275+
julia> using Flux, MLDataDevices
284276
285-
julia> model = Dense(2 => 3);
286-
287-
julia> model = model |> device
288-
Dense(2 => 3) # 9 parameters
277+
julia> cdev = cpu_device()
278+
(::CPUDevice{Nothing}) (generic function with 4 methods)
289279
290-
julia> model.weight # no change; model still lives on CPU
291-
3×2 Matrix{Float32}:
292-
-0.942968 0.856258
293-
0.440009 0.714106
294-
-0.419192 -0.471838
295-
```
296-
Clearly, this means that the same code will work for any GPU backend and the CPU.
280+
julia> gdev = gpu_device(force=true) # force GPU device, error if no GPU is available
281+
(::CUDADevice{Nothing}) (generic function with 4 methods)
297282
298-
If the preference backend isn't available or isn't functional, then [`Flux.get_device`](@ref) looks for a CUDA, AMDGPU or Metal backend, and returns a corresponding device (if the backend is available and functional). Otherwise, a CPU device is returned. In the below example, the GPU preference is `"CUDA"`:
283+
julia> model = Dense(2 => 3); # model in CPU memory
299284
300-
```julia-repl
301-
julia> using Flux; # preference is CUDA, but CUDA.jl not loaded
285+
julia> gmodel = model |> gdev; # transfer model to GPU
302286
303-
julia> device = Flux.get_device(; verbose=true) # this will resort to automatic device selection
304-
[ Info: Using backend set in preferences: CUDA.
305-
┌ Warning: Trying to use backend: CUDA but it's trigger package is not loaded.
306-
│ Please load the package and call this function again to respect the preferences backend.
307-
└ @ Flux ~/fluxml/Flux.jl/src/functor.jl:637
308-
[ Info: Using backend: CPU.
309-
(::Flux.FluxCPUDevice) (generic function with 1 method)
287+
julia> cmodel = gmodel |> cdev; # transfer model back to CPU
310288
```
311-
For detailed information about how the backend is selected, check the documentation for [`Flux.get_device`](@ref).
312289

313290
## Data movement across GPU devices
314291

315-
Flux also supports getting handles to specific GPU devices, and transferring models from one GPU device to another GPU
316-
device from the same backend. Let's try it out for NVIDIA GPUs. First, we list all the available devices:
292+
Flux also supports getting handles to specific GPU devices, and transferring models from one GPU device to another GPU device from the same backend. Let's try it out for NVIDIA GPUs. First, we list all the available devices:
317293

318294
```julia-repl
319295
julia> using Flux, CUDA;
320296
321297
julia> CUDA.devices()
322298
CUDA.DeviceIterator() for 3 devices:
323-
0. GeForce RTX 2080 Ti
324-
1. GeForce RTX 2080 Ti
325-
2. TITAN X (Pascal)
326-
299+
0. NVIDIA TITAN RTX
300+
1. NVIDIA TITAN RTX
301+
2. NVIDIA TITAN RTX
327302
```
328303

329304
Then, let's select the device with id `0`:
330305

331306
```julia-repl
332-
julia> device0 = Flux.get_device("CUDA", 0) # the currently supported values for backend are "CUDA" and "AMDGPU"
333-
(::Flux.FluxCUDADevice) (generic function with 1 method)
307+
julia> device0 = gpu_device(1)
308+
(::CUDADevice{CuDevice}) (generic function with 4 methods)
334309
310+
julia> device0.device
311+
CuDevice(0): NVIDIA TITAN RTX
335312
```
313+
Notice that indexing starts from `0` in the `CUDA.devices()` output, but `gpu_device!` expects the device id starting from `1`.
336314

337315
Then, let's move a simple dense layer to the GPU represented by `device0`:
338316

@@ -343,27 +321,25 @@ Dense(2 => 3) # 9 parameters
343321
julia> dense_model = dense_model |> device0;
344322
345323
julia> dense_model.weight
346-
3×2 CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}:
347-
0.695662 0.816299
348-
-0.204763 -0.10232
349-
-0.955829 0.538412
324+
3×2 CuArray{Float32, 2, CUDA.DeviceMemory}:
325+
-0.142062 -0.131455
326+
-0.828134 -1.06552
327+
0.608595 -1.05375
350328
351329
julia> CUDA.device(dense_model.weight) # check the GPU to which dense_model is attached
352-
CuDevice(0): GeForce RTX 2080 Ti
353-
330+
CuDevice(0): NVIDIA TITAN RTX
354331
```
355332

356333
Next, we'll get a handle to the device with id `1`, and move `dense_model` to that device:
357334

358335
```julia-repl
359-
julia> device1 = Flux.get_device("CUDA", 1)
360-
(::Flux.FluxCUDADevice) (generic function with 1 method)
336+
julia> device1 = gpu_device(2)
337+
(::CUDADevice{CuDevice}) (generic function with 4 methods)
361338
362339
julia> dense_model = dense_model |> device1; # don't directly print the model; see warning below
363340
364341
julia> CUDA.device(dense_model.weight)
365-
CuDevice(1): GeForce RTX 2080 Ti
366-
342+
CuDevice(1): NVIDIA TITAN RTX
367343
```
368344

369345
Due to a limitation in `Metal.jl`, currently this kind of data movement across devices is only supported for `CUDA` and `AMDGPU` backends.
@@ -376,14 +352,15 @@ Due to a limitation in `Metal.jl`, currently this kind of data movement across d
376352

377353

378354
```@docs
379-
Flux.AbstractDevice
380-
Flux.FluxCPUDevice
381-
Flux.FluxCUDADevice
382-
Flux.FluxAMDGPUDevice
383-
Flux.FluxMetalDevice
384-
Flux.supported_devices
385-
Flux.get_device
386-
Flux.gpu_backend!
355+
MLDataDevices.cpu_device
356+
MLDataDevices.default_device_rng
357+
MLDataDevices.get_device
358+
MLDataDevices.gpu_device
359+
MLDataDevices.gpu_backend!
360+
MLDataDevices.get_device_type
361+
MLDataDevices.reset_gpu_device!
362+
MLDataDevices.supported_gpu_backends
363+
MLDataDevices.DeviceIterator
387364
```
388365

389366
## Distributed data parallel training

docs/src/guide/models/recurrence.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ julia> RNN(2, 5) # or equivalently RNN(2 => 5)
7171
Recur(
7272
RNNCell(2 => 5, tanh), # 45 parameters
7373
) # Total: 4 trainable arrays, 45 parameters,
74-
# plus 1 non-trainable, 5 parameters, summarysize 412 bytes.
74+
# plus 1 non-trainable, 5 parameters, summarysize 404 bytes.
7575
```
7676

7777
Equivalent to the `RNN` stateful constructor, `LSTM` and `GRU` are also available.
@@ -86,7 +86,7 @@ Chain(
8686
),
8787
Dense(5 => 1), # 6 parameters
8888
) # Total: 6 trainable arrays, 51 parameters,
89-
# plus 1 non-trainable, 5 parameters, summarysize 580 bytes.
89+
# plus 1 non-trainable, 5 parameters, summarysize 540 bytes.
9090
```
9191
In this example, each output has only one component.
9292

docs/src/guide/saving.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ julia> m = Chain(Dense(10 => 5, relu), Dense(5 => 2))
6262
Chain(
6363
Dense(10 => 5, relu), # 55 parameters
6464
Dense(5 => 2), # 12 parameters
65-
) # Total: 4 arrays, 67 parameters, 524 bytes.
65+
) # Total: 4 arrays, 67 parameters, 476 bytes.
6666
6767
julia> for epoch in 1:10
6868
# ... train model ...
@@ -131,7 +131,7 @@ julia> model
131131
Chain(
132132
Dense(10 => 5, relu), # 55 parameters
133133
Dense(5 => 2), # 12 parameters
134-
) # Total: 4 arrays, 67 parameters, 524 bytes.
134+
) # Total: 4 arrays, 67 parameters, 476 bytes.
135135
```
136136
!!! warning
137137
Saving models this way could lead to compatibility issues across julia versions

docs/src/reference/destructure.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,4 +94,5 @@ Flux.loadmodel!
9494
Functors.KeyPath
9595
Functors.getkeypath
9696
Functors.haskeypath
97-
```
97+
Functors.setkeypath!
98+
```

ext/FluxAMDGPUExt/FluxAMDGPUExt.jl

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,16 +17,6 @@ const MIOPENFloat = AMDGPU.MIOpen.MIOPENFloat
1717
# Set to boolean on the first call to check_use_amdgpu
1818
const USE_AMDGPU = Ref{Union{Nothing, Bool}}(nothing)
1919

20-
function (device::Flux.FluxAMDGPUDevice)(x)
21-
if device.deviceID === nothing
22-
Flux.gpu(Flux.FluxAMDGPUAdaptor(), x)
23-
else
24-
return Flux.gpu(Flux.FluxAMDGPUAdaptor(AMDGPU.device_id(device.deviceID) - 1), x) # subtracting 1, because device_id returns a positive integer
25-
end
26-
end
27-
Flux._get_device_name(::Flux.FluxAMDGPUDevice) = "AMDGPU"
28-
Flux._isavailable(::Flux.FluxAMDGPUDevice) = true
29-
Flux._isfunctional(::Flux.FluxAMDGPUDevice) = AMDGPU.functional()
3020

3121
function check_use_amdgpu()
3222
if !isnothing(USE_AMDGPU[])
@@ -55,7 +45,6 @@ include("conv.jl")
5545

5646
function __init__()
5747
Flux.AMDGPU_LOADED[] = true
58-
Flux.DEVICES[][Flux.GPU_BACKEND_ORDER["AMDGPU"]] = AMDGPU.functional() ? Flux.FluxAMDGPUDevice(AMDGPU.device()) : Flux.FluxAMDGPUDevice(nothing)
5948
end
6049

6150
# TODO

0 commit comments

Comments
 (0)