Skip to content

Commit 4ab8343

Browse files
committed
use NNlib.bias_act
rm comments
1 parent 8654721 commit 4ab8343

File tree

4 files changed

+7
-11
lines changed

4 files changed

+7
-11
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ Functors = "0.4"
4646
MLUtils = "0.4"
4747
MacroTools = "0.5"
4848
Metal = "0.5, 1"
49-
NNlib = "0.9.1"
49+
NNlib = "0.9.5"
5050
OneHotArrays = "0.2.4"
5151
Optimisers = "0.3.2"
5252
Preferences = "1"

src/layers/basic.jl

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -169,9 +169,8 @@ end
169169

170170
function (a::Dense)(x::AbstractVecOrMat)
171171
_size_check(a, x, 1 => size(a.weight, 2))
172-
σ = NNlib.fast_act(a.σ, x) # replaces tanh => tanh_fast, etc
173172
xT = _match_eltype(a, x) # fixes Float64 input, etc.
174-
return σ.(a.weight * xT .+ a.bias)
173+
NNlib.bias_act!(a.σ, a.weight * xT, a.bias) # does σ.(W*x .+ b), with fast paths
175174
end
176175

177176
function (a::Dense)(x::AbstractArray)
@@ -446,7 +445,7 @@ function (a::Bilinear)(x::AbstractMatrix, y::AbstractMatrix)
446445
Z = reshape(Wyx, (d_z, :))
447446

448447
# @einsum out[o,s] := σ(Z[o,i] + b[o])
449-
σ.(Z .+ b)
448+
NNlib.bias_act!(σ, Z, b) # σ.(Z .+ b)
450449
end
451450

452451
(a::Bilinear)(x::AbstractVecOrMat) = a(x, x)

src/layers/conv.jl

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -196,10 +196,9 @@ ChainRulesCore.@non_differentiable conv_dims(::Any, ::Any)
196196

197197
function (c::Conv)(x::AbstractArray)
198198
_conv_size_check(c, x)
199-
σ = NNlib.fast_act(c.σ, x)
200199
cdims = conv_dims(c, x)
201200
xT = _match_eltype(c, x)
202-
σ.(conv(xT, c.weight, cdims) .+ conv_reshape_bias(c))
201+
NNlib.bias_act!(c.σ, conv(xT, c.weight, cdims), conv_reshape_bias(c))
203202
end
204203

205204
_channels_in(l::Conv) = size(l.weight, ndims(l.weight)-1) * l.groups
@@ -332,10 +331,9 @@ ChainRulesCore.@non_differentiable conv_transpose_dims(::Any, ::Any)
332331

333332
function (c::ConvTranspose)(x::AbstractArray)
334333
_conv_size_check(c, x)
335-
σ = NNlib.fast_act(c.σ, x)
336334
cdims = conv_transpose_dims(c, x)
337335
xT = _match_eltype(c, x)
338-
σ.(∇conv_data(xT, c.weight, cdims) .+ conv_reshape_bias(c))
336+
NNlib.bias_act!(c.σ, ∇conv_data(xT, c.weight, cdims), conv_reshape_bias(c))
339337
end
340338

341339
function Base.show(io::IO, l::ConvTranspose)
@@ -474,10 +472,9 @@ ChainRulesCore.@non_differentiable crosscor_dims(::Any, ::Any)
474472

475473
function (c::CrossCor)(x::AbstractArray)
476474
_conv_size_check(c, x)
477-
σ = NNlib.fast_act(c.σ, x)
478475
cdims = crosscor_dims(c, x)
479476
xT = _match_eltype(c, x)
480-
σ.(crosscor(xT, c.weight, cdims) .+ conv_reshape_bias(c))
477+
NNlib.bias_act!(c.σ, crosscor(xT, c.weight, cdims), conv_reshape_bias(c))
481478
end
482479

483480
function Base.show(io::IO, l::CrossCor)

src/layers/normalise.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,7 @@ function _norm_layer_forward(
246246
β = reshape(l.β, affine_shape)
247247

248248
scale = γ ./ sqrt.(σ² .+ eps)
249-
bias = -scale .* μ .+ β
249+
bias = .-scale .* μ .+ β
250250
l.λ.(scale .* x .+ bias)
251251
end
252252

0 commit comments

Comments
 (0)