Skip to content

Commit 35b893a

Browse files
simplify test machinery (#2498)
* simplify test machinery
1 parent 09a16ee commit 35b893a

File tree

19 files changed

+153
-273
lines changed

19 files changed

+153
-273
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,3 +10,4 @@ Manifest.toml
1010
LocalPreferences.toml
1111
.DS_Store
1212
docs/mymodel.bson
13+
prova.jl

src/distributed/public_api.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ Backend Agnostic API to perform an allreduce operation on the given buffer `send
132132
workers.
133133
"""
134134
function allreduce!(backend::AbstractFluxDistributedBackend, sendrecvbuf, op::F) where {F}
135-
return __allreduce!(backend, sendrecvbuf, op, get_device())
135+
return __allreduce!(backend, sendrecvbuf, op, gpu_device())
136136
end
137137

138138
function allreduce!(

test/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,4 +21,4 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
2121
[compat]
2222
FiniteDifferences = "0.12"
2323
Tracker = "0.2.33"
24-
Enzyme = "0.12.4"
24+
Enzyme = "0.13"

test/ext_amdgpu/basic.jl

Lines changed: 21 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -19,26 +19,27 @@ end
1919
end
2020

2121
@testset "Chain of Dense layers" begin
22-
m = Chain(Dense(10, 5, tanh), Dense(5, 2), softmax) |> f32
22+
m = Chain(Dense(10, 5, tanh), Dense(5, 2), softmax)
2323
x = rand(Float32, 10, 10)
24-
gpu_autodiff_test(m, x)
24+
test_gradients(m, x, test_gpu=true, compare_finite_diff=false)
2525
end
2626

2727
@testset "Convolution" begin
2828
for conv_type in (Conv, ConvTranspose), nd in 1:3
29-
m = conv_type(tuple(fill(2, nd)...), 3 => 4) |> f32
29+
m = conv_type(tuple(fill(2, nd)...), 3 => 4)
3030
x = rand(Float32, fill(10, nd)..., 3, 5)
3131

32+
md, xd = Flux.gpu.((m, x))
33+
y = m(x)
3234
# Ensure outputs are the same.
33-
gpu_autodiff_test(m, x; atol=1f-3, checkgrad=false)
35+
@test collect(md(xd)) y atol=1f-3
3436

3537
# Gradients are flipped as well.
36-
md, xd = Flux.gpu.((m, x))
37-
gs = gradient(m -> sum(m(x)), m)
38-
gsd = gradient(m -> sum(m(xd)), md)
38+
gs = gradient(m -> sum(m(x)), m)[1]
39+
gsd = gradient(m -> sum(m(xd)), md)[1]
3940

4041
dims = ntuple(i -> i, ndims(m.weight) - 2)
41-
@test reverse(gs[1].weight; dims) Array(gsd[1].weight) atol=1f-2
42+
@test reverse(gs.weight; dims) Array(gsd.weight) atol=1f-2
4243

4344
# Movement back to CPU flips weights back.
4445
mh = Flux.cpu(md)
@@ -52,10 +53,10 @@ end
5253
x = rand(Float32, fill(10, nd)..., 3, 5) |> gpu
5354

5455
pad = ntuple(i -> i, nd)
55-
m = conv_type(kernel, 3 => 4, pad=pad) |> f32 |> gpu
56+
m = conv_type(kernel, 3 => 4, pad=pad) |> gpu
5657

5758
expanded_pad = ntuple(i -> pad[(i - 1) ÷ 2 + 1], 2 * nd)
58-
m_expanded = conv_type(kernel, 3 => 4, pad=expanded_pad) |> f32 |> gpu
59+
m_expanded = conv_type(kernel, 3 => 4, pad=expanded_pad) |> gpu
5960

6061
@test size(m(x)) == size(m_expanded(x))
6162
end
@@ -74,25 +75,25 @@ end
7475
end
7576

7677
@testset "Chain(Conv)" begin
77-
m = Chain(Conv((3, 3), 3 => 3)) |> f32
78-
x = rand(Float32, 10, 10, 3, 2)
79-
gpu_autodiff_test(m, x; atol=1f-3, checkgrad=false)
78+
m = Chain(Conv((3, 3), 3 => 3))
79+
x = rand(Float32, 5, 5, 3, 2)
80+
test_gradients(m, x, test_gpu=true, compare_finite_diff=false, test_grad_f=false)
8081

8182
md = m |> gpu |> cpu
8283
@test md[1].weight m[1].weight atol=1f-3
8384

84-
m = Chain(ConvTranspose((3, 3), 3 => 3)) |> f32
85-
x = rand(Float32, 10, 10, 3, 2)
86-
gpu_autodiff_test(m, x; atol=1f-3, checkgrad=false)
85+
m = Chain(ConvTranspose((3, 3), 3 => 3))
86+
x = rand(Float32, 5, 5, 3, 2)
87+
test_gradients(m, x, test_gpu=true, compare_finite_diff=false, test_grad_f=false)
8788

8889
md = m |> gpu |> cpu
8990
@test md[1].weight m[1].weight atol=1f-3
9091
end
9192

9293
@testset "Cross-correlation" begin
93-
m = CrossCor((2, 2), 3 => 4) |> f32
94-
x = rand(Float32, 10, 10, 3, 2)
95-
gpu_autodiff_test(m, x; atol=1f-3)
94+
m = CrossCor((2, 2), 3 => 4)
95+
x = rand(Float32, 5, 5, 3, 2)
96+
test_gradients(m, x, test_gpu=true, compare_finite_diff=false)
9697
end
9798

9899
@testset "Restructure" begin
@@ -132,7 +133,7 @@ end
132133
bn = BatchNorm(3, σ)
133134
for nd in 1:3
134135
x = rand(Float32, fill(2, nd - 1)..., 3, 4)
135-
gpu_autodiff_test(bn, x; atol=1f-3, allow_nothing=true)
136+
test_gradients(bn, x; test_gpu=true, compare_finite_diff=false)
136137
end
137138
end
138139

test/ext_amdgpu/get_devices.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,9 @@ x = randn(Float32, 5, 5)
1717
cx = x |> amdgpu_device
1818
@test cx isa AMDGPU.ROCArray
1919

20-
# moving models to specific NVIDIA devices
20+
# moving models to specific AMDGPU devices
2121
for id in 0:(length(AMDGPU.devices()) - 1)
22-
current_amdgpu_device = Flux.get_device("AMDGPU", id)
22+
current_amdgpu_device = gpu_device(id+1)
2323

2424
global dense_model = dense_model |> current_amdgpu_device
2525
@test dense_model.weight isa AMDGPU.ROCArray

test/ext_amdgpu/runtests.jl

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,6 @@
22
@assert AMDGPU.functional()
33
AMDGPU.allowscalar(false)
44

5-
include("../test_utils.jl")
6-
include("test_utils.jl")
7-
85
@testset "get_devices" begin
96
include("get_devices.jl")
107
end

test/ext_amdgpu/test_utils.jl

Lines changed: 0 additions & 15 deletions
This file was deleted.

test/ext_cuda/get_devices.jl

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,6 @@ dense_model = Dense(2 => 3) # initially lives on CPU
88
weight = copy(dense_model.weight) # store the weight
99
bias = copy(dense_model.bias) # store the bias
1010

11-
cuda_device = Flux.get_device()
12-
13-
@test typeof(cuda_device) <: Flux.CUDADevice
1411

1512
# correctness of data transfer
1613
x = randn(5, 5)
@@ -30,6 +27,12 @@ for id in 0:(length(CUDA.devices()) - 1)
3027
@test isequal(Flux.cpu(dense_model.weight), weight)
3128
@test isequal(Flux.cpu(dense_model.bias), bias)
3229
end
30+
31+
# gpu_device remembers the last device selected
32+
# Therefore, we need to reset it to the current cuda device
33+
@test gpu_device().device.handle == length(CUDA.devices()) - 1
34+
gpu_device(CUDA.device().handle + 1)
35+
3336
# finally move to CPU, and see if things work
3437
cdev = cpu_device()
3538
dense_model = cdev(dense_model)

test/ext_cuda/layers.jl

Lines changed: 32 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -10,73 +10,23 @@
1010
@test gradient(x -> sum(cpu(x)), gpu(rand(3,3))) isa Tuple
1111
end
1212

13-
# TODO: These layers get into scalar indexing issues.
14-
const BROKEN_LAYERS = Union{}
1513

16-
const ACTIVATIONS = [identity, relu, tanh,
17-
sigmoid, exp, softplus,
18-
elu, selu]
14+
const ACTIVATIONS = [identity, tanh]
1915

20-
function gpu_gradtest(name::String, layers::Vector, x_cpu = nothing, args...; test_cpu = true, test_mode = false)
21-
isnothing(x_cpu) && error("Missing input to test the layers against.")
16+
function gpu_gradtest(name::String, layers::Vector, x_cpu, args...;
17+
test_mode=false, test_grad_x=true,
18+
atol=1e-4, rtol=1e-4)
2219
@testset "$name GPU grad tests" begin
2320
for layer in layers
2421
@testset "$layer Layer GPU grad test" begin
2522

2623
# compute output and grad of parameters
2724
l_cpu = layer(args...)
28-
l_gpu = l_cpu |> gpu
2925
if test_mode
3026
testmode!(l_cpu)
31-
testmode!(l_gpu)
3227
end
3328

34-
ps_cpu = Flux.params(l_cpu)
35-
y_cpu, back_cpu = pullback(() -> sum(l_cpu(x_cpu)), ps_cpu)
36-
gs_cpu = back_cpu(1f0)
37-
38-
x_gpu = gpu(x_cpu)
39-
ps_gpu = Flux.params(l_gpu)
40-
41-
if typeof(l_gpu) <: BROKEN_LAYERS
42-
@test_broken gradient(() -> sum(l_gpu(x_gpu)), ps_gpu) isa Flux.Zygote.Grads
43-
else
44-
y_gpu, back_gpu = pullback(() -> sum(l_gpu(x_gpu)), ps_gpu)
45-
gs_gpu = back_gpu(1f0) # TODO many layers error out when backprop int 1, should fix
46-
47-
# compute grad of input
48-
xg_cpu = gradient(x -> sum(l_cpu(x)), x_cpu)[1]
49-
xg_gpu = gradient(x -> sum(l_gpu(x)), x_gpu)[1]
50-
51-
# test
52-
if test_cpu
53-
if layer === GroupedConvTranspose
54-
@test y_gpu y_cpu rtol=1f-2 atol=1f-3
55-
else
56-
@test y_gpu y_cpu rtol=1f-3 atol=1f-3
57-
end
58-
if isnothing(xg_cpu)
59-
@test isnothing(xg_gpu)
60-
else
61-
if layer === GroupedConvTranspose
62-
@test Array(xg_gpu) xg_cpu rtol = 2f-2 atol = 1f-3
63-
else
64-
@test Array(xg_gpu) xg_cpu rtol = 1f-3 atol = 1f-3
65-
end
66-
end
67-
end
68-
@test gs_gpu isa Flux.Zygote.Grads
69-
for (p_cpu, p_gpu) in zip(ps_cpu, ps_gpu)
70-
if isnothing(gs_cpu[p_cpu])
71-
@test isnothing(gs_gpu[p_gpu])
72-
else
73-
@test gs_gpu[p_gpu] isa CuArray
74-
if test_cpu
75-
@test Array(gs_gpu[p_gpu]) gs_cpu[p_cpu] rtol=1f-3 atol=1f-3
76-
end
77-
end
78-
end
79-
end
29+
test_gradients(l_cpu, x_cpu; test_gpu=true, compare_finite_diff=false, test_grad_x, atol, rtol)
8030
end
8131
end
8232
end
@@ -97,23 +47,24 @@ for act in ACTIVATIONS
9747
ConvTranspose, ConvTransposeNoBias,
9848
CrossCor, CrossCorNoBias,
9949
DepthwiseConv, DepthwiseConvNoBias]
100-
gpu_gradtest("Convolution with $act", conv_layers, r, (2,2), 1=>3, act, test_cpu = false)
50+
gpu_gradtest("Convolution with $act", conv_layers, r, (2,2), 1=>3, act)
10151

10252
groupedconv = [GroupedConv, GroupedConvTranspose]
103-
gpu_gradtest("GroupedConvolution with $act", groupedconv, rand(Float32, 28, 28, 100, 2), (3,3), 100 => 25, act, test_cpu = true)
53+
gpu_gradtest("GroupedConvolution with $act", groupedconv, rand(Float32, 28, 28, 100, 2), (3,3), 100 => 25, act)
10454

10555
batch_norm = [BatchNorm, BatchNormNoTrackStats]
106-
gpu_gradtest("BatchNorm 1 with $act", batch_norm, rand(Float32, 28,28,3,4), 3, act, test_cpu = false) #TODO fix errors
107-
gpu_gradtest("BatchNorm 2 with $act", batch_norm, rand(Float32, 5,4), 5, act, test_cpu = true)
56+
gpu_gradtest("BatchNorm 1 with $act", batch_norm, rand(Float32, 28,28,3,4), 3, act, atol=1e-3)
57+
gpu_gradtest("BatchNorm 2 with $act", batch_norm, rand(Float32, 5,4), 5, act, atol=1e-3)
10858

10959
batch_norm = [BatchNormNoTrackStats]
110-
gpu_gradtest("BatchNorm 3 with $act (test mode)", batch_norm, rand(Float32, 5,4), 5, act, test_cpu = true, test_mode = true)
60+
gpu_gradtest("BatchNorm 3 with $act (test mode)", batch_norm, rand(Float32, 5,4), 5, act,
61+
test_mode=true, atol=1e-3)
11162

11263
instancenorm = [InstanceNorm]
113-
gpu_gradtest("InstanceNorm with $act", instancenorm, r, 1, act, test_cpu = false)
64+
gpu_gradtest("InstanceNorm with $act", instancenorm, r, 1, act)
11465

11566
groupnorm = [GroupNorm]
116-
gpu_gradtest("GroupNorm with $act", groupnorm, rand(Float32, 28,28,3,1), 3, 1, act, test_cpu = false)
67+
gpu_gradtest("GroupNorm with $act", groupnorm, rand(Float32, 28,28,3,1), 3, 1, act)
11768
end
11869

11970
r = rand(Float32, 28, 28, 1, 1)
@@ -122,13 +73,13 @@ pooling_layers = [MaxPool, MeanPool]
12273
gpu_gradtest("Pooling", pooling_layers, r, (2,2))
12374

12475
adaptive_pooling_layers = [AdaptiveMaxPool, AdaptiveMeanPool]
125-
gpu_gradtest("AdaptivePooling", adaptive_pooling_layers, r, (7,7), test_cpu = false)
76+
gpu_gradtest("AdaptivePooling", adaptive_pooling_layers, r, (7,7))
12677

12778
dropout_layers = [Dropout, AlphaDropout]
128-
gpu_gradtest("Dropout", dropout_layers, r, 0.5f0; test_cpu = false) # dropout is not deterministic
79+
gpu_gradtest("Dropout", dropout_layers, r, 1e-6) # dropout is not deterministic
12980

13081
layer_norm = [LayerNorm]
131-
gpu_gradtest("LayerNorm 1", layer_norm, rand(Float32, 28,28,3,4), 28, test_cpu = false) #TODO fix errors
82+
gpu_gradtest("LayerNorm 1", layer_norm, rand(Float32, 28,28,3,4), 28)
13283
gpu_gradtest("LayerNorm 2", layer_norm, rand(Float32, 5,4), 5)
13384

13485
upsample = [x -> Upsample(scale=x)]
@@ -140,32 +91,27 @@ gpu_gradtest("PixelShuffle 2d", pixelshuffle, rand(Float32, 3, 4, 18, 3), 3)
14091
gpu_gradtest("PixelShuffle 1d", pixelshuffle, rand(Float32, 3, 18, 3), 3)
14192

14293
embedding = [Flux.Embedding]
143-
gpu_gradtest("Embedding", embedding, [1,3,5], 5, 2)
144-
gpu_gradtest("Embedding repeated indices", embedding, [1,3,5,3], 5, 2)
145-
gpu_gradtest("Embedding integer index", embedding, 1, 5, 2)
146-
gpu_gradtest("Embedding 2d index", embedding, [1 2; 3 4], 5, 2)
147-
gpu_gradtest("Embedding OneHotVec index", embedding, OneHotVector(1, 5), 5, 2)
148-
gpu_gradtest("Embedding OneHotMatrix index", embedding, OneHotMatrix([1,2,3], 5), 5, 2)
149-
gpu_gradtest("Embedding OneHotMatrix repeated indices", embedding, OneHotMatrix([1,2,2], 5), 5, 2)
94+
gpu_gradtest("Embedding", embedding, [1,3,5], 5, 2, test_grad_x=false)
95+
gpu_gradtest("Embedding repeated indices", embedding, [1,3,5,3], 5, 2, test_grad_x=false)
96+
gpu_gradtest("Embedding integer index", embedding, 1, 5, 2, test_grad_x=false)
97+
gpu_gradtest("Embedding 2d index", embedding, [1 2; 3 4], 5, 2, test_grad_x=false)
98+
gpu_gradtest("Embedding OneHotVec index", embedding, OneHotVector(1, 5), 5, 2, test_grad_x=false)
99+
gpu_gradtest("Embedding OneHotMatrix index", embedding, OneHotMatrix([1,2,3], 5), 5, 2, test_grad_x=false)
100+
gpu_gradtest("Embedding OneHotMatrix repeated indices", embedding, OneHotMatrix([1,2,2], 5), 5, 2, test_grad_x=false)
150101

151102
@testset "function layers" begin
152-
x = rand(Float32, 3,3)
153-
gpu_autodiff_test(x -> sum(Flux.normalise(x; dims=1)), x)
154-
gpu_autodiff_test(x -> sum(Flux.normalise(x; dims=2)), x)
155-
gpu_autodiff_test(x -> sum(Flux.normalise(x)), x)
103+
x = rand(Float32, 3, 3)
104+
test_gradients(x -> sum(Flux.normalise(x; dims=1)), x, test_gpu=true, compare_finite_diff=false)
105+
test_gradients(x -> sum(Flux.normalise(x; dims=2)), x, test_gpu=true, compare_finite_diff=false)
106+
test_gradients(x -> sum(Flux.normalise(x)), x, test_gpu=true, compare_finite_diff=false)
156107
end
157108

158109
@testset "Zeros mapped for $cl" for cl in (Conv, ConvTranspose, CrossCor, DepthwiseConv)
159110
l = cl((2,2), 1=>3, bias = false) |> gpu
160111
ip = zeros(Float32, 28,28,1,1) |> gpu
161-
if typeof(l) <: BROKEN_LAYERS
162-
@test_broken sum(l(ip)) 0.f0
163-
@test_broken gradient(() -> sum(l(ip)), Flux.params(l)) isa Flux.Zygote.Grads
164-
else
165-
@test sum(l(ip)) 0.f0
166-
gs = gradient(() -> sum(l(ip)), Flux.params(l))
167-
@test l.bias gs.params
168-
end
112+
@test sum(l(ip)) 0.f0
113+
gs = gradient(() -> sum(l(ip)), Flux.params(l))
114+
@test l.bias gs.params
169115
end
170116

171117
@testset "Dense without bias" begin
@@ -366,14 +312,6 @@ end
366312
@test Array(y_gpu) y_cpu atol=1e-4
367313
@test Array(α_gpu) α_cpu atol=1e-4
368314

369-
gm_cpu, gx_cpu = gradient(mha_cpu, x_cpu) do mha, x
370-
y, α = mha(x)
371-
return sum(y.^2) + sum.^2)
372-
end
373-
gm_gpu, gx_gpu = gradient(mha_gpu, x_gpu) do mha, x
374-
y, α = mha(x)
375-
return sum(y.^2) + sum.^2)
376-
end
377-
check_grad(gm_gpu, gm_cpu)
378-
check_grad(gx_gpu, gx_cpu)
315+
test_gradients(mha_cpu, x_cpu, loss = o -> sum(o[1].^2) + sum(o[2].^2),
316+
test_gpu=true, compare_finite_diff=false)
379317
end

test/ext_cuda/losses.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,12 @@ y = [1 0 0 0 1
2727
@test focal_loss(x, y) focal_loss(gpu(x), gpu(y))
2828

2929
@testset "GPU: $loss" for loss in ALL_LOSSES
30-
x = rand(Float32, 3,4)
31-
y = rand(Float32, 3,4)
30+
# let's stay far from the boundaries to avoid problems with finite differences gradients
31+
x = 0.1f0 .+ 0.8f0 .* rand(Float32, 3, 4)
32+
y = 0.1f0 .+ 0.8f0 .* rand(Float32, 3, 4)
3233
@test loss(x, y) loss(gpu(x), gpu(y))
3334

34-
gpu_autodiff_test(loss, x, y)
35+
test_gradients(loss, x, y, test_gpu=true, test_grad_f=false, compare_finite_diff=false)
3536

3637
# Float16 tests
3738
@test loss(f16(x), f16(y)) loss(gpu(f16(x)), gpu(f16(y)))

0 commit comments

Comments
 (0)