Skip to content

Commit 5f17f1c

Browse files
authored
Rename Diagonal to Scale (#1927)
* rename Diagonal to Scale * fix a test * types etc * spaces
1 parent 6405ab3 commit 5f17f1c

File tree

8 files changed

+88
-38
lines changed

8 files changed

+88
-38
lines changed

NEWS.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ been removed in favour of MLDatasets.jl.
1111
* Many utily functions and the `DataLoader` are [now provided by MLUtils.jl](https://github.com/FluxML/Flux.jl/pull/1874).
1212
* The DataLoader is now compatible with generic dataset types implementing `MLUtils.numobs` and `MLUtils.getobs`.
1313
* Added [truncated normal initialisation](https://github.com/FluxML/Flux.jl/pull/1877) of weights.
14+
* The `Flux.Diagonal` layer is now called `Scale`, and accepts an activation function.
1415

1516
## v0.12.10
1617
* `Dropout`/`AlphaDropout` now supports [user-specified RNGs](https://github.com/FluxML/Flux.jl/pull/1838)

docs/src/models/layers.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ Maxout
5656
SkipConnection
5757
Parallel
5858
Flux.Bilinear
59-
Flux.Diagonal
59+
Flux.Scale
6060
Flux.Embedding
6161
```
6262

src/deprecations.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,15 @@ function Optimise.update!(x::AbstractArray, x̄)
3939
x .-=
4040
end
4141

42+
function Diagonal(size::Integer...; kw...)
43+
Base.depwarn("Flux.Diagonal is now Flux.Scale, and also allows an activation function.", :Diagonal)
44+
Scale(size...; kw...)
45+
end
46+
function Diagonal(size::Tuple; kw...)
47+
Base.depwarn("Flux.Diagonal is now Flux.Scale, and also allows an activation function.", :Diagonal)
48+
Scale(size...; kw...)
49+
end
50+
4251
# Channel notation: Changed to match Conv, but very softly deprecated!
4352
# Perhaps change to @deprecate for v0.14, but there is no plan to remove these.
4453
Dense(in::Integer, out::Integer, σ = identity; kw...) =

src/layers/basic.jl

Lines changed: 43 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -171,40 +171,69 @@ function Base.show(io::IO, l::Dense)
171171
end
172172

173173
"""
174-
Diagonal(size::Integer...; σ = identity, bias=true, init=ones32)
175-
Diagonal(scale::AbstractArray, [bias, activation])
174+
Scale(size::Integer..., σ=identity; bias=true, init=ones32)
175+
Scale(scale::AbstractArray, [bias, σ])
176176
177-
Create an element-wise linear layer, which performs
177+
Create an element-wise layer, whose forward pass is given by:
178178
179179
y = σ.(scale .* x .+ bias)
180180
181+
This uses `.*` instead of matrix multiplication `*` of [`Dense`](@ref).
182+
181183
The learnable scale & bias are initialised `init(size...)` and `zeros32(size...)`,
182184
with `init=ones32` by default. You may specify the function `init`,
183185
turn off trainable bias with `bias=false`, or provide the array(s) explicitly.
184186
185-
Used by [`LayerNorm`](@ref).
187+
Used by [`LayerNorm`](@ref) with `affine=true`.
188+
189+
# Examples
190+
```jldoctest
191+
julia> a = Flux.Scale(2)
192+
Scale(2) # 4 parameters
193+
194+
julia> Flux.params(a)
195+
Params([Float32[1.0, 1.0], Float32[0.0, 0.0]])
196+
197+
julia> a([1 2 3])
198+
2×3 Matrix{Float32}:
199+
1.0 2.0 3.0
200+
1.0 2.0 3.0
201+
202+
julia> b = Flux.Scale([1 2 3 4], false, abs2)
203+
Scale(1, 4, abs2; bias=false) # 4 parameters
204+
205+
julia> b([1, 10])
206+
2×4 Matrix{Int64}:
207+
1 4 9 16
208+
100 400 900 1600
209+
210+
julia> Flux.params(b)
211+
Params([[1 2 3 4]])
212+
```
186213
"""
187-
struct Diagonal{A<:AbstractArray, B, F}
214+
struct Scale{F, A<:AbstractArray, B}
188215
scale::A
189216
bias::B
190217
σ::F
191-
function Diagonal(W::M, bias = true, σ::F = identity) where {M<:AbstractArray, F}
192-
b = create_bias(W, bias, size(W)...)
193-
new{M, typeof(b), F}(W, b, σ)
218+
function Scale(scale::A, bias::B = true, σ::F = identity) where {A<:AbstractArray, B<:Union{Bool, AbstractArray}, F}
219+
b = create_bias(scale, bias, size(scale)...)
220+
new{F, A, typeof(b)}(scale, b, σ)
194221
end
195222
end
196223

197-
Diagonal(sz::Integer...; σ = identity, bias = true, init = ones32) = Diagonal(init(sz...), bias, σ)
224+
Scale(s1::Integer, s23::Integer...; bias = true, init = ones32, _act = identity) = Scale(init(s1, s23...), bias, _act)
225+
Scale(size_act...; bias = true, init = ones32) = Scale(size_act[1:end-1]...; bias, init, _act = size_act[end])
198226

199-
@functor Diagonal
227+
@functor Scale
200228

201-
function (a::Diagonal)(x::AbstractArray)
229+
function (a::Scale)(x::AbstractArray)
202230
σ = NNlib.fast_act(a.σ, x) # replaces tanh => tanh_fast, etc
203-
return σ === typeof(identity) ? a.scale .* x .+ a.bias : σ.(a.scale .* x .+ a.bias)
231+
σ.(a.scale .* x .+ a.bias)
204232
end
205233

206-
function Base.show(io::IO, l::Diagonal)
207-
print(io, "Diagonal(", join(size(l.scale), ", "))
234+
function Base.show(io::IO, l::Scale)
235+
print(io, "Scale(", join(size(l.scale), ", "))
236+
l.σ == identity || print(io, ", ", l.σ)
208237
l.bias == false && print(io, "; bias=false")
209238
print(io, ")")
210239
end

src/layers/normalise.jl

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ testmode!(m::AlphaDropout, mode=true) =
139139
(m.active = (isnothing(mode) || mode == :auto) ? nothing : !mode; m)
140140

141141
"""
142-
LayerNorm(sz, λ=identity; affine=true, ϵ=1fe-5)
142+
LayerNorm(size..., λ=identity; affine=true, ϵ=1fe-5)
143143
144144
A [normalisation layer](https://arxiv.org/abs/1607.06450) designed to be
145145
used with recurrent hidden states.
@@ -151,10 +151,10 @@ for tuple `sz`, along the first dimension for integer `sz`.
151151
The input is expected to have first dimensions' size equal to `sz`.
152152
153153
If `affine=true` also applies a learnable shift and rescaling
154-
as in the [`Diagonal`](@ref) layer.
154+
using the [`Scale`](@ref) layer.
155155
156156
157-
Se also [`BatchNorm`](@ref), [`InstanceNorm`](@ref), [`GroupNorm`](@ref), and [`normalise`](@ref).
157+
See also [`BatchNorm`](@ref), [`InstanceNorm`](@ref), [`GroupNorm`](@ref), and [`normalise`](@ref).
158158
"""
159159
struct LayerNorm{F,D,T,N}
160160
λ::F
@@ -164,17 +164,19 @@ struct LayerNorm{F,D,T,N}
164164
affine::Bool
165165
end
166166

167-
function LayerNorm(sz, λ=identity; affine::Bool=true, ϵ::Real=1f-5)
168-
diag = affine ? Diagonal(sz...; σ = λ) : Base.Fix1(broadcast, λ)
169-
return LayerNorm(λ, diag, ϵ, Tuple(sz), affine)
167+
function LayerNorm(size::Tuple{Vararg{Int}}, λ=identity; affine::Bool=true, ϵ::Real=1f-5)
168+
diag = affine ? Scale(size..., λ) : λ!=identity ? Base.Fix1(broadcast, λ) : identity
169+
return LayerNorm(λ, diag, ϵ, size, affine)
170170
end
171+
LayerNorm(size::Integer...; kw...) = LayerNorm(Int.(size); kw...)
172+
LayerNorm(size_act...; kw...) = LayerNorm(Int.(size_act[1:end-1]), size_act[end]; kw...)
171173

172174
@functor LayerNorm
173175

174176
(a::LayerNorm)(x) = a.diag(normalise(x, dims=1:length(a.size), ϵ=a.ϵ))
175177

176178
function Base.show(io::IO, l::LayerNorm)
177-
print(io, "LayerNorm($(l.size)")
179+
print(io, "LayerNorm(", join(l.size, ", "))
178180
l.λ === identity || print(io, ", ", l.λ)
179181
hasaffine(l) || print(io, ", affine=false")
180182
print(io, ")")

src/layers/show.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ _show_children(m::Maxout) = m.layers
5555
_show_children(p::Parallel) = (p.connection, p.layers...)
5656

5757
for T in [
58-
:Conv, :ConvTranspose, :CrossCor, :Dense, :Bilinear, :Embedding,
58+
:Conv, :ConvTranspose, :CrossCor, :Dense, :Scale, :Bilinear, :Embedding,
5959
:BatchNorm, :LayerNorm, :InstanceNorm, :GroupNorm,
6060
]
6161
@eval function Base.show(io::IO, m::MIME"text/plain", x::$T)

test/layers/basic.jl

Lines changed: 23 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -87,20 +87,29 @@ import Flux: activations
8787
end
8888
end
8989

90-
@testset "Diagonal" begin
91-
@test length(Flux.Diagonal(10)(randn(10))) == 10
92-
@test length(Flux.Diagonal(10)(randn(1))) == 10
93-
@test length(Flux.Diagonal(10; bias = false)(randn(10))) == 10
94-
@test_throws DimensionMismatch Flux.Diagonal(10)(randn(2))
95-
96-
@test Flux.Diagonal(2)([1 2]) == [1 2; 1 2]
97-
@test Flux.Diagonal(2)([1, 2]) == [1, 2]
98-
@test Flux.Diagonal(2; bias = false)([1 2; 3 4]) == [1 2; 3 4]
99-
100-
@test Flux.Diagonal(2)(rand(2, 3, 4)) |> size == (2, 3, 4)
101-
@test Flux.Diagonal(2, 3;)(rand(2, 3, 4)) |> size == (2, 3, 4)
102-
@test Flux.Diagonal(2, 3, 4; bias = false)(rand(2, 3, 4)) |> size == (2, 3, 4)
103-
@test Flux.Diagonal(2, 3; bias = false)(rand(2, 1, 4)) |> size == (2, 3, 4)
90+
@testset "Scale" begin
91+
@test length(Flux.Scale(10)(randn(10))) == 10
92+
@test length(Flux.Scale(10)(randn(1))) == 10
93+
@test length(Flux.Scale(10; bias = false)(randn(10))) == 10
94+
@test length(Flux.Scale(10, tanh)(randn(10))) == 10
95+
@test_throws DimensionMismatch Flux.Scale(10)(randn(2))
96+
97+
@test Flux.Scale(2)([1 2]) == [1 2; 1 2]
98+
@test Flux.Scale(2)([1, 2]) == [1, 2]
99+
@test Flux.Scale(2; init = randn)([1, 2]) != [1, 2]
100+
@test Flux.Scale(2; bias = false)([1 2; 3 4]) == [1 2; 3 4]
101+
@test Flux.Scale(2, abs2; bias = false, init = ones)([1 2; 3 4]) == [1 4; 9 16]
102+
103+
@test Flux.Scale(2)(rand(2, 3, 4)) |> size == (2, 3, 4)
104+
@test Flux.Scale(2, 3;)(rand(2, 3, 4)) |> size == (2, 3, 4)
105+
@test Flux.Scale(2, 3, 4; bias = false)(rand(2, 3, 4)) |> size == (2, 3, 4)
106+
@test Flux.Scale(2, 3; bias = false)(rand(2, 1, 4)) |> size == (2, 3, 4)
107+
@test Flux.Scale(2, 3, tanh; bias = false, init = zeros)(rand(2, 1, 4)) == zeros(2, 3, 4)
108+
109+
@test_throws MethodError Flux.Scale(1.)
110+
@test_throws MethodError Flux.Scale(1., 2.)
111+
@test_throws Exception Flux.Scale()
112+
@test_throws MethodError Flux.Scale(sin)
104113
end
105114

106115
@testset "Maxout" begin

test/utils.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -457,7 +457,7 @@ end
457457
+),
458458
LayerNorm(8)))
459459
@test length(mod_skip) == 6
460-
@test mod_skip[end] isa Flux.Diagonal
460+
@test mod_skip[end] isa Flux.Scale
461461
end
462462

463463
@testset "Patience triggers" begin

0 commit comments

Comments
 (0)