Skip to content

Commit 90a0043

Browse files
authored
Add bias_act! (#457)
* sometimes-in-place bias_act * update after dropout PR * add to docs * also fix two unrelated docstring which just told you what the function was called without explaining anything * tidy & un-comment * comment out 2nd path again * add Returns for 1.6 * upgrade tests * more tests * skip hardσ tests * Update test/bias_act.jl
1 parent 2b548b6 commit 90a0043

File tree

7 files changed

+287
-25
lines changed

7 files changed

+287
-25
lines changed

docs/src/reference.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ pad_zeros
7575
## Convolution
7676

7777
`Flux`'s `Conv` and `CrossCor` layers use `NNlib.DenseConvDims` and `NNlib.conv` internally.
78+
7879
`NNlib.conv` supports complex datatypes on CPU and CUDA devices.
7980

8081
!!! AMDGPU MIOpen supports only cross-correlation (flipkernel=true).
@@ -152,4 +153,5 @@ ctc_loss
152153
logsumexp
153154
NNlib.glu
154155
NNlib.within_gradient
156+
bias_act!
155157
```

src/NNlib.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,9 @@ export conv, conv!, ∇conv_data, ∇conv_data!, ∇conv_filter,
7272
include("conv_bias_act.jl")
7373
export conv_bias_act, conv_bias_act!
7474

75+
include("bias_act.jl")
76+
export bias_act!
77+
7578
include("fold.jl")
7679

7780
include("ctc.jl")

src/bias_act.jl

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
2+
using NNlib: fast_act, tanh_fast
3+
using ChainRulesCore
4+
5+
const RCR = RuleConfig{>:HasReverseMode}
6+
7+
# This just saves typing `only.(only.(` many times:
8+
@inline only_derivative(y,f::F,x) where F = only(only(ChainRulesCore.derivatives_given_output(y, f, x)))
9+
10+
# This has no methods, used for testing whether `derivatives_given_output(Ω, f, x)`
11+
# is independent of `x`, as `_return_type` says `Union{}` when calling is an error.
12+
struct NotaNumber <: Real end
13+
14+
"""
15+
bias_act!(σ, x, b)
16+
17+
This is equivalent to `x .= σ.(x .+ b)`, also replacing `sigmoid` & `tanh`
18+
with `sigmoid_fast` & `tanh_fast`.
19+
It will only overwrite `x` when `x isa StridedArray{<:AbstractFloat}`.
20+
21+
When used within a gradient, it will overwrite only when `σ` has
22+
a method of `derivatives_given_output` which does not need the input at all.
23+
Such methods are defined by e.g. `@scalar_rule relu(x) Ω > 0` where the derivative
24+
contains only `Ω` (the output) not `x`.
25+
26+
!!! warning
27+
This is not safe to use if `x` is still needed for the gradient
28+
of some other function. Incorrect use will give silently wrong answers.
29+
It is intended mainly for Flux layers, in which the previous operation is
30+
known to be safe, e.g. `bias_act!(σ, weight * input, bias)` for a `Dense` layer.
31+
"""
32+
bias_act!::Function, x::StridedArray{<:AbstractFloat}, b::AbstractArray{<:Union{Bool, AbstractFloat}}) =
33+
_fast_broadcast!(fast_act(σ, x)(+), x, b) # works around a SIMD bug
34+
35+
function bias_act!::Function, x::StridedArray{<:AbstractFloat}, b::Bool)
36+
b === true && error("bias=true is not accepted; layer constructors shoud guarantee this")
37+
_fast_broadcast!(fast_act(σ, x), x)
38+
end
39+
40+
function bias_act!(::typeof(identity), x::StridedArray{<:AbstractFloat}, b::Bool)
41+
b === true && error("bias=true is not accepted; layer constructors shoud guarantee this")
42+
x # pass-through
43+
end
44+
45+
function bias_act!::Function, x::AbstractArray, b)
46+
b === true && error("bias=true is not accepted; layer constructors shoud guarantee this")
47+
fast_act(σ, x).(x .+ b) # fallback
48+
end
49+
50+
function ChainRulesCore.rrule(cfg::RCR, ::typeof(bias_act!), σ::F, x::AbstractArray{T,N}, b::B) where {F,T,N,B}
51+
biasgrad = if eltype(B) !== Bool
52+
# Summing over ndims(x)+1 is a trick to make b_dims type-stable
53+
dims = ntuple(d -> size(b,d)==1 ? d : N+1, N)
54+
_biasgrad(dx) = reshape(sum(dx; dims), size(b))
55+
else
56+
Returns(NoTangent())
57+
end
58+
59+
# Fast path: it is now safe to overwrite x, since this is not needed for gradient of σ
60+
if isconcretetype(Core.Compiler._return_type(only_derivative, Tuple{T, F, NotaNumber}))
61+
Ω = bias_act!(σ, x, b) # now x === Ω, when x isa StridedArray{<:AbstractFloat}
62+
function bias_act!_fastback(Δ)
63+
# Tempting to overwrite x again, but only safe if you call pullback at most once,
64+
# TODO with e.g. https://github.com/FluxML/Zygote.jl/pull/1340
65+
# https://github.com/JuliaDiff/ChainRulesCore.jl/pull/592
66+
dx = only_derivative.(Ω, σ, NotaNumber()) .* unthunk(Δ)
67+
return (NoTangent(), NoTangent(), dx, biasgrad(dx))
68+
end
69+
return Ω, bias_act!_fastback
70+
71+
# # Slower path: can't overwrite x, but can use derivatives_given_output
72+
# # This case is WRONG and tests fail, but not sure why
73+
# elseif isconcretetype(Core.Compiler._return_type(only_derivative, Tuple{T, F, T}))
74+
# Ω2 = fast_act(σ, x).(x) .+ b
75+
# @show σ b
76+
# function bias_act!_back2(Δ)
77+
# dx = only_derivative.(Ω2, σ, x .+ b) .* unthunk(Δ)
78+
# return (NoTangent(), NoTangent(), dx, biasgrad(dx))
79+
# end
80+
# return Ω2, bias_act!_back2
81+
82+
# Fallback path: let AD handle the broadcast
83+
else
84+
Ω3, back = rrule_via_ad(cfg, broadcast, fast_act(σ, x), bias_act!(identity, x, b))
85+
@inline function bias_act!_slowback(Δ)
86+
_, _, dx = back(Δ)
87+
return (NoTangent(), NoTangent(), dx, biasgrad(dx))
88+
end
89+
return Ω3, bias_act!_slowback
90+
end
91+
end
92+
93+
# Two easy cases with identity
94+
function rrule(cfg::RCR, ::typeof(bias_act!), ::typeof(identity), x::AbstractArray{T,N}, b::B) where {T,N,B}
95+
dims = ntuple(d -> size(b,d)==1 ? d : N+1, N)
96+
biasgrad(dx) = reshape(sum(dx; dims), size(b))
97+
function bias_act!_idback(Δ)
98+
dx = unthunk(Δ)
99+
return (NoTangent(), NoTangent(), dx, biasgrad(dx))
100+
end
101+
return bias_act!(identity, x, b), bias_act!_idback
102+
end
103+
function rrule(cfg::RCR, ::typeof(bias_act!), ::typeof(identity), x::AbstractArray{T,N}, b::Bool) where {T,N}
104+
bias_act!_trivial(Δ) = (NoTangent(), NoTangent(), Δ, NoTangent())
105+
return x, bias_act!_trivial
106+
end
107+

src/dropout.jl

Lines changed: 0 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -125,27 +125,6 @@ end
125125
# and seems about as fast, but needs a method `copy(::CUDA.RNG)` and careful checking.
126126
# https://github.com/FluxML/NNlib.jl/pull/454#issuecomment-1369357402
127127

128-
"""
129-
_fast_broadcast!(f, x, y, z...)
130-
131-
This does `x .= f.(x, y, z...)`, but works around
132-
an issue with broadcasting that prevents SIMD in such cases.
133-
Can be removed once https://github.com/JuliaLang/julia/issues/43153 is fixed.
134-
135-
Not intended for general use. Does not check sizes!
136-
"""
137-
function _fast_broadcast!(f::F, x::Array, yz...) where {F<:Function}
138-
bc = Broadcast.instantiate(Broadcast.broadcasted(f, x, yz...))
139-
@simd ivdep for I in eachindex(bc)
140-
@inbounds x[I] = bc[I]
141-
end
142-
return x
143-
end
144-
function _fast_broadcast!(f::F, x::AbstractArray, yz...) where {F<:Function}
145-
# CUDA does not suffer from this bug
146-
broadcast!(f, x, x, yz...)
147-
end
148-
149128

150129
"""
151130
_rng_from_array(x)

src/utils.jl

Lines changed: 60 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -53,15 +53,21 @@ ChainRulesCore.rrule(::typeof(within_gradient), x) = true, _ -> (NoTangent(), No
5353
"""
5454
safe_div(x, y)
5555
56-
Safely divide `x` by `y`. If `y` is zero, return `x` directly.
56+
Returns `x/y` unless `y==0`, in which case it just returns `x`.
57+
(Used internally by `scatter`.)
5758
"""
5859
safe_div(x, y) = ifelse(iszero(y), x, x/y)
5960

6061
"""
6162
maximum_dims(dims)
6263
63-
Return the maximum value for each dimension. An array of dimensions `dims` is accepted.
64-
The maximum of each dimension in the element is computed.
64+
Given an array of `CartesianIndex{N}` or `NTuple{N,Int}`,
65+
returns a tuple containing the maximum of all the 1st entries,
66+
all the 2nd entries, and so on up to `N`.
67+
68+
Given an array of integers, returns `(maximum(dims),)`.
69+
70+
(These arguments are what [`scatter`](@ref NNlib.scatter) understands.)
6571
"""
6672
maximum_dims(dims::AbstractArray{<:Integer}) = (maximum(dims), )
6773
maximum_dims(dims::AbstractArray{NTuple{N, T}}) where {N,T} = ntuple(i -> maximum(x->x[i], dims), N)
@@ -105,4 +111,54 @@ function reverse_indices(idx::AbstractArray{<:Any,N}) where N
105111
return reverse_indices!(rev, idx)
106112
end
107113

108-
unsqueeze(x) = reshape(x, 1, size(x)...)
114+
unsqueeze(x) = reshape(x, 1, size(x)...)
115+
116+
117+
"""
118+
_fast_broadcast!(f, x, y, z...)
119+
120+
This does `x .= f.(x, y, z...)`, but works around
121+
an issue with broadcasting that prevents SIMD in such cases.
122+
Can perhaps be removed once https://github.com/JuliaLang/julia/issues/43153 is fixed.
123+
124+
Has an `rrule` to avoid mutation within derivatives.
125+
126+
!!! warning
127+
Not intended for general use.
128+
Uses `@inbounds` but does not check sizes!
129+
Assumes that `f` has no derivative!
130+
"""
131+
function _fast_broadcast!(f::F, x::Array, yz...) where {F<:Function}
132+
bc = Broadcast.instantiate(Broadcast.broadcasted(f, x, yz...))
133+
@simd ivdep for I in eachindex(bc)
134+
@inbounds x[I] = bc[I]
135+
end
136+
return x
137+
end
138+
function _fast_broadcast!(f::F, x::AbstractArray, yz...) where {F<:Function}
139+
# CUDA does not suffer from this bug
140+
broadcast!(f, x, x, yz...)
141+
end
142+
143+
function rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(_fast_broadcast!), f::F, x::AbstractArray, ys...) where {F<:Function}
144+
rrule_via_ad(cfg, broadcast, f, x, ys...)
145+
end
146+
147+
# Could get this from Compat.jl instead
148+
# https://github.com/JuliaLang/julia/pull/39794
149+
if VERSION < v"1.7.0-DEV.793"
150+
struct Returns{V} <: Function
151+
value::V
152+
Returns{V}(value) where {V} = new{V}(value)
153+
Returns(value) = new{Core.Typeof(value)}(value)
154+
end
155+
156+
(obj::Returns)(args...; kw...) = obj.value
157+
function Base.show(io::IO, obj::Returns)
158+
show(io, typeof(obj))
159+
print(io, "(")
160+
show(io, obj.value)
161+
print(io, ")")
162+
end
163+
end
164+

test/bias_act.jl

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
using NNlib, Zygote, ChainRulesCore, Test
2+
using Zygote: ForwardDiff
3+
4+
ACTIVATION_FUNCTIONS =
5+
[@eval($a) for a in NNlib.ACTIVATIONS]
6+
7+
@testset "bias_act!" begin
8+
x = randn(3,4)
9+
b = randn(3)
10+
@test @inferred(bias_act!(identity, x, false)) === x # pass-through
11+
@test @inferred(bias_act!(identity, copy(x), b)) (x .+ b)
12+
@test @inferred(bias_act!(relu, copy(x), b)) relu.(x .+ b)
13+
@test @inferred(bias_act!(tanh, copy(x), b)) tanh.(x .+ b)
14+
@test @inferred(bias_act!(tanh, copy(x), false)) tanh.(x)
15+
16+
# Check that it does overwrite:
17+
x32 = rand(Float32, 3, 4); x32copy = copy(x32)
18+
@test @inferred(bias_act!(cbrt, x32, b)) cbrt.(x32copy .+ b)
19+
@test x32 cbrt.(x32copy .+ b)
20+
21+
x32 = rand(Float32, 3, 4); x32copy = copy(x32) # without bias
22+
@test @inferred(bias_act!(tanh, x32, false)) tanh.(x32copy)
23+
@test x32 tanh.(x32copy)
24+
25+
x32 = rand(Float32, 3, 4); x32copy = copy(x32) # now check gradient rule
26+
y, back = rrule(Zygote.ZygoteRuleConfig(), bias_act!, relu, x32, b)
27+
@test y x32 relu.(x32copy .+ b)
28+
29+
x32 = rand(Float32, 3, 4); x32copy = copy(x32) # without bias
30+
y, back = rrule(Zygote.ZygoteRuleConfig(), bias_act!, relu, x32, false)
31+
@test y x32 relu.(x32copy)
32+
33+
# Check that it doesn't try to overwrite non-float arrays:
34+
xint = rand(-3:3, 3, 4)
35+
bint = rand(-2:2, 3)
36+
@test bias_act!(identity, copy(xint), bint) xint .+ bint
37+
@test bias_act!(tanh, copy(xint), bint) tanh.(xint .+ bint)
38+
@test bias_act!(tanh, copy(xint), false) tanh.(xint)
39+
40+
# Reject bias===true so that Bool means one thing:
41+
@test_throws Exception bias_act!(identity, rand(3), true)
42+
@test_throws Exception bias_act!(cbrt, rand(3), true)
43+
@test_throws Exception bias_act!(cbrt, rand(1:3, 3), true)
44+
45+
@testset "gradient with $fun" for fun in vcat([identity, tanh, cbrt],
46+
ACTIVATION_FUNCTIONS,
47+
[x->x, x -> 1/(x^2+2), x -> leakyrelu(x, 0.33)])
48+
# Only some of these go the fast path, `cbrt` is an example of a function NNlib knows nothing about.
49+
fun == rrelu && continue # this one is randomised!
50+
fun == hardσ && continue # this one has heisenbugs, not solved by discontinuity-avoidance code below
51+
52+
@test bias_act!(fun, copy(x), b) fun.(x .+ b)
53+
@test bias_act!(fun, copy(x), false) fun.(x)
54+
55+
gx = ForwardDiff.gradient(x -> sum(bias_act!(fun, copy(x), b)), x)
56+
gxplus = ForwardDiff.gradient(x -> sum(bias_act!(fun, copy(x), b)), x .+ eps())
57+
gxminus = ForwardDiff.gradient(x -> sum(bias_act!(fun, copy(x), b)), x .- eps())
58+
if !(gx gxplus gxminus)
59+
@warn "skipping gradient tests due to discontinuity" fun x b
60+
continue
61+
end
62+
@test gx Zygote.gradient(x -> sum(bias_act!(fun, copy(x), b)), x)[1]
63+
64+
gx2 = ForwardDiff.gradient(x -> sum(bias_act!(fun, copy(x), false)), x)
65+
gx2plus = ForwardDiff.gradient(x -> sum(bias_act!(fun, copy(x), false)), x .- eps())
66+
gx2minus = ForwardDiff.gradient(x -> sum(bias_act!(fun, copy(x), false)), x .- eps())
67+
if !(gx2 gx2plus gx2minus)
68+
@warn "skipping gradient tests due to discontinuity" fun x
69+
continue
70+
end
71+
@test gx2 Zygote.gradient(x -> sum(bias_act!(fun, copy(x), false)), x)[1]
72+
73+
gb = ForwardDiff.gradient(b -> sum(bias_act!(fun, copy(x), b)), b)
74+
@test gb Zygote.gradient(b -> sum(bias_act!(fun, copy(x), b)), b)[1]
75+
76+
@test Zygote.gradient(b -> sum(bias_act!(fun, copy(x), b)), false) == (nothing,)
77+
@test Zygote.gradient(b -> sum(bias_act!(fun, copy(x), b)), b .> 0) == (nothing,)
78+
end
79+
80+
@testset "gradient for fast_broadcast!" begin
81+
# Gradient definition is just to disable mutation inside 2nd order AD
82+
gx = ForwardDiff.gradient(x -> sum(NNlib._fast_broadcast!(cbrt(+), copy(x), b)), x)
83+
@test gx Zygote.gradient(x -> sum(NNlib._fast_broadcast!(cbrt(+), copy(x), b)), x)[1]
84+
85+
# relu should take the fast path
86+
g2 = ForwardDiff.gradient(x) do x
87+
sum(abs2, Zygote.gradient(x -> sum(abs2, bias_act!(relu, copy(x), b)), x)[1])
88+
end
89+
@test_skip gx Zygote.gradient(x) do x # Here global variable b causes an error
90+
sum(abs2,Zygote. gradient(x -> sum(abs2, bias_act!(relu, copy(x), b)), x)[1])
91+
end
92+
# Can't differentiate foreigncall expression $(Expr(:foreigncall, :(:jl_eqtable_get), Any, svec(Any, Any, Any), 0, :(:ccall), %5, %3, %4)).
93+
# [5] (::typeof(∂(accum_global)))(Δ::Nothing)
94+
@test g2 Zygote.gradient(x, b) do x, b
95+
sum(abs2, Zygote.gradient((x, b) -> sum(abs2, bias_act!(relu, copy(x), b)), x, b)[1])
96+
end[1]
97+
98+
g3 = ForwardDiff.gradient(x) do x
99+
sum(abs2, Zygote.gradient((x, b) -> sum(abs2, bias_act!(swish, copy(x), b)), x, b)[1])
100+
end
101+
@test g3 Zygote.gradient(x, b) do x, b
102+
sum(abs2, Zygote.gradient((x, b) -> sum(abs2, bias_act!(swish, copy(x), b)), x, b)[1])
103+
end[1]
104+
105+
# Anon function sure to take the generic path
106+
g4 = ForwardDiff.gradient(x) do x
107+
sum(abs2, Zygote.gradient((x, b) -> sum(abs2, bias_act!(y -> cbrt(y/3), copy(x), b)), x, b)[1])
108+
end
109+
@test g4 Zygote.gradient(x, b) do x, b
110+
sum(abs2, Zygote.gradient((x, b) -> sum(abs2, bias_act!(y -> cbrt(y/3), copy(x), b)), x, b)[1])
111+
end[1]
112+
end
113+
end
114+

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,7 @@ end
127127

128128
@testset "Activation Functions" begin
129129
include("activations.jl")
130+
include("bias_act.jl")
130131
end
131132

132133
@testset "Attention" begin

0 commit comments

Comments
 (0)