|
| 1 | + |
| 2 | +using NNlib: fast_act, tanh_fast |
| 3 | +using ChainRulesCore |
| 4 | + |
| 5 | +const RCR = RuleConfig{>:HasReverseMode} |
| 6 | + |
| 7 | +# This just saves typing `only.(only.(` many times: |
| 8 | +@inline only_derivative(y,f::F,x) where F = only(only(ChainRulesCore.derivatives_given_output(y, f, x))) |
| 9 | + |
| 10 | +# This has no methods, used for testing whether `derivatives_given_output(Ω, f, x)` |
| 11 | +# is independent of `x`, as `_return_type` says `Union{}` when calling is an error. |
| 12 | +struct NotaNumber <: Real end |
| 13 | + |
| 14 | +""" |
| 15 | + bias_act!(σ, x, b) |
| 16 | +
|
| 17 | +This is equivalent to `σ.(x .+ b)`, but faster because |
| 18 | +it will overwrite `x` to save memory (when possible) and |
| 19 | +replace `sigmoid` & `tanh` with `sigmoid_fast` & `tanh_fast`. |
| 20 | +
|
| 21 | +The best case requires `x isa StridedArray{<:AbstractFloat}`, |
| 22 | +and that the activation has a method of `derivatives_given_output` |
| 23 | +which does not need the input at all (such as `relu`, `tanh`). |
| 24 | +
|
| 25 | +!!! warning |
| 26 | + This is not safe to use if `x` is still needed for the gradient |
| 27 | + of some other function. Incorrect use will give silently wrong answers. |
| 28 | +""" |
| 29 | +bias_act!(σ::Function, x::AbstractArray, b) = fast_act(σ, x).(x .+ b) # fallback |
| 30 | + |
| 31 | +bias_act!(σ::Function, x::StridedArray{<:AbstractFloat}, b::AbstractArray{<:Union{Bool, AbstractFloat}}) = |
| 32 | + fast_broadcast_plus!(fast_act(σ, x), x, b) # hand-written version below. |
| 33 | + |
| 34 | +bias_act!(::typeof(identity), x::StridedArray{<:AbstractFloat}, b::Bool) = |
| 35 | + (@assert !b "bias=true is not accepted; layer constructors shoud guarantee this"; x) |
| 36 | + |
| 37 | + |
| 38 | +function ChainRulesCore.rrule(cfg::RCR, ::typeof(bias_act!), σ::F, x::AbstractArray{T,N}, b::B) where {F,T,N,B} |
| 39 | + if eltype(B) !== Bool |
| 40 | + b_dims = ntuple(d -> size(b,d)==1 ? d : N+1, N) |
| 41 | + size_b = size(b) |
| 42 | + end |
| 43 | + |
| 44 | + # Fast path: it is now safe to overwrite x, since this is not needed for gradient of σ |
| 45 | + if isconcretetype(Core.Compiler._return_type(only_derivative, Tuple{T, F, NotaNumber})) |
| 46 | + Ω = bias_act!(σ, x, b) # now x === Ω, when x isa StridedArray{<:AbstractFloat} |
| 47 | + function bias_act!_fastback(Δ) |
| 48 | + # Tempting to overwrite x again, but only safe if you call pullback at most once, |
| 49 | + # TODO with e.g. https://github.com/FluxML/Zygote.jl/pull/1340 |
| 50 | + # https://github.com/JuliaDiff/ChainRulesCore.jl/pull/592 |
| 51 | + dx = only_derivative.(Ω, σ, NotaNumber()) .* unthunk(Δ) |
| 52 | + db = eltype(B) === Bool ? NoTangent() : reshape(sum(dx; dims = b_dims), size_b) |
| 53 | + return (NoTangent(), NoTangent(), dx, db) |
| 54 | + end |
| 55 | + return Ω, bias_act!_fastback |
| 56 | + |
| 57 | + # # Slower path: can't overwrite x, but can use derivatives_given_output |
| 58 | + # # This case is WRONG and tests fail, but not sure why |
| 59 | + # elseif isconcretetype(Core.Compiler._return_type(only_derivative, Tuple{T, F, T})) |
| 60 | + # Ω2 = fast_act(σ, x).(x) .+ b |
| 61 | + # @show σ b |
| 62 | + # function bias_act!_back2(Δ) |
| 63 | + # dx = only_derivative.(Ω2, σ, x .+ b) .* unthunk(Δ) |
| 64 | + # db = eltype(B) === Bool ? NoTangent() : reshape(sum(dx; dims = b_dims), size_b) |
| 65 | + # return (NoTangent(), NoTangent(), dx, db) |
| 66 | + # end |
| 67 | + # return Ω2, bias_act!_back2 |
| 68 | + |
| 69 | + # Fallback path: let AD handle the broadcast |
| 70 | + else |
| 71 | + Ω3, back = rrule_via_ad(cfg, broadcast, fast_act(σ, x), bias_act!(identity, x, b)) |
| 72 | + @inline function bias_act!_slowback(Δ) |
| 73 | + _, _, dx = back(Δ) |
| 74 | + db = eltype(B) === Bool ? NoTangent() : reshape(sum(dx; dims = b_dims), size_b) |
| 75 | + return (NoTangent(), NoTangent(), dx, db) |
| 76 | + end |
| 77 | + return Ω3, bias_act!_slowback |
| 78 | + end |
| 79 | +end |
| 80 | + |
| 81 | +# Two easy cases |
| 82 | +function rrule(cfg::RCR, ::typeof(bias_act!), ::typeof(identity), x::AbstractArray{T,N}, b::B) where {T,N,B} |
| 83 | + b_dims = ntuple(d -> size(b,d)==1 ? d : N+1, N) |
| 84 | + size_b = size(b) |
| 85 | + function bias_act!_idback(Δ) |
| 86 | + dx = unthunk(Δ) |
| 87 | + db = reshape(sum(dx; dims = b_dims), size_b) |
| 88 | + return (NoTangent(), NoTangent(), dx, db) |
| 89 | + end |
| 90 | + return bias_act!(identity, x, b), bias_act!_idback |
| 91 | +end |
| 92 | +function rrule(cfg::RCR, ::typeof(bias_act!), ::typeof(identity), x::AbstractArray{T,N}, b::Bool) where {T,N} |
| 93 | + bias_act!_trivial(Δ) = (NoTangent(), NoTangent(), Δ, NoTangent()) |
| 94 | + return x, bias_act!_trivial |
| 95 | +end |
| 96 | + |
| 97 | + |
| 98 | +""" |
| 99 | + NNlib.fast_broadcast_plus!(f, x, b) |
| 100 | +
|
| 101 | +This is equivalent to `x .= f.(x .+ b)`, but works around |
| 102 | +an issue with broadcasting that prevents SIMD in such cases. |
| 103 | +
|
| 104 | +That can be removed once https://github.com/JuliaLang/julia/issues/43153 is fixed. |
| 105 | +
|
| 106 | +Also has an `rrule` to prevent mutation inside 2nd-order AD. |
| 107 | +
|
| 108 | +!!! warning |
| 109 | + Does not allow for derivatives with respect to `f`. |
| 110 | +""" |
| 111 | +function fast_broadcast_plus!(f::F, x::Array{<:AbstractFloat}, b) where {F<:Function} |
| 112 | + if b === false |
| 113 | + @simd ivdep for I in eachindex(x) |
| 114 | + @inbounds x[I] = f(x[I]) |
| 115 | + end |
| 116 | + else |
| 117 | + xplus = Broadcast.instantiate(Broadcast.broadcasted(+, x, b)) |
| 118 | + @simd ivdep for I in eachindex(xplus) |
| 119 | + @inbounds x[I] = f(xplus[I]) |
| 120 | + end |
| 121 | + end |
| 122 | + return x |
| 123 | +end |
| 124 | +function fast_broadcast_plus!(f::F, x::StridedArray{<:AbstractFloat}, b) where {F<:Function} |
| 125 | + # CuArray has its own broadcasting. |
| 126 | + x .= f.(x .+ b) |
| 127 | + return x |
| 128 | +end |
| 129 | +function fast_broadcast_plus!(f::F, x::AbstractArray, b) where {F<:Function} |
| 130 | + # Don't try to write into weird arrays |
| 131 | + return f.(x .+ b) |
| 132 | +end |
| 133 | + |
| 134 | +function rrule(cfg::RCR, ::typeof(fast_broadcast_plus!), f::F, x::AbstractArray{T,N}, b::B) where {F,T,N,B} |
| 135 | + rrule_via_ad(cfg, broadcast, (x,b) -> f.(x .+ b), x, b) |
| 136 | +end |
| 137 | + |
| 138 | + |
| 139 | +# """ |
| 140 | +# add_act(σ, x, y...) |
| 141 | +# add_act!(σ, x, y, z...) |
| 142 | + |
| 143 | +# Equivalent to `σ.(x .+ y .+ z)`. The mutating method `add_act!` |
| 144 | +# """ |
| 145 | +# add_act(σ::Function, x::AbstractArray, yz::AbstractArray...) = σ.(.+(x, yz...)) # fused |
| 146 | + |
| 147 | + |
| 148 | +# function ChainRulesCore.rrule(::typeof(add_act), σ::F, x::AbstractArray, yz::AbstractArray...) where {F,T,N} |
| 149 | +# if isconcretetype(Core.Compiler._return_type( |
| 150 | +# derivatives_given_output, Tuple{T, F, NotaNumber})) |
| 151 | + |
| 152 | +# end |
| 153 | + |
| 154 | + |
| 155 | +# bias_act!(σ::Function, x::StridedArray{<:AbstractFloat}, b::Bool) = |
| 156 | +# # b ? (x .= fast_act(σ, x).(x .+ b)) : (x .= fast_act(σ, x).(x)) |
| 157 | +# (@assert !b "bias=true is not accepted"; (x .= fast_act(σ, x).(x))) |
| 158 | + |
| 159 | + |
| 160 | +# using NNlib, BenchmarkTools |
| 161 | + |
| 162 | +#= |
| 163 | +
|
| 164 | +## M1 mac, 1.10 |
| 165 | +
|
| 166 | +julia> w, b = rand(Float32, 100, 10000), rand(Float32, 100); |
| 167 | +
|
| 168 | +julia> @btime bias_act!(relu, $w, $b); |
| 169 | + min 19.500 μs, mean 21.375 μs (0 allocations) |
| 170 | +
|
| 171 | +julia> @btime relu.($w .+ $b); |
| 172 | + min 17.208 μs, mean 62.826 μs (2 allocations, 390.67 KiB) |
| 173 | +
|
| 174 | +julia> @btime bias_act!(tanh, $w, $b); |
| 175 | + min 63.792 μs, mean 65.052 μs (0 allocations) |
| 176 | +
|
| 177 | +julia> @btime tanh_fast.($w .+ $b); |
| 178 | + min 63.583 μs, mean 102.004 μs (2 allocations, 390.67 KiB) |
| 179 | +
|
| 180 | +julia> using Zygote |
| 181 | +
|
| 182 | +julia> @btime gradient((w,b) -> sum(bias_act!(relu, w, b)), $w, $b); |
| 183 | + min 145.166 μs, mean 150.785 μs (51 allocations, 2.18 KiB) |
| 184 | +
|
| 185 | +julia> @btime gradient((w,b) -> sum(relu.(w .+ b)), $w, $b); |
| 186 | + min 165.583 μs, mean 314.267 μs (32 allocations, 1.15 MiB) |
| 187 | +
|
| 188 | +julia> @btime gradient((w,b) -> sum(bias_act!(tanh, w, b)), $w, $b); |
| 189 | + min 191.917 μs, mean 195.956 μs (51 allocations, 2.18 KiB) |
| 190 | +
|
| 191 | +julia> @btime gradient((w,b) -> sum(tanh_fast.(w .+ b)), $w, $b); |
| 192 | + min 209.458 μs, mean 338.652 μs (32 allocations, 1.15 MiB) |
| 193 | +
|
| 194 | +
|
| 195 | +
|
| 196 | +## Cyclops |
| 197 | +
|
| 198 | +julia> using CUDA # 10x bigger |
| 199 | +
|
| 200 | +julia> cw, cb = CUDA.rand(Float32, 100, 100_00), CUDA.rand(Float32, 100); |
| 201 | +
|
| 202 | +julia> @btime CUDA.@sync bias_act!(relu, $cw, $cb); |
| 203 | + 22.546 μs (27 allocations: 1.45 KiB) |
| 204 | +
|
| 205 | +julia> @btime CUDA.@sync relu.($cw .+ $cb); # faster, that's odd? |
| 206 | + 31.282 μs (38 allocations: 1.81 KiB) |
| 207 | +
|
| 208 | +julia> @btime CUDA.@sync bias_act!(tanh, $cw, $cb); |
| 209 | + 27.030 μs (27 allocations: 1.45 KiB) |
| 210 | +
|
| 211 | +julia> @btime CUDA.@sync tanh_fast.($cw .+ $cb); |
| 212 | + 36.421 μs (38 allocations: 1.81 KiB) |
| 213 | +
|
| 214 | +julia> using Zygote |
| 215 | +
|
| 216 | +julia> @btime CUDA.@sync gradient((w,b) -> sum(bias_act!(relu, w, b)), $cw, $cb); |
| 217 | + 204.507 μs (382 allocations: 18.15 KiB) |
| 218 | +
|
| 219 | +julia> @btime CUDA.@sync gradient((w,b) -> sum(relu.(w .+ b)), $cw, $cb); |
| 220 | + 204.458 μs (409 allocations: 19.19 KiB) |
| 221 | +
|
| 222 | +julia> @btime CUDA.@sync gradient((w,b) -> sum(bias_act!(tanh, w, b)), $cw, $cb); |
| 223 | + 224.545 μs (382 allocations: 18.15 KiB) |
| 224 | +
|
| 225 | +julia> @btime CUDA.@sync gradient((w,b) -> sum(tanh_fast.(w .+ b)), $cw, $cb); |
| 226 | + 204.793 μs (411 allocations: 19.30 KiB) |
| 227 | +
|
| 228 | +
|
| 229 | +=# |
| 230 | + |
| 231 | +#= |
| 232 | +
|
| 233 | +(jl_fuwIi8) pkg> add https://github.com/mcabbott/NNlib.jl/tree/bias_act_23 |
| 234 | +
|
| 235 | +julia> using NNlib, Zygote, BenchmarkTools |
| 236 | +
|
| 237 | +julia> w, b, x = rand(Float32, 50, 50), rand(Float32, 50), randn(Float32, 50, 100); |
| 238 | +
|
| 239 | +julia> @btime bias_act!(relu, $w * $x, $b); |
| 240 | + min 5.243 μs, mean 8.600 μs (2 allocations, 19.61 KiB) |
| 241 | +
|
| 242 | +julia> @btime relu.($w * $x .+ $b); |
| 243 | + min 5.160 μs, mean 10.863 μs (4 allocations, 39.22 KiB) |
| 244 | +
|
| 245 | +julia> @btime gradient((w,x,b) -> sum(abs2, bias_act!(relu, w*x, b)), $w, $x, $b); |
| 246 | + min 21.042 μs, mean 40.476 μs (43 allocations, 89.83 KiB) |
| 247 | +
|
| 248 | +julia> @btime gradient((w,x,b) -> sum(abs2, relu.(w*x .+ b)), $w, $x, $b); |
| 249 | + min 21.542 μs, mean 43.947 μs (41 allocations, 128.91 KiB) |
| 250 | +
|
| 251 | +julia> @btime gradient((w,x) -> sum(abs2, w*x), $w, $x); |
| 252 | + min 14.708 μs, mean 26.450 μs (28 allocations, 69.41 KiB) |
| 253 | +
|
| 254 | +julia> @btime gradient(x -> sum(abs2, x), $x); |
| 255 | + min 1.938 μs, mean 4.160 μs (2 allocations, 19.61 KiB) |
| 256 | +
|
| 257 | +
|
| 258 | +# Cyclops |
| 259 | +
|
| 260 | +julia> @btime bias_act!(relu, $w * $x, $b); |
| 261 | + 24.786 μs (2 allocations: 19.61 KiB) |
| 262 | +
|
| 263 | +julia> @btime relu.($w * $x .+ $b); |
| 264 | + 25.501 μs (4 allocations: 39.22 KiB) |
| 265 | +
|
| 266 | +julia> @btime gradient((w,x,b) -> sum(abs2, bias_act!(relu, w*x, b)), $w, $x, $b); |
| 267 | + 91.847 μs (43 allocations: 89.83 KiB) |
| 268 | +
|
| 269 | +julia> @btime gradient((w,x,b) -> sum(abs2, relu.(w*x .+ b)), $w, $x, $b); |
| 270 | + 98.054 μs (41 allocations: 128.91 KiB) |
| 271 | +
|
| 272 | +julia> @btime gradient((w,x) -> sum(abs2, w*x), $w, $x); |
| 273 | + 80.464 μs (28 allocations: 69.41 KiB) |
| 274 | +
|
| 275 | +julia> @btime gradient(x -> sum(abs2, x), $x); |
| 276 | + 4.604 μs (2 allocations: 19.61 KiB) |
| 277 | +
|
| 278 | +julia> @time using CUDA; @time cu(ones(3)) .+ 1; |
| 279 | +
|
| 280 | +julia> w, b, x = CUDA.rand(Float32, 1000, 1000), CUDA.rand(Float32, 1000), CUDA.rand(Float32, 1000, 1000); |
| 281 | +
|
| 282 | +
|
| 283 | +
|
| 284 | +=# |
| 285 | + |
0 commit comments