@@ -81,7 +81,7 @@ julia> Flux.glorot_uniform(2, 3)
81
81
[1] Glorot, Xavier, and Yoshua Bengio. "Understanding the difficulty of training deep feedforward neural networks." _Proceedings of the thirteenth international conference on artificial intelligence and statistics_. 2010.
82
82
"""
83
83
glorot_uniform (rng:: AbstractRNG , dims... ) = (rand (rng, Float32, dims... ) .- 0.5f0 ) .* sqrt (24.0f0 / sum (nfan (dims... )))
84
- glorot_uniform (dims... ) = glorot_uniform (Random . GLOBAL_RNG , dims... )
84
+ glorot_uniform (dims... ) = glorot_uniform (rng_from_array () , dims... )
85
85
glorot_uniform (rng:: AbstractRNG ) = (dims... ) -> glorot_uniform (rng, dims... )
86
86
87
87
"""
@@ -114,7 +114,7 @@ julia> Flux.glorot_normal(3, 2)
114
114
[1] Glorot, Xavier, and Yoshua Bengio. "Understanding the difficulty of training deep feedforward neural networks." _Proceedings of the thirteenth international conference on artificial intelligence and statistics_. 2010.
115
115
"""
116
116
glorot_normal (rng:: AbstractRNG , dims... ) = randn (rng, Float32, dims... ) .* sqrt (2.0f0 / sum (nfan (dims... )))
117
- glorot_normal (dims... ) = glorot_normal (Random . GLOBAL_RNG , dims... )
117
+ glorot_normal (dims... ) = glorot_normal (rng_from_array () , dims... )
118
118
glorot_normal (rng:: AbstractRNG ) = (dims... ) -> glorot_normal (rng, dims... )
119
119
120
120
"""
@@ -151,7 +151,7 @@ function kaiming_uniform(rng::AbstractRNG, dims...; gain = √2)
151
151
return (rand (rng, Float32, dims... ) .- 0.5f0 ) .* 2 bound
152
152
end
153
153
154
- kaiming_uniform (dims... ; kwargs... ) = kaiming_uniform (Random . GLOBAL_RNG , dims... ; kwargs... )
154
+ kaiming_uniform (dims... ; kwargs... ) = kaiming_uniform (rng_from_array () , dims... ; kwargs... )
155
155
kaiming_uniform (rng:: AbstractRNG ; init_kwargs... ) = (dims... ; kwargs... ) -> kaiming_uniform (rng, dims... ; init_kwargs... , kwargs... )
156
156
157
157
"""
@@ -188,9 +188,58 @@ function kaiming_normal(rng::AbstractRNG, dims...; gain = √2f0)
188
188
return randn (rng, Float32, dims... ) .* std
189
189
end
190
190
191
- kaiming_normal (dims... ; kwargs... ) = kaiming_normal (Random . GLOBAL_RNG , dims... ; kwargs... )
191
+ kaiming_normal (dims... ; kwargs... ) = kaiming_normal (rng_from_array () , dims... ; kwargs... )
192
192
kaiming_normal (rng:: AbstractRNG ; init_kwargs... ) = (dims... ; kwargs... ) -> kaiming_normal (rng, dims... ; init_kwargs... , kwargs... )
193
193
194
+ """
195
+ truncated_normal([rng=GLOBAL_RNG], dims...; mean = 0, std = 1, lo = -2, hi = 2)
196
+
197
+ Return an `Array{Float32}` of size `dims` where each element is drawn from a truncated normal distribution.
198
+ The numbers are distributed like `filter(x -> lo<=x<=hi, mean .+ std .* randn(dims...))`.
199
+
200
+ The values are generated by sampling a Uniform(0, 1) (`rand()`) and then
201
+ applying the inverse CDF of the truncated normal distribution
202
+ (see the references for more info).
203
+ This method works best when `lo ≤ mean ≤ hi`.
204
+
205
+ # Examples
206
+ ```jldoctest
207
+ julia> using Statistics
208
+
209
+ julia> Flux.truncated_normal(3, 4) |> summary
210
+ "3×4 Matrix{Float32}"
211
+
212
+ julia> round.(extrema(Flux.truncated_normal(10^6)); digits=3)
213
+ (-2.0f0, 2.0f0)
214
+
215
+ julia> round(std(Flux.truncated_normal(10^6; lo = -100, hi = 100)))
216
+ 1.0f0
217
+ ```
218
+
219
+ # References
220
+ [1] Burkardt, John. "The Truncated Normal Distribution"
221
+ [PDF](https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf).
222
+ Department of Scientific Computing website.
223
+ """
224
+ function truncated_normal (rng:: AbstractRNG , dims... ; mean = 0 , std = 1 , lo = - 2 , hi = 2 )
225
+ norm_cdf (x) = 0.5 * (1 + erf (x/√ 2 ))
226
+ if (mean < lo - 2 * std) || (mean > hi + 2 * std)
227
+ @warn " Mean is more than 2 std outside the limits in truncated_normal, so the distribution of values may be inaccurate." maxlog= 1
228
+ end
229
+ l = norm_cdf ((lo - mean) / std)
230
+ u = norm_cdf ((hi - mean) / std)
231
+ xs = rand (rng, Float32, dims... )
232
+ broadcast! (xs, xs) do x
233
+ x = x * 2 (u - l) + (2 l - 1 )
234
+ x = erfinv (x)
235
+ x = clamp (x * std * √ 2 + mean, lo, hi)
236
+ end
237
+ return xs
238
+ end
239
+
240
+ truncated_normal (dims... ; kwargs... ) = truncated_normal (rng_from_array (), dims... ; kwargs... )
241
+ truncated_normal (rng:: AbstractRNG ; init_kwargs... ) = (dims... ; kwargs... ) -> truncated_normal (rng, dims... ; init_kwargs... , kwargs... )
242
+
194
243
"""
195
244
orthogonal([rng=GLOBAL_RNG], dims...; gain = 1)
196
245
232
281
* sparse initialization: [`sparse_init`](@ref Flux.sparse_init)
233
282
234
283
# References
284
+
235
285
[1] Saxe, McClelland, Ganguli. "Exact solutions to the nonlinear dynamics of learning in deep linear neural networks", ICLR 2014, https://arxiv.org/abs/1312.6120
236
286
237
287
"""
@@ -254,7 +304,7 @@ function orthogonal(rng::AbstractRNG, d1::Integer, ds::Integer...; kwargs...)
254
304
return reshape (orthogonal (rng, rows, cols; kwargs... ), dims)
255
305
end
256
306
257
- orthogonal (dims:: Integer... ; kwargs... ) = orthogonal (Random . GLOBAL_RNG , dims... ; kwargs... )
307
+ orthogonal (dims:: Integer... ; kwargs... ) = orthogonal (rng_from_array () , dims... ; kwargs... )
258
308
orthogonal (rng:: AbstractRNG ; init_kwargs... ) = (dims:: Integer... ; kwargs... ) -> orthogonal (rng, dims... ; init_kwargs... , kwargs... )
259
309
260
310
"""
@@ -298,7 +348,7 @@ function sparse_init(rng::AbstractRNG, dims...; sparsity, std = 0.01)
298
348
return mapslices (shuffle, sparse_array, dims= 1 )
299
349
end
300
350
301
- sparse_init (dims... ; kwargs... ) = sparse_init (Random . GLOBAL_RNG , dims... ; kwargs... )
351
+ sparse_init (dims... ; kwargs... ) = sparse_init (rng_from_array () , dims... ; kwargs... )
302
352
sparse_init (rng:: AbstractRNG ; init_kwargs... ) = (dims... ; kwargs... ) -> sparse_init (rng, dims... ; init_kwargs... , kwargs... )
303
353
304
354
"""
@@ -382,7 +432,7 @@ function identity_init(dims...; gain=1, shift=0)
382
432
end
383
433
384
434
identity_init (:: AbstractRNG , dims... ; kwargs... ) = identity_init (dims... ; kwargs... )
385
- identity_init (; init_kwargs... ) = identity_init (Random . GLOBAL_RNG ; init_kwargs... )
435
+ identity_init (; init_kwargs... ) = identity_init (rng_from_array () ; init_kwargs... )
386
436
identity_init (rng:: AbstractRNG ; init_kwargs... ) = (args... ;kwargs... ) -> identity_init (rng, args... ; init_kwargs... , kwargs... )
387
437
388
438
ones32 (dims... ) = Base. ones (Float32, dims... )
437
487
438
488
Flatten a model's parameters into a single weight vector.
439
489
440
- julia> m = Chain(Dense(10, 5, σ ), Dense(5, 2), softmax)
441
- Chain(Dense(10, 5, σ ), Dense(5, 2), softmax)
490
+ julia> m = Chain(Dense(10, 5, std ), Dense(5, 2), softmax)
491
+ Chain(Dense(10, 5, std ), Dense(5, 2), softmax)
442
492
443
493
julia> θ, re = destructure(m);
444
494
0 commit comments