Skip to content

Commit 795a2d7

Browse files
committed
update after dropout PR
1 parent 6ae7ffa commit 795a2d7

File tree

4 files changed

+75
-223
lines changed

4 files changed

+75
-223
lines changed

src/bias_act.jl

Lines changed: 12 additions & 196 deletions
Original file line numberDiff line numberDiff line change
@@ -8,28 +8,31 @@ const RCR = RuleConfig{>:HasReverseMode}
88
@inline only_derivative(y,f::F,x) where F = only(only(ChainRulesCore.derivatives_given_output(y, f, x)))
99

1010
# This has no methods, used for testing whether `derivatives_given_output(Ω, f, x)`
11-
# is independent of `x`, as `_return_type` says `Union{}` when calling is an error.
11+
# is independent of `x`, as `_return_type` says `Union{}` when calling is an error.
1212
struct NotaNumber <: Real end
1313

1414
"""
1515
bias_act!(σ, x, b)
1616
17-
This is equivalent to `σ.(x .+ b)`, but faster because
18-
it will overwrite `x` to save memory (when possible) and
19-
replace `sigmoid` & `tanh` with `sigmoid_fast` & `tanh_fast`.
17+
This is equivalent to `x .= σ.(x .+ b)`, also replacing `sigmoid` & `tanh`
18+
with `sigmoid_fast` & `tanh_fast`.
19+
It will only overwrite `x` when `x isa StridedArray{<:AbstractFloat}`.
2020
21-
The best case requires `x isa StridedArray{<:AbstractFloat}`,
22-
and that the activation has a method of `derivatives_given_output`
23-
which does not need the input at all (such as `relu`, `tanh`).
21+
When used within a gradient, it will overwrite only when `σ` has
22+
a method of `derivatives_given_output` which does not need the input at all.
23+
Such methods are defined by e.g. `@scalar_rule relu(x) Ω > 0` where the derivative
24+
contains only `Ω` (the output) not `x`.
2425
2526
!!! warning
2627
This is not safe to use if `x` is still needed for the gradient
2728
of some other function. Incorrect use will give silently wrong answers.
29+
It is intended mainly for Flux layers, in which the previous operation is
30+
known to be safe, e.g. `bias_act!(σ, weight * input, bias)` for a `Dense` layer.
2831
"""
2932
bias_act!::Function, x::AbstractArray, b) = fast_act(σ, x).(x .+ b) # fallback
3033

3134
bias_act!::Function, x::StridedArray{<:AbstractFloat}, b::AbstractArray{<:Union{Bool, AbstractFloat}}) =
32-
fast_broadcast_plus!(fast_act(σ, x), x, b) # hand-written version below.
35+
_fast_broadcast!(fast_act(σ, x)(+), x, b) # works around a SIMD bug
3336

