Skip to content

Commit a39d91b

Browse files
committed
Refactor to use 'cpu' & 'device' functions
1 parent f30fb37 commit a39d91b

File tree

2 files changed

+33
-32
lines changed

2 files changed

+33
-32
lines changed

test/runtests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ function nnlib_testsuite(Backend; skip_tests = Set{String}())
3535
end
3636

3737
@testset "NNlib.jl" verbose=true begin
38-
@testset "Test Suite" begin
38+
@testset verbose=true "Test Suite" begin
3939
@testset "CPU" begin
4040
nnlib_testsuite(CPU)
4141
end

test/upsample.jl

Lines changed: 32 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,26 @@
11
function upsample_testsuite(Backend)
2-
cpu, backend = CPU(), Backend()
2+
cpu(x) = adapt(CPU(), x)
3+
device(x) = adapt(Backend(), x)
4+
gradtest_fn = KernelAbstractions.isgpu(Backend()) ? gputest : gradtest
35
T = Float32 # TODO test against all supported eltypes for each backend.
46
atol = T == Float32 ? 1e-3 : 1e-6
5-
gradtest_fn = backend == CPU() ? gradtest : gputest
67

78
@testset "upsample_nearest, integer scale via reshape" begin
8-
x = adapt(backend, reshape(T[1 2; 3 4], (2,2,1,1)))
9-
@test adapt(cpu, upsample_nearest(x, (3,3)))[1,:] == [1,1,1, 2,2,2]
9+
x = device(reshape(T[1 2; 3 4], (2,2,1,1)))
10+
@test cpu(upsample_nearest(x, (3,3)))[1,:] == [1,1,1, 2,2,2]
1011

1112
y = upsample_nearest(x, (2,3))
1213
@test size(y) == (4,6,1,1)
1314
y2 = upsample_nearest(x, size=(4,6))
14-
@test adapt(cpu, y) adapt(cpu, y2)
15+
@test cpu(y) cpu(y2)
1516

16-
@test adapt(cpu, ∇upsample_nearest(y, (2,3)))[:, :, 1, 1] == [6 12; 18 24]
17+
@test cpu(∇upsample_nearest(y, (2,3)))[:, :, 1, 1] == [6 12; 18 24]
1718
gradtest_fn(
1819
x -> upsample_nearest(x, (2,3)),
19-
adapt(backend, rand(T, 2,2,1,1)); atol)
20+
device(rand(T, 2,2,1,1)); atol)
2021
gradtest_fn(
2122
x -> upsample_nearest(x, size=(4,6)),
22-
adapt(backend, rand(T, 2,2,1,1)); atol)
23+
device(rand(T, 2,2,1,1)); atol)
2324

2425
@test_throws ArgumentError ∇upsample_nearest(y, (2,4))
2526
@test_throws ArgumentError upsample_nearest(x, (1,2,3,4,5))
@@ -33,9 +34,9 @@ function upsample_testsuite(Backend)
3334
y = collect(1:1//3:4)
3435
y = hcat(y,y,y)[:,:,:]
3536

36-
xd = adapt(backend, x)
37-
@test y adapt(cpu, upsample_linear(xd, 2.5))
38-
@test y adapt(cpu, upsample_linear(xd; size=10))
37+
xd = device(x)
38+
@test y cpu(upsample_linear(xd, 2.5))
39+
@test y cpu(upsample_linear(xd; size=10))
3940
gradtest_fn(x -> upsample_linear(x, 2.5), xd; atol)
4041
end
4142

@@ -56,18 +57,18 @@ function upsample_testsuite(Backend)
5657
y_true = cat(y_true, y_true; dims=3)
5758
y_true = cat(y_true, y_true; dims=4)
5859

59-
xd = adapt(backend, x)
60+
xd = device(x)
6061
y = upsample_bilinear(xd, (3, 2))
6162
@test size(y) == size(y_true)
6263
@test eltype(y) == Float32
63-
@test adapt(cpu, y) y_true
64+
@test cpu(y) y_true
6465

6566
gradtest_fn(x -> upsample_bilinear(x, (3, 2)), xd; atol)
6667

6768
# additional grad check, also compliant with pytorch
6869
o = ones(Float32,6,4,2,1)
6970
grad_true = 6*ones(Float32,2,2,2,1)
70-
@test adapt(cpu, ∇upsample_bilinear(adapt(backend, o); size = (2,2))) grad_true
71+
@test cpu(∇upsample_bilinear(device(o); size = (2,2))) grad_true
7172

7273
# CPU only tests.
7374

@@ -110,7 +111,7 @@ function upsample_testsuite(Backend)
110111
y_true[:,:,4,:,:] .= 2.5
111112
y_true[:,:,5,:,:] .= 3.
112113

113-
xd = adapt(backend, x)
114+
xd = device(x)
114115
y = upsample_trilinear(xd; size=(5,5,5))
115116

116117
@test size(y) == size(y_true)
@@ -122,9 +123,9 @@ function upsample_testsuite(Backend)
122123
atol=(T == Float32) ? 1e-2 : 1e-5)
123124

