Skip to content

Commit 483f20c

Browse files
committed
more tests
1 parent 78e401f commit 483f20c

File tree

1 file changed

+15
-8
lines changed

1 file changed

+15
-8
lines changed

test/bias_act.jl

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using NNlib, Zygote, Test
1+
using NNlib, Zygote, ChainRulesCore, Test
22
using Zygote: ForwardDiff
33

44
ACTIVATION_FUNCTIONS =
@@ -14,14 +14,21 @@ ACTIVATION_FUNCTIONS =
1414
@test @inferred(bias_act!(tanh, copy(x), false)) tanh.(x)
1515

1616
# Check that it does overwrite:
17-
x32 = rand(Float32, 3, 4)
18-
x32copy = copy(x32)
17+
x32 = rand(Float32, 3, 4); x32copy = copy(x32)
1918
@test @inferred(bias_act!(cbrt, x32, b)) cbrt.(x32copy .+ b)
20-
@test x32 cbrt.(x32copy .+ b)
21-
x32 = rand(Float32, 3, 4)
22-
x32copy = copy(x32)
19+
@test x32 cbrt.(x32copy .+ b)
20+
21+
x32 = rand(Float32, 3, 4); x32copy = copy(x32) # without bias
2322
@test @inferred(bias_act!(tanh, x32, false)) tanh.(x32copy)
24-
@test x32 tanh.(x32copy)
23+
@test x32 tanh.(x32copy)
24+
25+
x32 = rand(Float32, 3, 4); x32copy = copy(x32) # now check gradient rule
26+
y, back = rrule(Zygote.ZygoteRuleConfig(), bias_act!, relu, x32, b)
27+
@test y x32 relu.(x32copy .+ b)
28+
29+
x32 = rand(Float32, 3, 4); x32copy = copy(x32) # without bias
30+
y, back = rrule(Zygote.ZygoteRuleConfig(), bias_act!, relu, x32, false)
31+
@test y x32 relu.(x32copy)
2532

2633
# Check that it doesn't try to overwrite non-float arrays:
2734
xint = rand(-3:3, 3, 4)
@@ -78,7 +85,7 @@ ACTIVATION_FUNCTIONS =
7885
g2 = ForwardDiff.gradient(x) do x
7986
sum(abs2, Zygote.gradient(x -> sum(abs2, bias_act!(relu, copy(x), b)), x)[1])
8087
end
81-
@test_broken gx Zygote.gradient(x) do x
88+
@test_skip gx Zygote.gradient(x) do x # Here global variable b causes an error
8289
sum(abs2,Zygote. gradient(x -> sum(abs2, bias_act!(relu, copy(x), b)), x)[1])
8390
end
8491
# Can't differentiate foreigncall expression $(Expr(:foreigncall, :(:jl_eqtable_get), Any, svec(Any, Any, Any), 0, :(:ccall), %5, %3, %4)).

0 commit comments

Comments
 (0)