3437
bias_act!(::typeof(identity), x::StridedArray{<:AbstractFloat}, b::Bool) =
3538
(@assert !b "bias=true is not accepted; layer constructors shoud guarantee this"; x)
@@ -89,197 +92,10 @@ function rrule(cfg::RCR, ::typeof(bias_act!), ::typeof(identity), x::AbstractArr
8992
end
9093
return bias_act!(identity, x, b), bias_act!_idback
9194
end
95+
9296
function rrule(cfg::RCR, ::typeof(bias_act!), ::typeof(identity), x::AbstractArray{T,N}, b::Bool) where {T,N}
9397
bias_act!_trivial(Δ) = (NoTangent(), NoTangent(), Δ, NoTangent())
9498
return x, bias_act!_trivial
9599
end
96100

97101

98-
"""
99-
NNlib.fast_broadcast_plus!(f, x, b)
100-
101-
This is equivalent to `x .= f.(x .+ b)`, but works around
102-
an issue with broadcasting that prevents SIMD in such cases.
103-
104-
That can be removed once https://github.com/JuliaLang/julia/issues/43153 is fixed.
105-
106-
Also has an `rrule` to prevent mutation inside 2nd-order AD.
107-
108-
!!! warning
109-
Does not allow for derivatives with respect to `f`.
110-
"""
111-
function fast_broadcast_plus!(f::F, x::Array{<:AbstractFloat}, b) where {F<:Function}
112-
if b === false
113-
@simd ivdep for I in eachindex(x)
114-
@inbounds x[I] = f(x[I])
115-
end
116-
else
117-
xplus = Broadcast.instantiate(Broadcast.broadcasted(+, x, b))
118-
@simd ivdep for I in eachindex(xplus)
119-
@inbounds x[I] = f(xplus[I])
120-
end
121-
end
122-
return x
123-
end
124-
function fast_broadcast_plus!(f::F, x::StridedArray{<:AbstractFloat}, b) where {F<:Function}
125-
# CuArray has its own broadcasting.
126-
x .= f.(x .+ b)
127-
return x
128-
end
129-
function fast_broadcast_plus!(f::F, x::AbstractArray, b) where {F<:Function}
130-
# Don't try to write into weird arrays
131-
return f.(x .+ b)
132-
end
133-
134-
function rrule(cfg::RCR, ::typeof(fast_broadcast_plus!), f::F, x::AbstractArray{T,N}, b::B) where {F,T,N,B}
135-
rrule_via_ad(cfg, broadcast, (x,b) -> f.(x .+ b), x, b)
136-
end
137-
138-
139-
# """
140-
# add_act(σ, x, y...)
141-
# add_act!(σ, x, y, z...)
142-
143-
# Equivalent to `σ.(x .+ y .+ z)`. The mutating method `add_act!`
144-
# """
145-
# add_act(σ::Function, x::AbstractArray, yz::AbstractArray...) = σ.(.+(x, yz...)) # fused
146-
147-
148-
# function ChainRulesCore.rrule(::typeof(add_act), σ::F, x::AbstractArray, yz::AbstractArray...) where {F,T,N}
149-
# if isconcretetype(Core.Compiler._return_type(
150-
# derivatives_given_output, Tuple{T, F, NotaNumber}))
151-
152-
# end
153-
154-
155-
# bias_act!(σ::Function, x::StridedArray{<:AbstractFloat}, b::Bool) =
156-
# # b ? (x .= fast_act(σ, x).(x .+ b)) : (x .= fast_act(σ, x).(x))
157-
# (@assert !b "bias=true is not accepted"; (x .= fast_act(σ, x).(x)))
158-
159-
160-
# using NNlib, BenchmarkTools
161-
162-
#=
163-
164-
## M1 mac, 1.10
165-
166-
julia> w, b = rand(Float32, 100, 10000), rand(Float32, 100);
167-
168-
julia> @btime bias_act!(relu, $w, $b);
169-
min 19.500 μs, mean 21.375 μs (0 allocations)
170-
171-
julia> @btime relu.($w .+ $b);
172-
min 17.208 μs, mean 62.826 μs (2 allocations, 390.67 KiB)
173-
174-
julia> @btime bias_act!(tanh, $w, $b);
175-
min 63.792 μs, mean 65.052 μs (0 allocations)
176-
177-
julia> @btime tanh_fast.($w .+ $b);
178-
min 63.583 μs, mean 102.004 μs (2 allocations, 390.67 KiB)
179-
180-
julia> using Zygote
181-
182-
julia> @btime gradient((w,b) -> sum(bias_act!(relu, w, b)), $w, $b);
183-
min 145.166 μs, mean 150.785 μs (51 allocations, 2.18 KiB)
184-
185-
julia> @btime gradient((w,b) -> sum(relu.(w .+ b)), $w, $b);
186-
min 165.583 μs, mean 314.267 μs (32 allocations, 1.15 MiB)
187-
188-
julia> @btime gradient((w,b) -> sum(bias_act!(tanh, w, b)), $w, $b);
189-
min 191.917 μs, mean 195.956 μs (51 allocations, 2.18 KiB)
190-
191-
julia> @btime gradient((w,b) -> sum(tanh_fast.(w .+ b)), $w, $b);
192-
min 209.458 μs, mean 338.652 μs (32 allocations, 1.15 MiB)
193-
194-
195-
196-
## Cyclops
197-
198-
julia> using CUDA # 10x bigger
199-
200-
julia> cw, cb = CUDA.rand(Float32, 100, 100_00), CUDA.rand(Float32, 100);
201-
202-
julia> @btime CUDA.@sync bias_act!(relu, $cw, $cb);
203-
22.546 μs (27 allocations: 1.45 KiB)
204-
205-
julia> @btime CUDA.@sync relu.($cw .+ $cb); # faster, that's odd?
206-
31.282 μs (38 allocations: 1.81 KiB)
207-
208-
julia> @btime CUDA.@sync bias_act!(tanh, $cw, $cb);
209-
27.030 μs (27 allocations: 1.45 KiB)
210-
211-
julia> @btime CUDA.@sync tanh_fast.($cw .+ $cb);
212-
36.421 μs (38 allocations: 1.81 KiB)
213-
214-
julia> using Zygote
215-
216-
julia> @btime CUDA.@sync gradient((w,b) -> sum(bias_act!(relu, w, b)), $cw, $cb);
217-
204.507 μs (382 allocations: 18.15 KiB)
218-
219-
julia> @btime CUDA.@sync gradient((w,b) -> sum(relu.(w .+ b)), $cw, $cb);
220-
204.458 μs (409 allocations: 19.19 KiB)
221-
222-
julia> @btime CUDA.@sync gradient((w,b) -> sum(bias_act!(tanh, w, b)), $cw, $cb);
223-
224.545 μs (382 allocations: 18.15 KiB)
224-
225-
julia> @btime CUDA.@sync gradient((w,b) -> sum(tanh_fast.(w .+ b)), $cw, $cb);
226-
204.793 μs (411 allocations: 19.30 KiB)
227-
228-
229-
=#
230-
231-
#=
232-
233-
(jl_fuwIi8) pkg> add https://github.com/mcabbott/NNlib.jl/tree/bias_act_23
234-
235-
julia> using NNlib, Zygote, BenchmarkTools
236-
237-
julia> w, b, x = rand(Float32, 50, 50), rand(Float32, 50), randn(Float32, 50, 100);
238-
239-
julia> @btime bias_act!(relu, $w * $x, $b);
240-
min 5.243 μs, mean 8.600 μs (2 allocations, 19.61 KiB)
241-
242-
julia> @btime relu.($w * $x .+ $b);
243-
min 5.160 μs, mean 10.863 μs (4 allocations, 39.22 KiB)
244-
245-
julia> @btime gradient((w,x,b) -> sum(abs2, bias_act!(relu, w*x, b)), $w, $x, $b);
246-
min 21.042 μs, mean 40.476 μs (43 allocations, 89.83 KiB)
247-
248-
julia> @btime gradient((w,x,b) -> sum(abs2, relu.(w*x .+ b)), $w, $x, $b);
249-
min 21.542 μs, mean 43.947 μs (41 allocations, 128.91 KiB)
250-
251-
julia> @btime gradient((w,x) -> sum(abs2, w*x), $w, $x);
252-
min 14.708 μs, mean 26.450 μs (28 allocations, 69.41 KiB)
253-
254-
julia> @btime gradient(x -> sum(abs2, x), $x);
255-
min 1.938 μs, mean 4.160 μs (2 allocations, 19.61 KiB)
256-
257-
258-
# Cyclops
259-
260-
julia> @btime bias_act!(relu, $w * $x, $b);
261-
24.786 μs (2 allocations: 19.61 KiB)
262-
263-
julia> @btime relu.($w * $x .+ $b);
264-
25.501 μs (4 allocations: 39.22 KiB)
265-
266-
julia> @btime gradient((w,x,b) -> sum(abs2, bias_act!(relu, w*x, b)), $w, $x, $b);
267-
91.847 μs (43 allocations: 89.83 KiB)
268-
269-
julia> @btime gradient((w,x,b) -> sum(abs2, relu.(w*x .+ b)), $w, $x, $b);
270-
98.054 μs (41 allocations: 128.91 KiB)
271-
272-
julia> @btime gradient((w,x) -> sum(abs2, w*x), $w, $x);
273-
80.464 μs (28 allocations: 69.41 KiB)
274-
275-
julia> @btime gradient(x -> sum(abs2, x), $x);
276-
4.604 μs (2 allocations: 19.61 KiB)
277-
278-
julia> @time using CUDA; @time cu(ones(3)) .+ 1;
279-
280-
julia> w, b, x = CUDA.rand(Float32, 1000, 1000), CUDA.rand(Float32, 1000), CUDA.rand(Float32, 1000, 1000);
281-
282-
283-
284-
=#
285-

src/dropout.jl

Lines changed: 0 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -123,27 +123,6 @@ end
123123
# and seems about as fast, but needs a method `copy(::CUDA.RNG)` and careful checking.
124124
# https://github.com/FluxML/NNlib.jl/pull/454#issuecomment-1369357402
125125

126-
"""
127-
_fast_broadcast!(f, x, y, z...)
128-
129-
This does `x .= f.(x, y, z...)`, but works around
130-
an issue with broadcasting that prevents SIMD in such cases.
131-
Can be removed once https://github.com/JuliaLang/julia/issues/43153 is fixed.
132-
133-
Not intended for general use. Does not check sizes!
134-
"""
135-
function _fast_broadcast!(f::F, x::Array, yz...) where {F<:Function}
136-
bc = Broadcast.instantiate(Broadcast.broadcasted(f, x, yz...))
137-
@simd ivdep for I in eachindex(bc)
138-
@inbounds x[I] = bc[I]
139-
end
140-
return x
141-
end
142-
function _fast_broadcast!(f::F, x::AbstractArray, yz...) where {F<:Function}
143-
# CUDA does not suffer from this bug
144-
broadcast!(f, x, x, yz...)
145-
end
146-
147126

148127
"""
149128
_rng_from_array(x)

src/utils.jl

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,4 +53,32 @@ function reverse_indices(idx::AbstractArray{<:Any,N}) where N
5353
return reverse_indices!(rev, idx)
5454
end
5555

56-
unsqueeze(x) = reshape(x, 1, size(x)...)
56+
unsqueeze(x) = reshape(x, 1, size(x)...)
57+
58+
59+
"""
60+
_fast_broadcast!(f, x, y, z...)
61+
62+
This does `x .= f.(x, y, z...)`, but works around
63+
an issue with broadcasting that prevents SIMD in such cases.
64+
Can be removed once https://github.com/JuliaLang/julia/issues/43153 is fixed.
65+
66+
Not intended for general use. Uses `@inbounds` but does not check sizes!
67+
68+
Has an `rrule` to avoid mutation within derivatives. This assumes that `f` has no derivative!
69+
"""
70+
function _fast_broadcast!(f::F, x::Array, yz...) where {F<:Function}
71+
bc = Broadcast.instantiate(Broadcast.broadcasted(f, x, yz...))
72+
@simd ivdep for I in eachindex(bc)
73+
@inbounds x[I] = bc[I]
74+
end
75+
return x
76+
end
77+
function _fast_broadcast!(f::F, x::AbstractArray, yz...) where {F<:Function}
78+
# CUDA does not suffer from this bug
79+
broadcast!(f, x, x, yz...)
80+
end
81+
82+
function rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(_fast_broadcast!), f::F, x::AbstractArray, ys...) where {F<:Function}
83+
rrule_via_ad(cfg, broadcast, f, x, ys...)
84+
end

test/bias_act.jl

Lines changed: 34 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
using NNlib, Zygote, Test
22
using Zygote: ForwardDiff
33

4-
ACTIVATION_FUNCTIONS =
4+
ACTIVATION_FUNCTIONS =
55
[@eval($a) for a in NNlib.ACTIVATIONS]
66

77
@testset "bias_act!" begin
@@ -11,7 +11,7 @@ ACTIVATION_FUNCTIONS =
1111
@test bias_act!(relu, copy(x), b) relu.(x .+ b)
1212
@test bias_act!(tanh, copy(x), b) tanh.(x .+ b)
1313

14-
@testset "gradient with $fun" for fun in vcat([identity, tanh, cbrt],
14+
@testset "gradient with $fun" for fun in vcat([identity, tanh, cbrt],
1515
ACTIVATION_FUNCTIONS,
1616
[x->x, x -> 1/(x^2+2), x -> leakyrelu(x, 0.33)])
1717
# Only some of these go the fast path, `cbrt` is an example of a function NNlib knows nothing about.
@@ -33,9 +33,38 @@ ACTIVATION_FUNCTIONS =
3333
@test Zygote.gradient(b -> sum(bias_act!(fun, copy(x), b)), b .> 0) == (nothing,)
3434
end
3535

36-
@testset "gradient for fast_broadcast_plus!" begin
36+
@testset "gradient for fast_broadcast!" begin
3737
# Gradient definition is just to disable mutation inside 2nd order AD
38-
gx = ForwardDiff.gradient(x -> sum(NNlib.fast_broadcast_plus!(cbrt, copy(x), b)), x)
39-
@test gx Zygote.gradient(x -> sum(NNlib.fast_broadcast_plus!(cbrt, copy(x), b)), x)[1]
38+
gx = ForwardDiff.gradient(x -> sum(NNlib._fast_broadcast!(cbrt(+), copy(x), b)), x)
39+
@test gx Zygote.gradient(x -> sum(NNlib._fast_broadcast!(cbrt(+), copy(x), b)), x)[1]
40+
41+
# relu should take the fast path
42+
g2 = ForwardDiff.gradient(x) do x
43+
sum(abs2, Zygote.gradient(x -> sum(abs2, bias_act!(relu, copy(x), b)), x)[1])
44+
end
45+
@test_broken gx Zygote.gradient(x) do x
46+
sum(abs2,Zygote. gradient(x -> sum(abs2, bias_act!(relu, copy(x), b)), x)[1])
47+
end
48+
# Can't differentiate foreigncall expression $(Expr(:foreigncall, :(:jl_eqtable_get), Any, svec(Any, Any, Any), 0, :(:ccall), %5, %3, %4)).
49+
# [5] (::typeof(∂(accum_global)))(Δ::Nothing)
50+
@test g2 Zygote.gradient(x, b) do x, b
51+
sum(abs2, Zygote.gradient((x, b) -> sum(abs2, bias_act!(relu, copy(x), b)), x, b)[1])
52+
end[1]
53+
54+
g3 = ForwardDiff.gradient(x) do x
55+
sum(abs2, Zygote.gradient((x, b) -> sum(abs2, bias_act!(swish, copy(x), b)), x, b)[1])
56+
end
57+
@test g3 Zygote.gradient(x, b) do x, b
58+
sum(abs2, Zygote.gradient((x, b) -> sum(abs2, bias_act!(swish, copy(x), b)), x, b)[1])
59+
end[1]
60+
61+
# Anon function sure to take the generic path
62+
g4 = ForwardDiff.gradient(x) do x
63+
sum(abs2, Zygote.gradient((x, b) -> sum(abs2, bias_act!(y -> cbrt(y/3), copy(x), b)), x, b)[1])
64+
end
65+
@test g4 Zygote.gradient(x, b) do x, b
66+
sum(abs2, Zygote.gradient((x, b) -> sum(abs2, bias_act!(y -> cbrt(y/3), copy(x), b)), x, b)[1])
67+
end[1]
4068
end
4169
end
70+

0 commit comments

Comments
 (0)