Skip to content

Commit cbd33fa

Browse files
committed
use NNlib.bias_act
1 parent eb6492c commit cbd33fa

File tree

4 files changed

+13
-11
lines changed

4 files changed

+13
-11
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ Functors = "0.4"
4444
MLUtils = "0.4"
4545
MacroTools = "0.5"
4646
Metal = "0.5, 1"
47-
NNlib = "0.9.1"
47+
NNlib = "0.9.5"
4848
OneHotArrays = "0.2.4"
4949
Optimisers = "0.3.2"
5050
Preferences = "1"

src/layers/basic.jl

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -169,9 +169,8 @@ end
169169

170170
function (a::Dense)(x::AbstractVecOrMat)
171171
_size_check(a, x, 1 => size(a.weight, 2))
172-
σ = NNlib.fast_act(a.σ, x) # replaces tanh => tanh_fast, etc
173172
xT = _match_eltype(a, x) # fixes Float64 input, etc.
174-
return σ.(a.weight * xT .+ a.bias)
173+
NNlib.bias_act!(a.σ, a.weight * xT, a.bias) # does σ.(W*x .+ b), with fast paths
175174
end
176175

177176
function (a::Dense)(x::AbstractArray)
@@ -446,7 +445,7 @@ function (a::Bilinear)(x::AbstractMatrix, y::AbstractMatrix)
446445
Z = reshape(Wyx, (d_z, :))
447446

448447
# @einsum out[o,s] := σ(Z[o,i] + b[o])
449-
σ.(Z .+ b)
448+
NNlib.bias_act!(σ, Z, b) # σ.(Z .+ b)
450449
end
451450

452451
(a::Bilinear)(x::AbstractVecOrMat) = a(x, x)

src/layers/conv.jl

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -196,10 +196,11 @@ ChainRulesCore.@non_differentiable conv_dims(::Any, ::Any)
196196

197197
function (c::Conv)(x::AbstractArray)
198198
_size_check(c, x, ndims(x)-1 => _channels_in(c))
199-
σ = NNlib.fast_act(c.σ, x)
199+
# σ = NNlib.fast_act(c.σ, x)
200200
cdims = conv_dims(c, x)
201201
xT = _match_eltype(c, x)
202-
σ.(conv(xT, c.weight, cdims) .+ conv_reshape_bias(c))
202+
# σ.(conv(xT, c.weight, cdims) .+ conv_reshape_bias(c))
203+
NNlib.bias_act!(c.σ, conv(xT, c.weight, cdims), conv_reshape_bias(c))
203204
end
204205

205206
_channels_in(l::Conv) = size(l.weight, ndims(l.weight)-1) * l.groups
@@ -332,10 +333,11 @@ ChainRulesCore.@non_differentiable conv_transpose_dims(::Any, ::Any)
332333

333334
function (c::ConvTranspose)(x::AbstractArray)
334335
_size_check(c, x, ndims(x)-1 => _channels_in(c))
335-
σ = NNlib.fast_act(c.σ, x)
336+
# σ = NNlib.fast_act(c.σ, x)
336337
cdims = conv_transpose_dims(c, x)
337338
xT = _match_eltype(c, x)
338-
σ.(∇conv_data(xT, c.weight, cdims) .+ conv_reshape_bias(c))
339+
# σ.(∇conv_data(xT, c.weight, cdims) .+ conv_reshape_bias(c))
340+
NNlib.bias_act!(c.σ, ∇conv_data(xT, c.weight, cdims), conv_reshape_bias(c))
339341
end
340342

341343
function Base.show(io::IO, l::ConvTranspose)
@@ -474,10 +476,11 @@ ChainRulesCore.@non_differentiable crosscor_dims(::Any, ::Any)
474476

475477
function (c::CrossCor)(x::AbstractArray)
476478
_size_check(c, x, ndims(x)-1 => _channels_in(c))
477-
σ = NNlib.fast_act(c.σ, x)
479+
# σ = NNlib.fast_act(c.σ, x)
478480
cdims = crosscor_dims(c, x)
479481
xT = _match_eltype(c, x)
480-
σ.(crosscor(xT, c.weight, cdims) .+ conv_reshape_bias(c))
482+
# σ.(crosscor(xT, c.weight, cdims) .+ conv_reshape_bias(c))
483+
NNlib.bias_act!(c.σ, crosscor(xT, c.weight, cdims), conv_reshape_bias(c))
481484
end
482485

483486
function Base.show(io::IO, l::CrossCor)

src/layers/normalise.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,7 @@ function _norm_layer_forward(
246246
β = reshape(l.β, affine_shape)
247247

248248
scale = γ ./ sqrt.(σ² .+ eps)
249-
bias = -scale .* μ .+ β
249+
bias = .-scale .* μ .+ β
250250
l.λ.(scale .* x .+ bias)
251251
end
252252

0 commit comments

Comments
 (0)