1
1
using NNlib: DenseConvDims
2
2
3
3
@testset " convolution" begin
4
- a, b, c = rand (Float64, 10 , 10 , 3 , 1 ), rand (Float64, 2 , 2 , 3 , 4 ), rand (Float64, 9 , 9 , 4 , 1 )
4
+ @testset " $T " for T in (Float64, ComplexF64)
5
+ a, b, c = rand (T, 10 , 10 , 3 , 1 ), rand (T, 2 , 2 , 3 , 4 ), rand (T, 9 , 9 , 4 , 1 )
5
6
da, db, dc = CuArray (a), CuArray (b), CuArray (c)
6
7
cdims = DenseConvDims (a, b)
7
8
@test NNlib. conv (a, b, cdims) ≈ collect (NNlib. conv (da, db, cdims))
8
9
@test ∇conv_data (c, b, cdims) ≈ collect (∇conv_data (dc, db, cdims))
9
10
@test ∇conv_filter (a, c, cdims) ≈ collect (∇conv_filter (da, dc, cdims))
10
11
12
+ if T <: Complex
13
+ @testset " mixed real and complex" begin
14
+ @test NNlib. conv (real (a), b, cdims) ≈ collect (NNlib. conv (real (da), db, cdims))
15
+ @test NNlib. conv (a, real (b), cdims) ≈ collect (NNlib. conv (da, real (db), cdims))
16
+ @test ∇conv_data (c, real (b), cdims) ≈ collect (∇conv_data (dc, real (db), cdims))
17
+ @test ∇conv_filter (real (a), c, cdims) ≈ collect (∇conv_filter (real (da), dc, cdims))
18
+ end
19
+ end
20
+
11
21
# Test Conv Bias Activation
12
- bias = rand (Float64 , 1 , 1 , 4 , 1 )
22
+ bias = rand (T , 1 , 1 , 4 , 1 )
13
23
dbias = CuArray (bias)
14
- @test conv_bias_act (a, b, cdims, bias, NNlib. relu) ≈ collect (conv_bias_act (da, db, cdims, dbias, NNlib. relu))
24
+ act = T <: Complex ? abs2 : NNlib. relu
25
+ @test conv_bias_act (a, b, cdims, bias, act) ≈ collect (conv_bias_act (da, db, cdims, dbias, act))
15
26
@test conv_bias_act (a, b, cdims, bias, identity) ≈ collect (conv_bias_act (da, db, cdims, dbias, identity))
16
27
17
28
# Test for agreement between CPU NNlib and CuDNN versions, across a variety of kwargs
@@ -26,16 +37,20 @@ using NNlib: DenseConvDims
26
37
C_out = 4
27
38
batch_size = 1
28
39
29
- for groups in (1 , 2 , 4 ), num_spatial_dims in (1 , 2 , 3 )
40
+ # we use this activation for the gpu tests
41
+ # as we can't take gradients of complex quantities
42
+ act = T <: Complex ? x-> abs2 (x) : identity
43
+ @testset " groups=$groups , num_spatial_dims=$num_spatial_dims " for groups in (1 , 2 , 4 ), num_spatial_dims in (1 , 2 , 3 )
30
44
# Make `C_in = C_out` when using grouped convolution.
31
45
C_in = groups == 1 ? C_in_ : C_out
32
46
# Initialize data we'll run our tests over
33
- x = rand (Float64 , fill (8 , num_spatial_dims)... , C_in, batch_size)
34
- w = rand (Float64 , fill (2 , num_spatial_dims)... , C_in ÷ groups, C_out)
47
+ x = rand (T , fill (8 , num_spatial_dims)... , C_in, batch_size)
48
+ w = rand (T , fill (2 , num_spatial_dims)... , C_in ÷ groups, C_out)
35
49
36
- for opts in options
50
+ @testset " opts # $i " for (i, opts) in enumerate ( options)
37
51
opts[:groups ] = groups
38
-
52
+
53
+
39
54
if :padding in keys (opts)
40
55
padding = opts[:padding ]
41
56
if 1 < length (padding) && length (padding) != 2 num_spatial_dims
@@ -47,18 +62,56 @@ using NNlib: DenseConvDims
47
62
y = NNlib. conv (x, w, cdims)
48
63
49
64
# Test that basic convolution is equivalent across GPU/CPU
50
- gputest ((x, w) -> NNlib. conv (x, w, cdims), x, w)
51
- gputest ((y, w) -> NNlib.∇conv_data (y, w, cdims), y, w)
52
- gputest ((x, y) -> NNlib.∇conv_filter (x, y, cdims), x, y, checkgrad= false ) # TODO fix grad
65
+ @testset " cpu==gpu" begin
66
+ @testset " conv" begin
67
+ gputest ((x, w) -> act .(NNlib. conv (x, w, cdims)), x, w)
68
+ if T <: Complex
69
+ gputest ((x, w) -> act .(NNlib. conv (x, w, cdims)), real (x), w)
70
+ gputest ((x, w) -> act .(NNlib. conv (x, w, cdims)), x, real (w))
71
+ end
72
+ end
73
+ @testset " ∇conv_data" begin
74
+ gputest ((y, w) -> act .(NNlib.∇conv_data (y, w, cdims)), y, w)
75
+ if T <: Complex
76
+ gputest ((y, w) -> act .(NNlib.∇conv_data (y, w, cdims)), y, real (w))
77
+ end
78
+ end
79
+ @testset " ∇conv_filter" begin
80
+ gputest ((x, y) -> act .(NNlib.∇conv_filter (x, y, cdims)), x, y)
81
+ if T <: Complex
82
+ gputest ((x, y) -> act .(NNlib.∇conv_filter (x, y, cdims)), real (x), y)
83
+ end
84
+ end
85
+ end
53
86
54
87
# Scaling factors
55
- gputest ((x, w) -> NNlib. conv (x, w, cdims; alpha= 2.0 ), x, w, checkgrad= false ) # TODO
56
- gputest ((y, w) -> NNlib.∇conv_data (y, w, cdims; alpha= 2.0 ), y, w, checkgrad= false ) # TODO
57
- gputest ((x, y) -> NNlib.∇conv_filter (x, y, cdims; alpha= 2.0 ), x, y, checkgrad= false ) # TODO
88
+ @testset " scale-alpha" begin
89
+ gputest ((x, w) -> act .(NNlib. conv (x, w, cdims; alpha= T (2.0 ))), x, w, checkgrad= false ) # TODO
90
+ gputest ((y, w) -> act .(NNlib.∇conv_data (y, w, cdims; alpha= T (2.0 ))), y, w, checkgrad= false ) # TODO
91
+ gputest ((x, y) -> act .(NNlib.∇conv_filter (x, y, cdims; alpha= T (2.0 ))), x, y, checkgrad= false ) # TODO
92
+
93
+ if T <: Complex
94
+ gputest ((x, w) -> act .(NNlib. conv (x, w, cdims; alpha= T (2.0 ))), real (x), w, checkgrad= false )
95
+ gputest ((x, w) -> act .(NNlib. conv (x, w, cdims; alpha= T (2.0 ))), x, real (w), checkgrad= false ) # TODO
96
+ gputest ((y, w) -> act .(NNlib.∇conv_data (y, w, cdims; alpha= T (2.0 ))), y, real (w), checkgrad= false ) # TODO
97
+ gputest ((x, y) -> act .(NNlib.∇conv_filter (x, y, cdims; alpha= T (2.0 ))), real (x), y, checkgrad= false ) # TODO
98
+ end
99
+ end
100
+
101
+ @testset " scale-beta" begin
102
+ gputest ((y, x, w) -> act .(NNlib. conv! (copy (y), x, w, cdims; beta= T (2.0 ))), y, x, w, checkgrad= false , broken= false )
103
+ gputest ((w, x, y) -> act .(NNlib.∇conv_filter! (copy (w), x, y, cdims; beta= T (2.0 ))), w, x, y, checkgrad= false , broken= false )
104
+ gputest ((x, y, w) -> act .(NNlib.∇conv_data! (copy (x), y, w, cdims; beta= T (2.0 ))), x, y, w, checkgrad= false , broken= true )
105
+
106
+ if T <: Complex
107
+ gputest ((y, x, w) -> act .(NNlib. conv! (copy (y), x, w, cdims; beta= T (2.0 ))), y, real (x), w, checkgrad= false )
108
+ gputest ((y, x, w) -> act .(NNlib. conv! (copy (y), x, w, cdims; beta= T (2.0 ))), y, x, real (w), checkgrad= false )
109
+ gputest ((x, y, w) -> act .(NNlib.∇conv_data! (copy (x), y, w, cdims; beta= T (2.0 ))), x, y, real (w), checkgrad= false )
110
+ gputest ((w, x, y) -> act .(NNlib.∇conv_filter! (copy (w), x, y, cdims; beta= T (2.0 ))), w, real (x), y, checkgrad= false )
111
+ end
112
+ end
58
113
59
- gputest ((y, x, w) -> NNlib. conv! (copy (y), x, w, cdims; beta= 2.0 ), y, x, w, checkgrad= false ) # TODO
60
- # @test_broken gputest((x, y, w) -> NNlib.∇conv_data!(copy(x), y, w, cdims; beta=2.0), x, y, w, checkgrad=false) #TODO
61
- gputest ((w, x, y) -> NNlib.∇conv_filter! (copy (w), x, y, cdims; beta= 2.0 ), w, x, y, checkgrad= false ) # TODO
62
114
end
63
115
end
64
116
end
117
+ end
0 commit comments