Skip to content

Commit 78e401f

Browse files
committed
upgrade tests
1 parent 6ecd9d2 commit 78e401f

File tree

2 files changed

+52
-7
lines changed

2 files changed

+52
-7
lines changed

src/bias_act.jl

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,14 +29,23 @@ contains only `Ω` (the output) not `x`.
2929
It is intended mainly for Flux layers, in which the previous operation is
3030
known to be safe, e.g. `bias_act!(σ, weight * input, bias)` for a `Dense` layer.
3131
"""
32-
bias_act!::Function, x::AbstractArray, b) = fast_act(σ, x).(x .+ b) # fallback
33-
3432
bias_act!::Function, x::StridedArray{<:AbstractFloat}, b::AbstractArray{<:Union{Bool, AbstractFloat}}) =
3533
_fast_broadcast!(fast_act(σ, x)(+), x, b) # works around a SIMD bug
3634

37-
bias_act!(::typeof(identity), x::StridedArray{<:AbstractFloat}, b::Bool) =
38-
(@assert !b "bias=true is not accepted; layer constructors shoud guarantee this"; x)
35+
function bias_act!::Function, x::StridedArray{<:AbstractFloat}, b::Bool)
36+
b === true && error("bias=true is not accepted; layer constructors shoud guarantee this")
37+
_fast_broadcast!(fast_act(σ, x), x)
38+
end
3939

40+
function bias_act!(::typeof(identity), x::StridedArray{<:AbstractFloat}, b::Bool)
41+
b === true && error("bias=true is not accepted; layer constructors shoud guarantee this")
42+
x # pass-through
43+
end
44+
45+
function bias_act!::Function, x::AbstractArray, b)
46+
b === true && error("bias=true is not accepted; layer constructors shoud guarantee this")
47+
fast_act(σ, x).(x .+ b) # fallback
48+
end
4049

4150
function ChainRulesCore.rrule(cfg::RCR, ::typeof(bias_act!), σ::F, x::AbstractArray{T,N}, b::B) where {F,T,N,B}
4251
biasgrad = if eltype(B) !== Bool

test/bias_act.jl

Lines changed: 39 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,33 @@ ACTIVATION_FUNCTIONS =
77
@testset "bias_act!" begin
88
x = randn(3,4)
99
b = randn(3)
10-
@test bias_act!(identity, copy(x), b) (x .+ b)
11-
@test bias_act!(relu, copy(x), b) relu.(x .+ b)
12-
@test bias_act!(tanh, copy(x), b) tanh.(x .+ b)
10+
@test @inferred(bias_act!(identity, x, false)) === x # pass-through
11+
@test @inferred(bias_act!(identity, copy(x), b)) (x .+ b)
12+
@test @inferred(bias_act!(relu, copy(x), b)) relu.(x .+ b)
13+
@test @inferred(bias_act!(tanh, copy(x), b)) tanh.(x .+ b)
14+
@test @inferred(bias_act!(tanh, copy(x), false)) tanh.(x)
15+
16+
# Check that it does overwrite:
17+
x32 = rand(Float32, 3, 4)
18+
x32copy = copy(x32)
19+
@test @inferred(bias_act!(cbrt, x32, b)) cbrt.(x32copy .+ b)
20+
@test x32 cbrt.(x32copy .+ b)
21+
x32 = rand(Float32, 3, 4)
22+
x32copy = copy(x32)
23+
@test @inferred(bias_act!(tanh, x32, false)) tanh.(x32copy)
24+
@test x32 tanh.(x32copy)
25+
26+
# Check that it doesn't try to overwrite non-float arrays:
27+
xint = rand(-3:3, 3, 4)
28+
bint = rand(-2:2, 3)
29+
@test bias_act!(identity, copy(xint), bint) xint .+ bint
30+
@test bias_act!(tanh, copy(xint), bint) tanh.(xint .+ bint)
31+
@test bias_act!(tanh, copy(xint), false) tanh.(xint)
32+
33+
# Reject bias===true so that Bool means one thing:
34+
@test_throws Exception bias_act!(identity, rand(3), true)
35+
@test_throws Exception bias_act!(cbrt, rand(3), true)
36+
@test_throws Exception bias_act!(cbrt, rand(1:3, 3), true)
1337

1438
@testset "gradient with $fun" for fun in vcat([identity, tanh, cbrt],
1539
ACTIVATION_FUNCTIONS,
@@ -21,9 +45,21 @@ ACTIVATION_FUNCTIONS =
2145
@test bias_act!(fun, copy(x), false) fun.(x)
2246

2347
gx = ForwardDiff.gradient(x -> sum(bias_act!(fun, copy(x), b)), x)
48+
gxplus = ForwardDiff.gradient(x -> sum(bias_act!(fun, copy(x), b)), x .+ eps())
49+
gxminus = ForwardDiff.gradient(x -> sum(bias_act!(fun, copy(x), b)), x .- eps())
50+
if !(gx gxplus gxminus)
51+
@warn "skipping gradient tests due to discontinuity" fun x b
52+
continue
53+
end
2454
@test gx Zygote.gradient(x -> sum(bias_act!(fun, copy(x), b)), x)[1]
2555

2656
gx2 = ForwardDiff.gradient(x -> sum(bias_act!(fun, copy(x), false)), x)
57+
gx2plus = ForwardDiff.gradient(x -> sum(bias_act!(fun, copy(x), false)), x .- eps())
58+
gx2minus = ForwardDiff.gradient(x -> sum(bias_act!(fun, copy(x), false)), x .- eps())
59+
if !(gx2 gx2plus gx2minus)
60+
@warn "skipping gradient tests due to discontinuity" fun x
61+
continue
62+
end
2763
@test gx2 Zygote.gradient(x -> sum(bias_act!(fun, copy(x), false)), x)[1]
2864

2965
gb = ForwardDiff.gradient(b -> sum(bias_act!(fun, copy(x), b)), b)

0 commit comments

Comments
 (0)