Skip to content

Commit 0de0ec5

Browse files
committed
sometimes-in-place bias_act
1 parent 5f63dbf commit 0de0ec5

File tree

4 files changed

+330
-0
lines changed

4 files changed

+330
-0
lines changed

src/NNlib.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,9 @@ export conv, conv!, ∇conv_data, ∇conv_data!, ∇conv_filter,
6565
include("conv_bias_act.jl")
6666
export conv_bias_act, conv_bias_act!
6767

68+
include("bias_act.jl")
69+
export bias_act!
70+
6871
include("fold.jl")
6972

7073
include("ctc.jl")

src/bias_act.jl

Lines changed: 285 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,285 @@
1+
2+
using NNlib: fast_act, tanh_fast
3+
using ChainRulesCore
4+
5+
const RCR = RuleConfig{>:HasReverseMode}
6+
7+
# This just saves typing `only.(only.(` many times:
8+
@inline only_derivative(y,f::F,x) where F = only(only(ChainRulesCore.derivatives_given_output(y, f, x)))
9+
10+
# This has no methods, used for testing whether `derivatives_given_output(Ω, f, x)`
11+
# is independent of `x`, as `_return_type` says `Union{}` when calling is an error.
12+
struct NotaNumber <: Real end
13+
14+
"""
15+
bias_act!(σ, x, b)
16+
17+
This is equivalent to `σ.(x .+ b)`, but faster because
18+
it will overwrite `x` to save memory (when possible) and
19+
replace `sigmoid` & `tanh` with `sigmoid_fast` & `tanh_fast`.
20+
21+
The best case requires `x isa StridedArray{<:AbstractFloat}`,
22+
and that the activation has a method of `derivatives_given_output`
23+
which does not need the input at all (such as `relu`, `tanh`).
24+
25+
!!! warning
26+
This is not safe to use if `x` is still needed for the gradient
27+
of some other function. Incorrect use will give silently wrong answers.
28+
"""
29+
bias_act!::Function, x::AbstractArray, b) = fast_act(σ, x).(x .+ b) # fallback
30+
31+
bias_act!::Function, x::StridedArray{<:AbstractFloat}, b::AbstractArray{<:Union{Bool, AbstractFloat}}) =
32+
fast_broadcast_plus!(fast_act(σ, x), x, b) # hand-written version below.
33+
34+
bias_act!(::typeof(identity), x::StridedArray{<:AbstractFloat}, b::Bool) =
35+
(@assert !b "bias=true is not accepted; layer constructors shoud guarantee this"; x)
36+
37+
38+
function ChainRulesCore.rrule(cfg::RCR, ::typeof(bias_act!), σ::F, x::AbstractArray{T,N}, b::B) where {F,T,N,B}
39+
if eltype(B) !== Bool
40+
b_dims = ntuple(d -> size(b,d)==1 ? d : N+1, N)
41+
size_b = size(b)
42+
end
43+
44+
# Fast path: it is now safe to overwrite x, since this is not needed for gradient of σ
45+
if isconcretetype(Core.Compiler._return_type(only_derivative, Tuple{T, F, NotaNumber}))
46+
Ω = bias_act!(σ, x, b) # now x === Ω, when x isa StridedArray{<:AbstractFloat}
47+
function bias_act!_fastback(Δ)
48+
# Tempting to overwrite x again, but only safe if you call pullback at most once,
49+
# TODO with e.g. https://github.com/FluxML/Zygote.jl/pull/1340
50+
# https://github.com/JuliaDiff/ChainRulesCore.jl/pull/592
51+
dx = only_derivative.(Ω, σ, NotaNumber()) .* unthunk(Δ)
52+
db = eltype(B) === Bool ? NoTangent() : reshape(sum(dx; dims = b_dims), size_b)
53+
return (NoTangent(), NoTangent(), dx, db)
54+
end
55+
return Ω, bias_act!_fastback
56+
57+
# # Slower path: can't overwrite x, but can use derivatives_given_output
58+
# # This case is WRONG and tests fail, but not sure why
59+
# elseif isconcretetype(Core.Compiler._return_type(only_derivative, Tuple{T, F, T}))
60+
# Ω2 = fast_act(σ, x).(x) .+ b
61+
# @show σ b
62+
# function bias_act!_back2(Δ)
63+
# dx = only_derivative.(Ω2, σ, x .+ b) .* unthunk(Δ)
64+
# db = eltype(B) === Bool ? NoTangent() : reshape(sum(dx; dims = b_dims), size_b)
65+
# return (NoTangent(), NoTangent(), dx, db)
66+
# end
67+
# return Ω2, bias_act!_back2
68+
69+
# Fallback path: let AD handle the broadcast
70+
else
71+
Ω3, back = rrule_via_ad(cfg, broadcast, fast_act(σ, x), bias_act!(identity, x, b))
72+
@inline function bias_act!_slowback(Δ)
73+
_, _, dx = back(Δ)
74+
db = eltype(B) === Bool ? NoTangent() : reshape(sum(dx; dims = b_dims), size_b)
75+
return (NoTangent(), NoTangent(), dx, db)
76+
end
77+
return Ω3, bias_act!_slowback
78+
end
79+
end
80+
81+
# Two easy cases
82+
function rrule(cfg::RCR, ::typeof(bias_act!), ::typeof(identity), x::AbstractArray{T,N}, b::B) where {T,N,B}
83+
b_dims = ntuple(d -> size(b,d)==1 ? d : N+1, N)
84+
size_b = size(b)
85+
function bias_act!_idback(Δ)
86+
dx = unthunk(Δ)
87+
db = reshape(sum(dx; dims = b_dims), size_b)
88+
return (NoTangent(), NoTangent(), dx, db)
89+
end
90+
return bias_act!(identity, x, b), bias_act!_idback
91+
end
92+
function rrule(cfg::RCR, ::typeof(bias_act!), ::typeof(identity), x::AbstractArray{T,N}, b::Bool) where {T,N}
93+
bias_act!_trivial(Δ) = (NoTangent(), NoTangent(), Δ, NoTangent())
94+
return x, bias_act!_trivial
95+
end
96+
97+
98+
"""
99+
NNlib.fast_broadcast_plus!(f, x, b)
100+
101+
This is equivalent to `x .= f.(x .+ b)`, but works around
102+
an issue with broadcasting that prevents SIMD in such cases.
103+
104+
That can be removed once https://github.com/JuliaLang/julia/issues/43153 is fixed.
105+
106+
Also has an `rrule` to prevent mutation inside 2nd-order AD.
107+
108+
!!! warning
109+
Does not allow for derivatives with respect to `f`.
110+
"""
111+
function fast_broadcast_plus!(f::F, x::Array{<:AbstractFloat}, b) where {F<:Function}
112+
if b === false
113+
@simd ivdep for I in eachindex(x)
114+
@inbounds x[I] = f(x[I])
115+
end
116+
else
117+
xplus = Broadcast.instantiate(Broadcast.broadcasted(+, x, b))
118+
@simd ivdep for I in eachindex(xplus)
119+
@inbounds x[I] = f(xplus[I])
120+
end
121+
end
122+
return x
123+
end
124+
function fast_broadcast_plus!(f::F, x::StridedArray{<:AbstractFloat}, b) where {F<:Function}
125+
# CuArray has its own broadcasting.
126+
x .= f.(x .+ b)
127+
return x
128+
end
129+
function fast_broadcast_plus!(f::F, x::AbstractArray, b) where {F<:Function}
130+
# Don't try to write into weird arrays
131+
return f.(x .+ b)
132+
end
133+
134+
function rrule(cfg::RCR, ::typeof(fast_broadcast_plus!), f::F, x::AbstractArray{T,N}, b::B) where {F,T,N,B}
135+
rrule_via_ad(cfg, broadcast, (x,b) -> f.(x .+ b), x, b)
136+
end
137+
138+
139+
# """
140+
# add_act(σ, x, y...)
141+
# add_act!(σ, x, y, z...)
142+
143+
# Equivalent to `σ.(x .+ y .+ z)`. The mutating method `add_act!`
144+
# """
145+
# add_act(σ::Function, x::AbstractArray, yz::AbstractArray...) = σ.(.+(x, yz...)) # fused
146+
147+
148+
# function ChainRulesCore.rrule(::typeof(add_act), σ::F, x::AbstractArray, yz::AbstractArray...) where {F,T,N}
149+
# if isconcretetype(Core.Compiler._return_type(
150+
# derivatives_given_output, Tuple{T, F, NotaNumber}))
151+
152+
# end
153+
154+
155+
# bias_act!(σ::Function, x::StridedArray{<:AbstractFloat}, b::Bool) =
156+
# # b ? (x .= fast_act(σ, x).(x .+ b)) : (x .= fast_act(σ, x).(x))
157+
# (@assert !b "bias=true is not accepted"; (x .= fast_act(σ, x).(x)))
158+
159+
160+
# using NNlib, BenchmarkTools
161+
162+
#=
163+
164+
## M1 mac, 1.10
165+
166+
julia> w, b = rand(Float32, 100, 10000), rand(Float32, 100);
167+
168+
julia> @btime bias_act!(relu, $w, $b);
169+
min 19.500 μs, mean 21.375 μs (0 allocations)
170+
171+
julia> @btime relu.($w .+ $b);
172+
min 17.208 μs, mean 62.826 μs (2 allocations, 390.67 KiB)
173+
174+
julia> @btime bias_act!(tanh, $w, $b);
175+
min 63.792 μs, mean 65.052 μs (0 allocations)
176+
177+
julia> @btime tanh_fast.($w .+ $b);
178+
min 63.583 μs, mean 102.004 μs (2 allocations, 390.67 KiB)
179+
180+
julia> using Zygote
181+
182+
julia> @btime gradient((w,b) -> sum(bias_act!(relu, w, b)), $w, $b);
183+
min 145.166 μs, mean 150.785 μs (51 allocations, 2.18 KiB)
184+
185+
julia> @btime gradient((w,b) -> sum(relu.(w .+ b)), $w, $b);
186+
min 165.583 μs, mean 314.267 μs (32 allocations, 1.15 MiB)
187+
188+
julia> @btime gradient((w,b) -> sum(bias_act!(tanh, w, b)), $w, $b);
189+
min 191.917 μs, mean 195.956 μs (51 allocations, 2.18 KiB)
190+
191+
julia> @btime gradient((w,b) -> sum(tanh_fast.(w .+ b)), $w, $b);
192+
min 209.458 μs, mean 338.652 μs (32 allocations, 1.15 MiB)
193+
194+
195+
196+
## Cyclops
197+
198+
julia> using CUDA # 10x bigger
199+
200+
julia> cw, cb = CUDA.rand(Float32, 100, 100_00), CUDA.rand(Float32, 100);
201+
202+
julia> @btime CUDA.@sync bias_act!(relu, $cw, $cb);
203+
22.546 μs (27 allocations: 1.45 KiB)
204+
205+
julia> @btime CUDA.@sync relu.($cw .+ $cb); # faster, that's odd?
206+
31.282 μs (38 allocations: 1.81 KiB)
207+
208+
julia> @btime CUDA.@sync bias_act!(tanh, $cw, $cb);
209+
27.030 μs (27 allocations: 1.45 KiB)
210+
211+
julia> @btime CUDA.@sync tanh_fast.($cw .+ $cb);
212+
36.421 μs (38 allocations: 1.81 KiB)
213+
214+
julia> using Zygote
215+
216+
julia> @btime CUDA.@sync gradient((w,b) -> sum(bias_act!(relu, w, b)), $cw, $cb);
217+
204.507 μs (382 allocations: 18.15 KiB)
218+
219+
julia> @btime CUDA.@sync gradient((w,b) -> sum(relu.(w .+ b)), $cw, $cb);
220+
204.458 μs (409 allocations: 19.19 KiB)
221+
222+
julia> @btime CUDA.@sync gradient((w,b) -> sum(bias_act!(tanh, w, b)), $cw, $cb);
223+
224.545 μs (382 allocations: 18.15 KiB)
224+
225+
julia> @btime CUDA.@sync gradient((w,b) -> sum(tanh_fast.(w .+ b)), $cw, $cb);
226+
204.793 μs (411 allocations: 19.30 KiB)
227+
228+
229+
=#
230+
231+
#=
232+
233+
(jl_fuwIi8) pkg> add https://github.com/mcabbott/NNlib.jl/tree/bias_act_23
234+
235+
julia> using NNlib, Zygote, BenchmarkTools
236+
237+
julia> w, b, x = rand(Float32, 50, 50), rand(Float32, 50), randn(Float32, 50, 100);
238+
239+
julia> @btime bias_act!(relu, $w * $x, $b);
240+
min 5.243 μs, mean 8.600 μs (2 allocations, 19.61 KiB)
241+
242+
julia> @btime relu.($w * $x .+ $b);
243+
min 5.160 μs, mean 10.863 μs (4 allocations, 39.22 KiB)
244+
245+
julia> @btime gradient((w,x,b) -> sum(abs2, bias_act!(relu, w*x, b)), $w, $x, $b);
246+
min 21.042 μs, mean 40.476 μs (43 allocations, 89.83 KiB)
247+
248+
julia> @btime gradient((w,x,b) -> sum(abs2, relu.(w*x .+ b)), $w, $x, $b);
249+
min 21.542 μs, mean 43.947 μs (41 allocations, 128.91 KiB)
250+
251+
julia> @btime gradient((w,x) -> sum(abs2, w*x), $w, $x);
252+
min 14.708 μs, mean 26.450 μs (28 allocations, 69.41 KiB)
253+
254+
julia> @btime gradient(x -> sum(abs2, x), $x);
255+
min 1.938 μs, mean 4.160 μs (2 allocations, 19.61 KiB)
256+
257+
258+
# Cyclops
259+
260+
julia> @btime bias_act!(relu, $w * $x, $b);
261+
24.786 μs (2 allocations: 19.61 KiB)
262+
263+
julia> @btime relu.($w * $x .+ $b);
264+
25.501 μs (4 allocations: 39.22 KiB)
265+
266+
julia> @btime gradient((w,x,b) -> sum(abs2, bias_act!(relu, w*x, b)), $w, $x, $b);
267+
91.847 μs (43 allocations: 89.83 KiB)
268+
269+
julia> @btime gradient((w,x,b) -> sum(abs2, relu.(w*x .+ b)), $w, $x, $b);
270+
98.054 μs (41 allocations: 128.91 KiB)
271+
272+
julia> @btime gradient((w,x) -> sum(abs2, w*x), $w, $x);
273+
80.464 μs (28 allocations: 69.41 KiB)
274+
275+
julia> @btime gradient(x -> sum(abs2, x), $x);
276+
4.604 μs (2 allocations: 19.61 KiB)
277+
278+
julia> @time using CUDA; @time cu(ones(3)) .+ 1;
279+
280+
julia> w, b, x = CUDA.rand(Float32, 1000, 1000), CUDA.rand(Float32, 1000), CUDA.rand(Float32, 1000, 1000);
281+
282+
283+
284+
=#
285+

