Skip to content

Commit 9f56f51

Browse files
Support for lecun normal weight initialization (#2311)
* Support for lecun normal weight initialization * add test * add to docs * Update utils.jl * fixup * fix rtol * doctests --------- Co-authored-by: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
1 parent 32db5d4 commit 9f56f51

File tree

4 files changed

+51
-2
lines changed

4 files changed

+51
-2
lines changed

docs/src/reference/utilities.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ Flux.glorot_normal
3535
Flux.kaiming_uniform
3636
Flux.kaiming_normal
3737
Flux.truncated_normal
38+
Flux.lecun_normal
3839
Flux.orthogonal
3940
Flux.sparse_init
4041
Flux.identity_init

src/Flux.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,7 @@ export MPIBackend, NCCLBackend, DistributedUtils
119119
kaiming_uniform,
120120
kaiming_normal,
121121
truncated_normal,
122+
lecun_normal,
122123
orthogonal,
123124
sparse_init,
124125
identity_init,

src/utils.jl

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,48 @@ truncated_normal(rng::AbstractRNG=default_rng(); init_kwargs...) = (dims...; kwa
248248

249249
ChainRulesCore.@non_differentiable truncated_normal(::Any...)
250250

251+
"""
252+
lecun_normal([rng], size...) -> Array
253+
lecun_normal([rng]; kw...) -> Function
254+
255+
Return an `Array{Float32}` of the given `size` containing random numbers drawn from a truncated normal
256+
distribution centered on 0 with stddev `sqrt(1 / fan_in)`, where `fan_in` is the number of input units
257+
in the weight tensor.
258+
259+
# Examples
260+
```jldoctest; setup = :(using Random; Random.seed!(0))
261+
julia> using Statistics
262+
263+
julia> round(std(Flux.lecun_normal(10, 1000)), digits=3)
264+
0.032f0
265+
266+
julia> round(std(Flux.lecun_normal(1000, 10)), digits=3)
267+
0.32f0
268+
269+
julia> round(std(Flux.lecun_normal(1000, 1000)), digits=3)
270+
0.032f0
271+
272+
julia> Dense(10 => 1000, selu; init = Flux.lecun_normal())
273+
Dense(10 => 1000, selu) # 11_000 parameters
274+
275+
julia> round(std(ans.weight), digits=3)
276+
0.313f0
277+
```
278+
279+
# References
280+
281+
[1] Lecun, Yann, et al. "Efficient backprop." Neural networks: Tricks of the trade. Springer, Berlin, Heidelberg, 2012. 9-48.
282+
"""
283+
function lecun_normal(rng::AbstractRNG, dims::Integer...; gain::Real=1)
284+
std = Float32(gain)*sqrt(1.0f0 / first(nfan(dims...))) # calculates the standard deviation based on the `fan_in` value
285+
return truncated_normal(rng, dims...; mean=0, std=std)
286+
end
287+
288+
lecun_normal(dims::Integer...; kwargs...) = lecun_normal(default_rng(), dims...; kwargs...)
289+
lecun_normal(rng::AbstractRNG=default_rng(); init_kwargs...) = (dims...; kwargs...) -> lecun_normal(rng, dims...; init_kwargs..., kwargs...)
290+
291+
ChainRulesCore.@non_differentiable lecun_normal(::Any...)
292+
251293
"""
252294
orthogonal([rng], size...; gain = 1) -> Array
253295
orthogonal([rng]; kw...) -> Function

test/utils.jl

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
using Flux
22
using Flux: throttle, nfan, glorot_uniform, glorot_normal,
3-
kaiming_normal, kaiming_uniform, orthogonal, truncated_normal,
3+
kaiming_normal, kaiming_uniform, orthogonal, truncated_normal, lecun_normal,
44
sparse_init, identity_init, unstack, batch, unbatch,
55
unsqueeze, params, loadmodel!
66
using MLUtils
@@ -75,7 +75,7 @@ end
7575
kaiming_uniform, kaiming_normal,
7676
orthogonal,
7777
sparse_init,
78-
truncated_normal,
78+
truncated_normal, lecun_normal,
7979
identity_init,
8080
Flux.rand32,
8181
Flux.randn32,
@@ -192,6 +192,11 @@ end
192192
end
193193
end
194194

195+
@testset "lecun_normal" begin
196+
@test std(Flux.lecun_normal(10, 1000)) 0.032f0 rtol=0.1
197+
@test std(Flux.lecun_normal(1000, 10)) 0.317f0 rtol=0.1
198+
end
199+
195200
@testset "Partial application" begin
196201
partial_ku = kaiming_uniform(gain=1e9)
197202
@test maximum(partial_ku(8, 8)) > 1e9 / 2

0 commit comments

Comments
 (0)