Skip to content

Commit af1e5fc

Browse files
Use NNlib.bias_act! (#2327)
* use NNlib.bias_act rm comments * mend * add to news * Update src/layers/basic.jl Co-authored-by: Carlo Lucibello <carlo.lucibello@gmail.com> --------- Co-authored-by: Carlo Lucibello <carlo.lucibello@gmail.com>
1 parent c86580b commit af1e5fc

File tree

4 files changed

+7
-10
lines changed

4 files changed

+7
-10
lines changed

NEWS.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ See also [github's page](https://github.com/FluxML/Flux.jl/releases) for a compl
1111
* The `Flux.Optimise` module has been deprecated in favor of the Optimisers.jl package.
1212
Now Flux re-exports the optimisers from Optimisers.jl. Most users will be uneffected by this change.
1313
The module is still available for now, but will be removed in a future release.
14+
* Most Flux layers will [re-use memory via `NNlib.bias_act!`](https://github.com/FluxML/Flux.jl/pull/2327), when possible.
1415

1516
## v0.14.22
1617
* Data movement between devices is now provided by [MLDataDevices.jl](https://github.com/LuxDL/MLDataDevices.jl).

src/layers/basic.jl

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -186,9 +186,8 @@ end
186186

187187
function (a::Dense)(x::AbstractVecOrMat)
188188
_size_check(a, x, 1 => size(a.weight, 2))
189-
σ = NNlib.fast_act(a.σ, x) # replaces tanh => tanh_fast, etc
190189
xT = _match_eltype(a, x) # fixes Float64 input, etc.
191-
return σ.(a.weight * xT .+ a.bias)
190+
return NNlib.bias_act!(a.σ, a.weight * xT, a.bias) # does σ.(W*x .+ b), with fast paths
192191
end
193192

194193
function (a::Dense)(x::AbstractArray)
@@ -466,7 +465,7 @@ function (a::Bilinear)(x::AbstractMatrix, y::AbstractMatrix)
466465
Z = reshape(Wyx, (d_z, :))
467466

468467
# @einsum out[o,s] := σ(Z[o,i] + b[o])
469-
σ.(Z .+ b)
468+
NNlib.bias_act!(σ, Z, b) # σ.(Z .+ b)
470469
end
471470

472471
(a::Bilinear)(x::AbstractVecOrMat) = a(x, x)

src/layers/conv.jl

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -196,10 +196,9 @@ ChainRulesCore.@non_differentiable conv_dims(::Any, ::Any)
196196

197197
function (c::Conv)(x::AbstractArray)
198198
_conv_size_check(c, x)
199-
σ = NNlib.fast_act(c.σ, x)
200199
cdims = conv_dims(c, x)
201200
xT = _match_eltype(c, x)
202-
σ.(conv(xT, c.weight, cdims) .+ conv_reshape_bias(c))
201+
NNlib.bias_act!(c.σ, conv(xT, c.weight, cdims), conv_reshape_bias(c))
203202
end
204203

205204
_channels_in(l::Conv) = size(l.weight, ndims(l.weight)-1) * l.groups
@@ -350,10 +349,9 @@ ChainRulesCore.@non_differentiable conv_transpose_dims(::Any, ::Any)
350349

351350
function (c::ConvTranspose)(x::AbstractArray)
352351
_conv_size_check(c, x)
353-
σ = NNlib.fast_act(c.σ, x)
354352
cdims = conv_transpose_dims(c, x)
355353
xT = _match_eltype(c, x)
356-
σ.(∇conv_data(xT, c.weight, cdims) .+ conv_reshape_bias(c))
354+
NNlib.bias_act!(c.σ, ∇conv_data(xT, c.weight, cdims), conv_reshape_bias(c))
357355
end
358356

359357
function Base.show(io::IO, l::ConvTranspose)
@@ -493,10 +491,9 @@ ChainRulesCore.@non_differentiable crosscor_dims(::Any, ::Any)
493491

494492
function (c::CrossCor)(x::AbstractArray)
495493
_conv_size_check(c, x)
496-
σ = NNlib.fast_act(c.σ, x)
497494
cdims = crosscor_dims(c, x)
498495
xT = _match_eltype(c, x)
499-
σ.(crosscor(xT, c.weight, cdims) .+ conv_reshape_bias(c))
496+
NNlib.bias_act!(c.σ, crosscor(xT, c.weight, cdims), conv_reshape_bias(c))
500497
end
501498

502499
function Base.show(io::IO, l::CrossCor)

src/layers/normalise.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -245,7 +245,7 @@ function _norm_layer_forward(
245245
β = reshape(l.β, affine_shape)
246246

247247
scale = γ ./ sqrt.(σ² .+ eps)
248-
bias = -scale .* μ .+ β
248+
bias = .-scale .* μ .+ β
249249
l.λ.(scale .* x .+ bias)
250250
end
251251

0 commit comments

Comments
 (0)