test/bias_act.jl

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
using NNlib, Zygote, Test
2+
using Zygote: ForwardDiff
3+
4+
ACTIVATION_FUNCTIONS =
5+
[@eval($a) for a in NNlib.ACTIVATIONS]
6+
7+
@testset "bias_act!" begin
8+
x = randn(3,4)
9+
b = randn(3)
10+
@test bias_act!(identity, copy(x), b) (x .+ b)
11+
@test bias_act!(relu, copy(x), b) relu.(x .+ b)
12+
@test bias_act!(tanh, copy(x), b) tanh.(x .+ b)
13+
14+
@testset "gradient with $fun" for fun in vcat([identity, tanh, cbrt],
15+
ACTIVATION_FUNCTIONS,
16+
[x->x, x -> 1/(x^2+2), x -> leakyrelu(x, 0.33)])
17+
# Only some of these go the fast path, `cbrt` is an example of a function NNlib knows nothing about.
18+
fun == rrelu && continue # this one is randomised!
19+
20+
@test bias_act!(fun, copy(x), b) fun.(x .+ b)
21+
@test bias_act!(fun, copy(x), false) fun.(x)
22+
23+
gx = ForwardDiff.gradient(x -> sum(bias_act!(fun, copy(x), b)), x)
24+
@test gx Zygote.gradient(x -> sum(bias_act!(fun, copy(x), b)), x)[1]
25+
26+
gx2 = ForwardDiff.gradient(x -> sum(bias_act!(fun, copy(x), false)), x)
27+
@test gx2 Zygote.gradient(x -> sum(bias_act!(fun, copy(x), false)), x)[1]
28+
29+
gb = ForwardDiff.gradient(b -> sum(bias_act!(fun, copy(x), b)), b)
30+
@test gb Zygote.gradient(b -> sum(bias_act!(fun, copy(x), b)), b)[1]
31+
32+
@test Zygote.gradient(b -> sum(bias_act!(fun, copy(x), b)), false) == (nothing,)
33+
@test Zygote.gradient(b -> sum(bias_act!(fun, copy(x), b)), b .> 0) == (nothing,)
34+
end
35+
36+
@testset "gradient for fast_broadcast_plus!" begin
37+
# Gradient definition is just to disable mutation inside 2nd order AD
38+
gx = ForwardDiff.gradient(x -> sum(NNlib.fast_broadcast_plus!(cbrt, copy(x), b)), x)
39+
@test gx Zygote.gradient(x -> sum(NNlib.fast_broadcast_plus!(cbrt, copy(x), b)), x)[1]
40+
end
41+
end

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ include("test_utils.jl")
3838

3939
@testset "Activation Functions" begin
4040
include("activations.jl")
41+
include("bias_act.jl")
4142
end
4243

4344
@testset "Batched Multiplication" begin

0 commit comments

Comments
 (0)