Skip to content

Commit 368cec2

Browse files
committed
use NNlib.bias_act
1 parent 2ac01e0 commit 368cec2

File tree

4 files changed

+14
-12
lines changed

4 files changed

+14
-12
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "Flux"
22
uuid = "587475ba-b771-5e3f-ad9e-33799f191a9c"
3-
version = "0.14.4"
3+
version = "0.14.5"
44

55
[deps]
66
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
@@ -42,7 +42,7 @@ Functors = "0.4"
4242
MLUtils = "0.4"
4343
MacroTools = "0.5"
4444
Metal = "0.5"
45-
NNlib = "0.9.1"
45+
NNlib = "0.9.5"
4646
OneHotArrays = "0.2.4"
4747
Optimisers = "0.2.12, 0.3.0"
4848
Preferences = "1"

src/layers/basic.jl

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -169,9 +169,8 @@ end
169169

170170
function (a::Dense)(x::AbstractVecOrMat)
171171
_size_check(a, x, 1 => size(a.weight, 2))
172-
σ = NNlib.fast_act(a.σ, x) # replaces tanh => tanh_fast, etc
173172
xT = _match_eltype(a, x) # fixes Float64 input, etc.
174-
return σ.(a.weight * xT .+ a.bias)
173+
NNlib.bias_act!(a.σ, a.weight * xT, a.bias) # does σ.(W*x .+ b), with fast paths
175174
end
176175

177176
function (a::Dense)(x::AbstractArray)
@@ -446,7 +445,7 @@ function (a::Bilinear)(x::AbstractMatrix, y::AbstractMatrix)
446445
Z = reshape(Wyx, (d_z, :))
447446

448447
# @einsum out[o,s] := σ(Z[o,i] + b[o])
449-
σ.(Z .+ b)
448+
NNlib.bias_act!(σ, Z, b) # σ.(Z .+ b)
450449
end
451450

452451
(a::Bilinear)(x::AbstractVecOrMat) = a(x, x)

src/layers/conv.jl

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -196,10 +196,11 @@ ChainRulesCore.@non_differentiable conv_dims(::Any, ::Any)
196196

197197
function (c::Conv)(x::AbstractArray)
198198
_size_check(c, x, ndims(x)-1 => _channels_in(c))
199-
σ = NNlib.fast_act(c.σ, x)
199+
# σ = NNlib.fast_act(c.σ, x)
200200
cdims = conv_dims(c, x)
201201
xT = _match_eltype(c, x)
202-
σ.(conv(xT, c.weight, cdims) .+ conv_reshape_bias(c))
202+
# σ.(conv(xT, c.weight, cdims) .+ conv_reshape_bias(c))
203+
NNlib.bias_act!(c.σ, conv(xT, c.weight, cdims), conv_reshape_bias(c))
203204
end
204205

205206
_channels_in(l::Conv) = size(l.weight, ndims(l.weight)-1) * l.groups
@@ -331,10 +332,11 @@ ChainRulesCore.@non_differentiable conv_transpose_dims(::Any, ::Any)
331332

332333
function (c::ConvTranspose)(x::AbstractArray)
333334
_size_check(c, x, ndims(x)-1 => _channels_in(c))
334-
σ = NNlib.fast_act(c.σ, x)
335+
# σ = NNlib.fast_act(c.σ, x)
335336
cdims = conv_transpose_dims(c, x)
336337
xT = _match_eltype(c, x)
337-
σ.(∇conv_data(xT, c.weight, cdims) .+ conv_reshape_bias(c))
338+
# σ.(∇conv_data(xT, c.weight, cdims) .+ conv_reshape_bias(c))
339+
NNlib.bias_act!(c.σ, ∇conv_data(xT, c.weight, cdims), conv_reshape_bias(c))
338340
end
339341

340342
function Base.show(io::IO, l::ConvTranspose)
@@ -473,10 +475,11 @@ ChainRulesCore.@non_differentiable crosscor_dims(::Any, ::Any)
473475

474476
function (c::CrossCor)(x::AbstractArray)
475477
_size_check(c, x, ndims(x)-1 => _channels_in(c))
476-
σ = NNlib.fast_act(c.σ, x)
478+
# σ = NNlib.fast_act(c.σ, x)
477479
cdims = crosscor_dims(c, x)
478480
xT = _match_eltype(c, x)
479-
σ.(crosscor(xT, c.weight, cdims) .+ conv_reshape_bias(c))
481+
# σ.(crosscor(xT, c.weight, cdims) .+ conv_reshape_bias(c))
482+
NNlib.bias_act!(c.σ, crosscor(xT, c.weight, cdims), conv_reshape_bias(c))
480483
end
481484

482485
function Base.show(io::IO, l::CrossCor)

src/layers/normalise.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -244,7 +244,7 @@ function _norm_layer_forward(
244244
β = reshape(l.β, affine_shape)
245245

246246
scale = γ ./ sqrt.(σ² .+ eps)
247-
bias = -scale .* μ .+ β
247+
bias = .-scale .* μ .+ β
248248
l.λ.(scale .* x .+ bias)
249249
end
250250

0 commit comments

Comments
 (0)