@@ -7,9 +7,33 @@ ACTIVATION_FUNCTIONS =
7
7
@testset " bias_act!" begin
8
8
x = randn (3 ,4 )
9
9
b = randn (3 )
10
- @test bias_act! (identity, copy (x), b) ≈ (x .+ b)
11
- @test bias_act! (relu, copy (x), b) ≈ relu .(x .+ b)
12
- @test bias_act! (tanh, copy (x), b) ≈ tanh .(x .+ b)
10
+ @test @inferred (bias_act! (identity, x, false )) === x # pass-through
11
+ @test @inferred (bias_act! (identity, copy (x), b)) ≈ (x .+ b)
12
+ @test @inferred (bias_act! (relu, copy (x), b)) ≈ relu .(x .+ b)
13
+ @test @inferred (bias_act! (tanh, copy (x), b)) ≈ tanh .(x .+ b)
14
+ @test @inferred (bias_act! (tanh, copy (x), false )) ≈ tanh .(x)
15
+
16
+ # Check that it does overwrite:
17
+ x32 = rand (Float32, 3 , 4 )
18
+ x32copy = copy (x32)
19
+ @test @inferred (bias_act! (cbrt, x32, b)) ≈ cbrt .(x32copy .+ b)
20
+ @test x32 ≈ cbrt .(x32copy .+ b)
21
+ x32 = rand (Float32, 3 , 4 )
22
+ x32copy = copy (x32)
23
+ @test @inferred (bias_act! (tanh, x32, false )) ≈ tanh .(x32copy)
24
+ @test x32 ≈ tanh .(x32copy)
25
+
26
+ # Check that it doesn't try to overwrite non-float arrays:
27
+ xint = rand (- 3 : 3 , 3 , 4 )
28
+ bint = rand (- 2 : 2 , 3 )
29
+ @test bias_act! (identity, copy (xint), bint) ≈ xint .+ bint
30
+ @test bias_act! (tanh, copy (xint), bint) ≈ tanh .(xint .+ bint)
31
+ @test bias_act! (tanh, copy (xint), false ) ≈ tanh .(xint)
32
+
33
+ # Reject bias===true so that Bool means one thing:
34
+ @test_throws Exception bias_act! (identity, rand (3 ), true )
35
+ @test_throws Exception bias_act! (cbrt, rand (3 ), true )
36
+ @test_throws Exception bias_act! (cbrt, rand (1 : 3 , 3 ), true )
13
37
14
38
@testset " gradient with $fun " for fun in vcat ([identity, tanh, cbrt],
15
39
ACTIVATION_FUNCTIONS,
@@ -21,9 +45,21 @@ ACTIVATION_FUNCTIONS =
21
45
@test bias_act! (fun, copy (x), false ) ≈ fun .(x)
22
46
23
47
gx = ForwardDiff. gradient (x -> sum (bias_act! (fun, copy (x), b)), x)
48
+ gxplus = ForwardDiff. gradient (x -> sum (bias_act! (fun, copy (x), b)), x .+ eps ())
49
+ gxminus = ForwardDiff. gradient (x -> sum (bias_act! (fun, copy (x), b)), x .- eps ())
50
+ if ! (gx ≈ gxplus ≈ gxminus)
51
+ @warn " skipping gradient tests due to discontinuity" fun x b
52
+ continue
53
+ end
24
54
@test gx ≈ Zygote. gradient (x -> sum (bias_act! (fun, copy (x), b)), x)[1 ]
25
55
26
56
gx2 = ForwardDiff. gradient (x -> sum (bias_act! (fun, copy (x), false )), x)
57
+ gx2plus = ForwardDiff. gradient (x -> sum (bias_act! (fun, copy (x), false )), x .- eps ())
58
+ gx2minus = ForwardDiff. gradient (x -> sum (bias_act! (fun, copy (x), false )), x .- eps ())
59
+ if ! (gx2 ≈ gx2plus ≈ gx2minus)
60
+ @warn " skipping gradient tests due to discontinuity" fun x
61
+ continue
62
+ end
27
63
@test gx2 ≈ Zygote. gradient (x -> sum (bias_act! (fun, copy (x), false )), x)[1 ]
28
64
29
65
gb = ForwardDiff. gradient (b -> sum (bias_act! (fun, copy (x), b)), b)
0 commit comments