Skip to content

Commit c1d834f

Browse files
committed
tidy & un-comment
1 parent eee9a37 commit c1d834f

File tree

3 files changed

+175
-27
lines changed

3 files changed

+175
-27
lines changed

docs/src/reference.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ pad_zeros
6363

6464
## Convolution
6565

66-
`Flux`'s `Conv` and `CrossCor` layers use `NNlib.DenseConvDims` and `NNlib.conv` internally.
66+
`Flux`'s `Conv` and `CrossCor` layers use `NNlib.DenseConvDims` and `NNlib.conv` internally.
6767

6868
```@docs
6969
conv

src/bias_act.jl

Lines changed: 172 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,12 @@ bias_act!(::typeof(identity), x::StridedArray{<:AbstractFloat}, b::Bool) =
3939

4040

4141
function ChainRulesCore.rrule(cfg::RCR, ::typeof(bias_act!), σ::F, x::AbstractArray{T,N}, b::B) where {F,T,N,B}
42-
if eltype(B) !== Bool
43-
b_dims = ntuple(d -> size(b,d)==1 ? d : N+1, N)
44-
size_b = size(b)
42+
biasgrad = if eltype(B) !== Bool
43+
# Summing over ndims(x)+1 is a trick to make b_dims type-stable
44+
dims = ntuple(d -> size(b,d)==1 ? d : N+1, N)
45+
_biasgrad(dx) = reshape(sum(dx; dims), size(b))
46+
else
47+
Returns(NoTangent())
4548
end
4649

4750
# Fast path: it is now safe to overwrite x, since this is not needed for gradient of σ
@@ -52,50 +55,195 @@ function ChainRulesCore.rrule(cfg::RCR, ::typeof(bias_act!), σ::F, x::AbstractA
5255
# TODO with e.g. https://github.com/FluxML/Zygote.jl/pull/1340
5356
# https://github.com/JuliaDiff/ChainRulesCore.jl/pull/592
5457
dx = only_derivative.(Ω, σ, NotaNumber()) .* unthunk(Δ)
55-
db = eltype(B) === Bool ? NoTangent() : reshape(sum(dx; dims = b_dims), size_b)
56-
return (NoTangent(), NoTangent(), dx, db)
58+
return (NoTangent(), NoTangent(), dx, biasgrad(dx))
5759
end
5860
return Ω, bias_act!_fastback
5961

60-
# # Slower path: can't overwrite x, but can use derivatives_given_output
61-
# # This case is WRONG and tests fail, but not sure why
62-
# elseif isconcretetype(Core.Compiler._return_type(only_derivative, Tuple{T, F, T}))
63-
# Ω2 = fast_act(σ, x).(x) .+ b
64-
# @show σ b
65-
# function bias_act!_back2(Δ)
66-
# dx = only_derivative.(Ω2, σ, x .+ b) .* unthunk(Δ)
67-
# db = eltype(B) === Bool ? NoTangent() : reshape(sum(dx; dims = b_dims), size_b)
68-
# return (NoTangent(), NoTangent(), dx, db)
69-
# end
70-
# return Ω2, bias_act!_back2
62+
# Slower path: can't overwrite x, but can use derivatives_given_output
63+
# This case is WRONG and tests fail, but not sure why
64+
elseif isconcretetype(Core.Compiler._return_type(only_derivative, Tuple{T, F, T}))
65+
Ω2 = fast_act(σ, x).(x) .+ b
66+
@show σ b
67+
function bias_act!_back2(Δ)
68+
dx = only_derivative.(Ω2, σ, x .+ b) .* unthunk(Δ)
69+
return (NoTangent(), NoTangent(), dx, biasgrad(dx))
70+
end
71+
return Ω2, bias_act!_back2
7172

7273
# Fallback path: let AD handle the broadcast
7374
else
7475
Ω3, back = rrule_via_ad(cfg, broadcast, fast_act(σ, x), bias_act!(identity, x, b))
7576
@inline function bias_act!_slowback(Δ)
7677
_, _, dx = back(Δ)
77-
db = eltype(B) === Bool ? NoTangent() : reshape(sum(dx; dims = b_dims), size_b)
78-
return (NoTangent(), NoTangent(), dx, db)
78+
return (NoTangent(), NoTangent(), dx, biasgrad(dx))
7979
end
8080
return Ω3, bias_act!_slowback
8181
end
8282
end
8383

84-
# Two easy cases
84+
# Two easy cases with identity
8585
function rrule(cfg::RCR, ::typeof(bias_act!), ::typeof(identity), x::AbstractArray{T,N}, b::B) where {T,N,B}
86-
b_dims = ntuple(d -> size(b,d)==1 ? d : N+1, N)
87-
size_b = size(b)
86+
dims = ntuple(d -> size(b,d)==1 ? d : N+1, N)
87+
biasgrad(dx) = reshape(sum(dx; dims), size(b))
8888
function bias_act!_idback(Δ)
8989
dx = unthunk(Δ)
90-
db = reshape(sum(dx; dims = b_dims), size_b)
91-
return (NoTangent(), NoTangent(), dx, db)
90+
return (NoTangent(), NoTangent(), dx, biasgrad(dx))
9291
end
9392
return bias_act!(identity, x, b), bias_act!_idback
9493
end
95-
9694
function rrule(cfg::RCR, ::typeof(bias_act!), ::typeof(identity), x::AbstractArray{T,N}, b::Bool) where {T,N}
9795
bias_act!_trivial(Δ) = (NoTangent(), NoTangent(), Δ, NoTangent())
9896
return x, bias_act!_trivial
9997
end
10098

10199

100+
101+
# """
102+
# add_act(σ, x, y...)
103+
# add_act!(σ, x, y, z...)
104+
105+
# Equivalent to `σ.(x .+ y .+ z)`. The mutating method `add_act!`
106+
# """
107+
# add_act(σ::Function, x::AbstractArray, yz::AbstractArray...) = σ.(.+(x, yz...)) # fused
108+
109+
110+
# function ChainRulesCore.rrule(::typeof(add_act), σ::F, x::AbstractArray, yz::AbstractArray...) where {F,T,N}
111+
# if isconcretetype(Core.Compiler._return_type(
112+
# derivatives_given_output, Tuple{T, F, NotaNumber}))
113+
114+
# end
115+
116+
117+
# bias_act!(σ::Function, x::StridedArray{<:AbstractFloat}, b::Bool) =
118+
# # b ? (x .= fast_act(σ, x).(x .+ b)) : (x .= fast_act(σ, x).(x))
119+
# (@assert !b "bias=true is not accepted"; (x .= fast_act(σ, x).(x)))
120+
121+
122+
# using NNlib, BenchmarkTools
123+
124+
#=
125+
126+
## M1 mac, 1.10
127+
128+
julia> w, b = rand(Float32, 100, 10000), rand(Float32, 100);
129+
130+
julia> @btime bias_act!(relu, $w, $b);
131+
min 19.500 μs, mean 21.375 μs (0 allocations)
132+
133+
julia> @btime relu.($w .+ $b);
134+
min 17.208 μs, mean 62.826 μs (2 allocations, 390.67 KiB)
135+
136+
julia> @btime bias_act!(tanh, $w, $b);
137+
min 63.792 μs, mean 65.052 μs (0 allocations)
138+
139+
julia> @btime tanh_fast.($w .+ $b);
140+
min 63.583 μs, mean 102.004 μs (2 allocations, 390.67 KiB)
141+
142+
julia> using Zygote
143+
144+
julia> @btime gradient((w,b) -> sum(bias_act!(relu, w, b)), $w, $b);
145+
min 145.166 μs, mean 150.785 μs (51 allocations, 2.18 KiB)
146+
147+
julia> @btime gradient((w,b) -> sum(relu.(w .+ b)), $w, $b);
148+
min 165.583 μs, mean 314.267 μs (32 allocations, 1.15 MiB)
149+
150+
julia> @btime gradient((w,b) -> sum(bias_act!(tanh, w, b)), $w, $b);
151+
min 191.917 μs, mean 195.956 μs (51 allocations, 2.18 KiB)
152+
153+
julia> @btime gradient((w,b) -> sum(tanh_fast.(w .+ b)), $w, $b);
154+
min 209.458 μs, mean 338.652 μs (32 allocations, 1.15 MiB)
155+
156+
157+
158+
## Cyclops
159+
160+
julia> using CUDA # 10x bigger
161+
162+
julia> cw, cb = CUDA.rand(Float32, 100, 100_00), CUDA.rand(Float32, 100);
163+
164+
julia> @btime CUDA.@sync bias_act!(relu, $cw, $cb);
165+
22.546 μs (27 allocations: 1.45 KiB)
166+
167+
julia> @btime CUDA.@sync relu.($cw .+ $cb); # faster, that's odd?
168+
31.282 μs (38 allocations: 1.81 KiB)
169+
170+
julia> @btime CUDA.@sync bias_act!(tanh, $cw, $cb);
171+
27.030 μs (27 allocations: 1.45 KiB)
172+
173+
julia> @btime CUDA.@sync tanh_fast.($cw .+ $cb);
174+
36.421 μs (38 allocations: 1.81 KiB)
175+
176+
julia> using Zygote
177+
178+
julia> @btime CUDA.@sync gradient((w,b) -> sum(bias_act!(relu, w, b)), $cw, $cb);
179+
204.507 μs (382 allocations: 18.15 KiB)
180+
181+
julia> @btime CUDA.@sync gradient((w,b) -> sum(relu.(w .+ b)), $cw, $cb);
182+
204.458 μs (409 allocations: 19.19 KiB)
183+
184+
julia> @btime CUDA.@sync gradient((w,b) -> sum(bias_act!(tanh, w, b)), $cw, $cb);
185+
224.545 μs (382 allocations: 18.15 KiB)
186+
187+
julia> @btime CUDA.@sync gradient((w,b) -> sum(tanh_fast.(w .+ b)), $cw, $cb);
188+
204.793 μs (411 allocations: 19.30 KiB)
189+
190+
191+
=#
192+
193+
#=
194+
195+
(jl_fuwIi8) pkg> add https://github.com/mcabbott/NNlib.jl/tree/bias_act_23
196+
197+
julia> using NNlib, Zygote, BenchmarkTools
198+
199+
julia> w, b, x = rand(Float32, 50, 50), rand(Float32, 50), randn(Float32, 50, 100);
200+
201+
julia> @btime bias_act!(relu, $w * $x, $b);
202+
min 5.243 μs, mean 8.600 μs (2 allocations, 19.61 KiB)
203+
204+
julia> @btime relu.($w * $x .+ $b);
205+
min 5.160 μs, mean 10.863 μs (4 allocations, 39.22 KiB)
206+
207+
julia> @btime gradient((w,x,b) -> sum(abs2, bias_act!(relu, w*x, b)), $w, $x, $b);
208+
min 21.042 μs, mean 40.476 μs (43 allocations, 89.83 KiB)
209+
210+
julia> @btime gradient((w,x,b) -> sum(abs2, relu.(w*x .+ b)), $w, $x, $b);
211+
min 21.542 μs, mean 43.947 μs (41 allocations, 128.91 KiB)
212+
213+
julia> @btime gradient((w,x) -> sum(abs2, w*x), $w, $x);
214+
min 14.708 μs, mean 26.450 μs (28 allocations, 69.41 KiB)
215+
216+
julia> @btime gradient(x -> sum(abs2, x), $x);
217+
min 1.938 μs, mean 4.160 μs (2 allocations, 19.61 KiB)
218+
219+
220+
# Cyclops
221+
222+
julia> @btime bias_act!(relu, $w * $x, $b);
223+
24.786 μs (2 allocations: 19.61 KiB)
224+
225+
julia> @btime relu.($w * $x .+ $b);
226+
25.501 μs (4 allocations: 39.22 KiB)
227+
228+
julia> @btime gradient((w,x,b) -> sum(abs2, bias_act!(relu, w*x, b)), $w, $x, $b);
229+
91.847 μs (43 allocations: 89.83 KiB)
230+
231+
julia> @btime gradient((w,x,b) -> sum(abs2, relu.(w*x .+ b)), $w, $x, $b);
232+
98.054 μs (41 allocations: 128.91 KiB)
233+
234+
julia> @btime gradient((w,x) -> sum(abs2, w*x), $w, $x);
235+
80.464 μs (28 allocations: 69.41 KiB)
236+
237+
julia> @btime gradient(x -> sum(abs2, x), $x);
238+
4.604 μs (2 allocations: 19.61 KiB)
239+
240+
julia> @time using CUDA; @time cu(ones(3)) .+ 1;
241+
242+
julia> w, b, x = CUDA.rand(Float32, 1000, 1000), CUDA.rand(Float32, 1000), CUDA.rand(Float32, 1000, 1000);
243+
244+
245+
246+
=#
247+
248+
249+

src/utils.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,8 @@ safe_div(x, y) = ifelse(iszero(y), x, x/y)
6262
maximum_dims(dims)
6363
6464
Given an array of `CartesianIndex{N}` or `NTuple{N,Int}`,
65-
returns a tuple containing the maximum of the 1st entries,
66-
the 2nd, and so on up to `N`.
65+
returns a tuple containing the maximum of all the 1st entries,
66+
all the 2nd entries, and so on up to `N`.
6767
6868
Given an array of integers, returns `(maximum(dims),)`.
6969

0 commit comments

Comments
 (0)