124125
# This test only works when `align_corners=false`.
125-
o = adapt(backend, ones(Float32,8,8,8,1,1))
126+
o = device(ones(Float32,8,8,8,1,1))
126127
grad_true = 8 * ones(Float32,4,4,4,1,1)
127-
@test adapt(cpu, ∇upsample_trilinear(o; size=(4,4,4), align_corners=false)) grad_true
128+
@test cpu(∇upsample_trilinear(o; size=(4,4,4), align_corners=false)) grad_true
128129
end
129130

130131
@testset "pixel_shuffle" begin
@@ -147,9 +148,9 @@ function upsample_testsuite(Backend)
147148
2 10 4 12
148149
6 14 8 16][:,:,:,:]
149150

150-
y = pixel_shuffle(adapt(backend, x), 2)
151+
y = pixel_shuffle(device(x), 2)
151152
@test size(y) == size(y_true)
152-
@test y_true == adapt(cpu, y)
153+
@test y_true == cpu(y)
153154

154155
x = reshape(1:32, (2, 2, 8, 1))
155156
y_true = zeros(Int, 4, 4, 2, 1)
@@ -163,28 +164,28 @@ function upsample_testsuite(Backend)
163164
18 26 20 28
164165
22 30 24 32]
165166

166-
y = pixel_shuffle(adapt(backend, x), 2)
167+
y = pixel_shuffle(device(x), 2)
167168
@test size(y) == size(y_true)
168-
@test y_true == adapt(cpu, y)
169+
@test y_true == cpu(y)
169170

170171
x = reshape(1:4*3*27*2, (4,3,27,2))
171-
y = pixel_shuffle(adapt(backend, x), 3)
172+
y = pixel_shuffle(device(x), 3)
172173
@test size(y) == (12, 9, 3, 2)
173174

174175
# batch dimension is preserved
175176
x1 = x[:,:,:,[1]]
176177
x2 = x[:,:,:,[2]]
177-
y1 = pixel_shuffle(adapt(backend, x1), 3)
178-
y2 = pixel_shuffle(adapt(backend, x2), 3)
179-
@test adapt(cpu, cat(y1, y2, dims=4)) == adapt(cpu, y)
178+
y1 = pixel_shuffle(device(x1), 3)
179+
y2 = pixel_shuffle(device(x2), 3)
180+
@test cpu(cat(y1, y2, dims=4)) == cpu(y)
180181

181182
for d in [1, 2, 3]
182183
r = rand(1:5)
183184
n = rand(1:5)
184185
c = rand(1:5)
185186
insize = rand(1:5, d)
186187
x = rand(insize..., r^d*c, n)
187-
xd = adapt(backend, x)
188+
xd = device(x)
188189

189190
y = pixel_shuffle(xd, r)
190191
@test size(y) == ((r .* insize)..., c, n)
@@ -195,19 +196,19 @@ function upsample_testsuite(Backend)
195196
@testset "Complex-valued upsample" begin
196197
for (d, method) in zip([1, 2, 3], [upsample_linear, upsample_bilinear, upsample_trilinear])
197198
for (k, interp) in zip((2, ntuple(_ -> 2, d)), [method, upsample_nearest])
198-
x = adapt(backend, randn(Complex{Float32}, (4,8,12)[1:d]..., 1, 1))
199+
x = device(randn(Complex{Float32}, (4,8,12)[1:d]..., 1, 1))
199200

200201
upsize = (8, 16, 24)[1:d]
201202
xup = interp(x, k)
202203
@test size(xup)[1:d] == upsize
203-
@test adapt(cpu, real(xup)) == adapt(cpu, interp(real(x), k))
204-
@test adapt(cpu, imag(xup)) == adapt(cpu, interp(imag(x), k))
204+
@test cpu(real(xup)) == cpu(interp(real(x), k))
205+
@test cpu(imag(xup)) == cpu(interp(imag(x), k))
205206

206207
upsize = (8,24,48)[1:d]
207208
xup = interp(x; size=upsize)
208209
@test size(xup)[1:d] == upsize
209-
@test adapt(cpu, real(xup)) == adapt(cpu, interp(real(x), size=upsize))
210-
@test adapt(cpu, imag(xup)) == adapt(cpu, interp(imag(x), size=upsize))
210+
@test cpu(real(xup)) == cpu(interp(real(x), size=upsize))
211+
@test cpu(imag(xup)) == cpu(interp(imag(x), size=upsize))
211212
end
212213
end
213214
end

0 commit comments

Comments
 (0)