Skip to content

Commit ebda582

Browse files
Merge pull request #1925 from theabhirath/diag-act
Allow activation function for Diagonal
2 parents a0b804a + 9ab71f7 commit ebda582

File tree

3 files changed

+21
-24
lines changed

3 files changed

+21
-24
lines changed

src/layers/basic.jl

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -156,9 +156,8 @@ end
156156
@functor Dense
157157

158158
function (a::Dense)(x::AbstractVecOrMat)
159-
W, b = a.weight, a.bias
160159
σ = NNlib.fast_act(a.σ, x) # replaces tanh => tanh_fast, etc
161-
return σ.(W*x .+ b)
160+
return σ.(a.weight * x .+ a.bias)
162161
end
163162

164163
(a::Dense)(x::AbstractArray) =
@@ -172,35 +171,37 @@ function Base.show(io::IO, l::Dense)
172171
end
173172

174173
"""
175-
Diagonal(size::Integer...; bias=true, init=ones32)
176-
Diagonal(scale::AbstractArray, [bias])
174+
Diagonal(size::Integer...; σ = identity, bias=true, init=ones32)
175+
Diagonal(scale::AbstractArray, [bias, activation])
177176
178177
Create an element-wise linear layer, which performs
179178
180-
y = scale .* x .+ bias
179+
y = σ.(scale .* x .+ bias)
181180
182-
with no activation function.
183-
184181
The learnable scale & bias are initialised `init(size...)` and `zeros32(size...)`,
185182
with `init=ones32` by default. You may specify the function `init`,
186183
turn off trainable bias with `bias=false`, or provide the array(s) explicitly.
187184
188185
Used by [`LayerNorm`](@ref).
189186
"""
190-
struct Diagonal{A<:AbstractArray, B}
187+
struct Diagonal{A<:AbstractArray, B, F}
191188
scale::A
192189
bias::B
193-
function Diagonal(W::M, bias = true) where M<:AbstractArray
190+
σ::F
191+
function Diagonal(W::M, bias = true, σ::F = identity) where {M<:AbstractArray, F}
194192
b = create_bias(W, bias, size(W)...)
195-
new{M, typeof(b)}(W, b)
193+
new{M, typeof(b), F}(W, b, σ)
196194
end
197195
end
198196

199-
Diagonal(sz::Integer...; bias = true, init = ones32) = Diagonal(init(sz...), bias)
197+
Diagonal(sz::Integer...; σ = identity, bias = true, init = ones32) = Diagonal(init(sz...), bias, σ)
200198

201199
@functor Diagonal
202200

203-
(a::Diagonal)(x) = a.scale .* x .+ a.bias
201+
function (a::Diagonal)(x::AbstractArray)
202+
σ = NNlib.fast_act(a.σ, x) # replaces tanh => tanh_fast, etc
203+
return σ === typeof(identity) ? a.scale .* x .+ a.bias : σ.(a.scale .* x .+ a.bias)
204+
end
204205

205206
function Base.show(io::IO, l::Diagonal)
206207
print(io, "Diagonal(", join(size(l.scale), ", "))
@@ -212,7 +213,7 @@ end
212213
Maxout(layers...)
213214
Maxout(f, n_alts)
214215
215-
This contains a number of internal layes, each of which receives the same input.
216+
This contains a number of internal layers, each of which receives the same input.
216217
Its output is the elementwise maximum of the the internal layers' outputs.
217218
218219
Instead of defining layers individually, you can provide a zero-argument function

src/layers/normalise.jl

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -165,16 +165,13 @@ struct LayerNorm{F,D,T,N}
165165
end
166166

167167
function LayerNorm(sz, λ=identity; affine::Bool=true, ϵ::Real=1f-5)
168-
diag = affine ? Diagonal(sz...) : identity
168+
diag = affine ? Diagonal(sz...; σ = λ) : Base.Fix1(broadcast, λ)
169169
return LayerNorm(λ, diag, ϵ, Tuple(sz), affine)
170170
end
171171

172172
@functor LayerNorm
173173

174-
function (a::LayerNorm)(x)
175-
x = a.diag(normalise(x, dims=1:length(a.size), ϵ=a.ϵ))
176-
return a.λ === identity ? x : a.λ.(x)
177-
end
174+
(a::LayerNorm)(x) = a.diag(normalise(x, dims=1:length(a.size), ϵ=a.ϵ))
178175

179176
function Base.show(io::IO, l::LayerNorm)
180177
print(io, "LayerNorm($(l.size)")

test/layers/basic.jl

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -89,19 +89,18 @@ import Flux: activations
8989

9090
@testset "Diagonal" begin
9191
@test length(Flux.Diagonal(10)(randn(10))) == 10
92-
@test length(Flux.Diagonal(10)(1)) == 10
9392
@test length(Flux.Diagonal(10)(randn(1))) == 10
9493
@test length(Flux.Diagonal(10; bias = false)(randn(10))) == 10
9594
@test_throws DimensionMismatch Flux.Diagonal(10)(randn(2))
9695

9796
@test Flux.Diagonal(2)([1 2]) == [1 2; 1 2]
98-
@test Flux.Diagonal(2)([1,2]) == [1,2]
97+
@test Flux.Diagonal(2)([1, 2]) == [1, 2]
9998
@test Flux.Diagonal(2; bias = false)([1 2; 3 4]) == [1 2; 3 4]
10099

101-
@test Flux.Diagonal(2)(rand(2,3,4)) |> size == (2, 3, 4)
102-
@test Flux.Diagonal(2,3)(rand(2,3,4)) |> size == (2, 3, 4)
103-
@test Flux.Diagonal(2, 3, 4; bias = false)(rand(2,3,4)) |> size == (2, 3, 4)
104-
@test Flux.Diagonal(2, 3; bias = false)(rand(2,1,4)) |> size == (2, 3, 4)
100+
@test Flux.Diagonal(2)(rand(2, 3, 4)) |> size == (2, 3, 4)
101+
@test Flux.Diagonal(2, 3;)(rand(2, 3, 4)) |> size == (2, 3, 4)
102+
@test Flux.Diagonal(2, 3, 4; bias = false)(rand(2, 3, 4)) |> size == (2, 3, 4)
103+
@test Flux.Diagonal(2, 3; bias = false)(rand(2, 1, 4)) |> size == (2, 3, 4)
105104
end
106105

107106
@testset "Maxout" begin

0 commit comments

Comments
 (0)