Skip to content

Commit bc7b64d

Browse files
committed
use NNlib.bias_act
rm comments
1 parent 0a36651 commit bc7b64d

File tree

4 files changed

+7
-10
lines changed

4 files changed

+7
-10
lines changed

Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ MPI = "0.20.19"
5454
MacroTools = "0.5"
5555
NCCL = "0.1.1"
5656
NNlib = "0.9.22"
57+
Metal = "0.5, 1"
5758
OneHotArrays = "0.2.4"
5859
Optimisers = "0.3.3"
5960
Preferences = "1"

src/layers/basic.jl

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -170,9 +170,8 @@ end
170170

171171
function (a::Dense)(x::AbstractVecOrMat)
172172
_size_check(a, x, 1 => size(a.weight, 2))
173-
σ = NNlib.fast_act(a.σ, x) # replaces tanh => tanh_fast, etc
174173
xT = _match_eltype(a, x) # fixes Float64 input, etc.
175-
return σ.(a.weight * xT .+ a.bias)
174+
NNlib.bias_act!(a.σ, a.weight * xT, a.bias) # does σ.(W*x .+ b), with fast paths
176175
end
177176

178177
function (a::Dense)(x::AbstractArray)
@@ -450,7 +449,7 @@ function (a::Bilinear)(x::AbstractMatrix, y::AbstractMatrix)
450449
Z = reshape(Wyx, (d_z, :))
451450

452451
# @einsum out[o,s] := σ(Z[o,i] + b[o])
453-
σ.(Z .+ b)
452+
NNlib.bias_act!(σ, Z, b) # σ.(Z .+ b)
454453
end
455454

456455
(a::Bilinear)(x::AbstractVecOrMat) = a(x, x)

src/layers/conv.jl

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -196,10 +196,9 @@ ChainRulesCore.@non_differentiable conv_dims(::Any, ::Any)
196196

197197
function (c::Conv)(x::AbstractArray)
198198
_conv_size_check(c, x)
199-
σ = NNlib.fast_act(c.σ, x)
200199
cdims = conv_dims(c, x)
201200
xT = _match_eltype(c, x)
202-
σ.(conv(xT, c.weight, cdims) .+ conv_reshape_bias(c))
201+
NNlib.bias_act!(c.σ, conv(xT, c.weight, cdims), conv_reshape_bias(c))
203202
end
204203

205204
_channels_in(l::Conv) = size(l.weight, ndims(l.weight)-1) * l.groups
@@ -350,10 +349,9 @@ ChainRulesCore.@non_differentiable conv_transpose_dims(::Any, ::Any)
350349

351350
function (c::ConvTranspose)(x::AbstractArray)
352351
_conv_size_check(c, x)
353-
σ = NNlib.fast_act(c.σ, x)
354352
cdims = conv_transpose_dims(c, x)
355353
xT = _match_eltype(c, x)
356-
σ.(∇conv_data(xT, c.weight, cdims) .+ conv_reshape_bias(c))
354+
NNlib.bias_act!(c.σ, ∇conv_data(xT, c.weight, cdims), conv_reshape_bias(c))
357355
end
358356

359357
function Base.show(io::IO, l::ConvTranspose)
@@ -493,10 +491,9 @@ ChainRulesCore.@non_differentiable crosscor_dims(::Any, ::Any)
493491

494492
function (c::CrossCor)(x::AbstractArray)
495493
_conv_size_check(c, x)
496-
σ = NNlib.fast_act(c.σ, x)
497494
cdims = crosscor_dims(c, x)
498495
xT = _match_eltype(c, x)
499-
σ.(crosscor(xT, c.weight, cdims) .+ conv_reshape_bias(c))
496+
NNlib.bias_act!(c.σ, crosscor(xT, c.weight, cdims), conv_reshape_bias(c))
500497
end
501498

502499
function Base.show(io::IO, l::CrossCor)

src/layers/normalise.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -245,7 +245,7 @@ function _norm_layer_forward(
245245
β = reshape(l.β, affine_shape)
246246

247247
scale = γ ./ sqrt.(σ² .+ eps)
248-
bias = -scale .* μ .+ β
248+
bias = .-scale .* μ .+ β
249249
l.λ.(scale .* x .+ bias)
250250
end
251251

0 commit comments

Comments
 (0)