Skip to content

Commit c2c6ab7

Browse files
committed
Allow activation function for Diagonal
1 parent 57beb23 commit c2c6ab7

File tree

2 files changed

+18
-19
lines changed

2 files changed

+18
-19
lines changed

src/layers/basic.jl

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -138,13 +138,13 @@ julia> Flux.params(d1) # no trainable bias
138138
Params([[1.0 1.0 … 1.0 1.0; 1.0 1.0 … 1.0 1.0]])
139139
```
140140
"""
141-
struct Dense{F, M<:AbstractMatrix, B}
141+
struct Dense{M<:AbstractMatrix, B, F}
142142
weight::M
143143
bias::B
144144
σ::F
145145
function Dense(W::M, bias = true, σ::F = identity) where {M<:AbstractMatrix, F}
146146
b = create_bias(W, bias, size(W,1))
147-
new{F,M,typeof(b)}(W, b, σ)
147+
new{M, typeof(b), F}(W, b, σ)
148148
end
149149
end
150150

@@ -158,7 +158,7 @@ end
158158
function (a::Dense)(x::AbstractVecOrMat)
159159
W, b = a.weight, a.bias
160160
σ = NNlib.fast_act(a.σ, x) # replaces tanh => tanh_fast, etc
161-
return σ.(W*x .+ b)
161+
return σ.(W * x .+ b)
162162
end
163163

164164
(a::Dense)(x::AbstractArray) =
@@ -172,35 +172,37 @@ function Base.show(io::IO, l::Dense)
172172
end
173173

174174
"""
175-
Diagonal(size::Integer...; bias=true, init=ones32)
176-
Diagonal(scale::AbstractArray, [bias])
175+
Diagonal(size::Integer...; σ = identity, bias=true, init=ones32)
176+
Diagonal(scale::AbstractArray, [bias, activation])
177177
178178
Create an element-wise linear layer, which performs
179179
180-
y = scale .* x .+ bias
180+
y = σ.(scale .* x .+ bias)
181181
182-
with no activation function.
183-
184182
The learnable scale & bias are initialised `init(size...)` and `zeros32(size...)`,
185183
with `init=ones32` by default. You may specify the function `init`,
186184
turn off trainable bias with `bias=false`, or provide the array(s) explicitly.
187185
188186
Used by [`LayerNorm`](@ref).
189187
"""
190-
struct Diagonal{A<:AbstractArray, B}
188+
struct Diagonal{A<:AbstractArray, B, F}
191189
scale::A
192190
bias::B
193-
function Diagonal(W::M, bias = true) where M<:AbstractArray
191+
σ::F
192+
function Diagonal(W::M, bias = true, σ::F = identity) where {M<:AbstractArray, F}
194193
b = create_bias(W, bias, size(W)...)
195-
new{M, typeof(b)}(W, b)
194+
new{M, typeof(b), F}(W, b, σ)
196195
end
197196
end
198197

199-
Diagonal(sz::Integer...; bias = true, init = ones32) = Diagonal(init(sz...), bias)
198+
Diagonal(sz::Integer...; σ = identity, bias = true, init = ones32) = Diagonal(init(sz...), bias, σ)
200199

201200
@functor Diagonal
202201

203-
(a::Diagonal)(x) = a.scale .* x .+ a.bias
202+
function (a::Diagonal)(x::AbstractArray)
203+
σ = NNlib.fast_act(a.σ, x) # replaces tanh => tanh_fast, etc
204+
return σ == typeof(identity) ? a.scale .* x .+ a.bias : σ.(a.scale .* x .+ a.bias)
205+
end
204206

205207
function Base.show(io::IO, l::Diagonal)
206208
print(io, "Diagonal(", join(size(l.scale), ", "))
@@ -212,7 +214,7 @@ end
212214
Maxout(layers...)
213215
Maxout(f, n_alts)
214216
215-
This contains a number of internal layes, each of which receives the same input.
217+
This contains a number of internal layers, each of which receives the same input.
216218
Its output is the elementwise maximum of the the internal layers' outputs.
217219
218220
Instead of defining layers individually, you can provide a zero-argument function

src/layers/normalise.jl

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -165,16 +165,13 @@ struct LayerNorm{F,D,T,N}
165165
end
166166

167167
function LayerNorm(sz, λ=identity; affine::Bool=true, ϵ::Real=1f-5)
168-
diag = affine ? Diagonal(sz...) : identity
168+
diag = affine ? Diagonal(sz...; σ = λ) : Base.Fix1(broadcast, λ)
169169
return LayerNorm(λ, diag, ϵ, Tuple(sz), affine)
170170
end
171171

172172
@functor LayerNorm
173173

174-
function (a::LayerNorm)(x)
175-
x = a.diag(normalise(x, dims=1:length(a.size), ϵ=a.ϵ))
176-
return a.λ === identity ? x : a.λ.(x)
177-
end
174+
(a::LayerNorm)(x) = a.diag(normalise(x, dims=1:length(a.size), ϵ=a.ϵ))
178175

179176
function Base.show(io::IO, l::LayerNorm)
180177
print(io, "LayerNorm($(l.size)")

0 commit comments

Comments
 (0)