@@ -43,16 +43,24 @@ The current defaults are:
43
43
- Julia version is < 1.7: `Random.GLOBAL_RNG`
44
44
- Julia version is >= 1.7: `Random.default_rng()`
45
45
"""
46
- _rng_from_array (:: AbstractArray ) = _rng_from_array ()
47
- _rng_from_array (:: CuArray ) = CUDA. default_rng ()
46
+ rng_from_array (:: AbstractArray ) = default_rng_value ()
47
+ rng_from_array (:: CuArray ) = CUDA. default_rng ()
48
+
48
49
if VERSION >= v " 1.7"
49
- _rng_from_array () = Random. default_rng ()
50
+ @doc """
51
+ default_rng_value()
52
+
53
+ Create an instance of the default RNG depending on Julia's version.
54
+ - Julia version is < 1.7: `Random.GLOBAL_RNG`
55
+ - Julia version is >= 1.7: `Random.default_rng()`
56
+ """
57
+ default_rng_value () = Random. default_rng ()
50
58
else
51
- _rng_from_array () = Random. GLOBAL_RNG
59
+ default_rng_value () = Random. GLOBAL_RNG
52
60
end
53
61
54
62
"""
55
- glorot_uniform([rng=GLOBAL_RNG ], size...; gain = 1) -> Array
63
+ glorot_uniform([rng = default_rng_value() ], size...; gain = 1) -> Array
56
64
glorot_uniform([rng]; kw...) -> Function
57
65
58
66
Return an `Array{Float32}` of the given `size` containing random numbers drawn from a uniform
@@ -91,13 +99,13 @@ function glorot_uniform(rng::AbstractRNG, dims::Integer...; gain::Real=1)
91
99
scale = Float32 (gain) * sqrt (24.0f0 / sum (nfan (dims... )))
92
100
(rand (rng, Float32, dims... ) .- 0.5f0 ) .* scale
93
101
end
94
- glorot_uniform (dims:: Integer... ; kw... ) = glorot_uniform (_rng_from_array (), dims... ; kw... )
95
- glorot_uniform (rng:: AbstractRNG = _rng_from_array (); init_kwargs... ) = (dims... ; kwargs... ) -> glorot_uniform (rng, dims... ; init_kwargs... , kwargs... )
102
+ glorot_uniform (dims:: Integer... ; kw... ) = glorot_uniform (default_rng_value (), dims... ; kw... )
103
+ glorot_uniform (rng:: AbstractRNG = default_rng_value (); init_kwargs... ) = (dims... ; kwargs... ) -> glorot_uniform (rng, dims... ; init_kwargs... , kwargs... )
96
104
97
105
ChainRulesCore. @non_differentiable glorot_uniform (:: Any... )
98
106
99
107
"""
100
- glorot_normal([rng=GLOBAL_RNG] , size...; gain = 1) -> Array
108
+ glorot_normal([rng = default_rng_value() , size...; gain = 1) -> Array
101
109
glorot_normal([rng]; kw...) -> Function
102
110
103
111
Return an `Array{Float32}` of the given `size` containing random numbers drawn from a normal
@@ -134,13 +142,13 @@ function glorot_normal(rng::AbstractRNG, dims::Integer...; gain::Real=1)
134
142
std = Float32 (gain) * sqrt (2.0f0 / sum (nfan (dims... )))
135
143
randn (rng, Float32, dims... ) .* std
136
144
end
137
- glorot_normal (dims:: Integer... ; kwargs... ) = glorot_normal (_rng_from_array (), dims... ; kwargs... )
138
- glorot_normal (rng:: AbstractRNG = _rng_from_array (); init_kwargs... ) = (dims... ; kwargs... ) -> glorot_normal (rng, dims... ; init_kwargs... , kwargs... )
145
+ glorot_normal (dims:: Integer... ; kwargs... ) = glorot_normal (default_rng_value (), dims... ; kwargs... )
146
+ glorot_normal (rng:: AbstractRNG = default_rng_value (); init_kwargs... ) = (dims... ; kwargs... ) -> glorot_normal (rng, dims... ; init_kwargs... , kwargs... )
139
147
140
148
ChainRulesCore. @non_differentiable glorot_normal (:: Any... )
141
149
142
150
"""
143
- kaiming_uniform([rng=GLOBAL_RNG ], size...; gain = √2) -> Array
151
+ kaiming_uniform([rng = default_rng_value() ], size...; gain = √2) -> Array
144
152
kaiming_uniform([rng]; kw...) -> Function
145
153
146
154
Return an `Array{Float32}` of the given `size` containing random numbers drawn from a uniform distribution
@@ -169,13 +177,13 @@ function kaiming_uniform(rng::AbstractRNG, dims::Integer...; gain::Real = √2)
169
177
return (rand (rng, Float32, dims... ) .- 0.5f0 ) .* 2 bound
170
178
end
171
179
172
- kaiming_uniform (dims:: Integer... ; kwargs... ) = kaiming_uniform (_rng_from_array (), dims... ; kwargs... )
173
- kaiming_uniform (rng:: AbstractRNG = _rng_from_array (); init_kwargs... ) = (dims... ; kwargs... ) -> kaiming_uniform (rng, dims... ; init_kwargs... , kwargs... )
180
+ kaiming_uniform (dims:: Integer... ; kwargs... ) = kaiming_uniform (default_rng_value (), dims... ; kwargs... )
181
+ kaiming_uniform (rng:: AbstractRNG = default_rng_value (); init_kwargs... ) = (dims... ; kwargs... ) -> kaiming_uniform (rng, dims... ; init_kwargs... , kwargs... )
174
182
175
183
ChainRulesCore. @non_differentiable kaiming_uniform (:: Any... )
176
184
177
185
"""
178
- kaiming_normal([rng=GLOBAL_RNG ], size...; gain = √2) -> Array
186
+ kaiming_normal([rng = default_rng_value() ], size...; gain = √2) -> Array
179
187
kaiming_normal([rng]; kw...) -> Function
180
188
181
189
Return an `Array{Float32}` of the given `size` containing random numbers taken from a normal
@@ -206,13 +214,13 @@ function kaiming_normal(rng::AbstractRNG, dims::Integer...; gain::Real = √2f0)
206
214
return randn (rng, Float32, dims... ) .* std
207
215
end
208
216
209
- kaiming_normal (dims:: Integer... ; kwargs... ) = kaiming_normal (_rng_from_array (), dims... ; kwargs... )
217
+ kaiming_normal (dims:: Integer... ; kwargs... ) = kaiming_normal (default_rng_value (), dims... ; kwargs... )
210
218
kaiming_normal (rng:: AbstractRNG ; init_kwargs... ) = (dims... ; kwargs... ) -> kaiming_normal (rng, dims... ; init_kwargs... , kwargs... )
211
219
212
220
ChainRulesCore. @non_differentiable kaiming_normal (:: Any... )
213
221
214
222
"""
215
- truncated_normal([rng=GLOBAL_RNG ], size...; mean = 0, std = 1, lo = -2, hi = 2) -> Array
223
+ truncated_normal([rng = default_rng_value() ], size...; mean = 0, std = 1, lo = -2, hi = 2) -> Array
216
224
truncated_normal([rng]; kw...) -> Function
217
225
218
226
Return an `Array{Float32}` of the given `size` where each element is drawn from a truncated normal distribution.
@@ -252,13 +260,13 @@ function truncated_normal(rng::AbstractRNG, dims::Integer...; mean = 0, std = 1,
252
260
return xs
253
261
end
254
262
255
- truncated_normal (dims:: Integer... ; kwargs... ) = truncated_normal (_rng_from_array (), dims... ; kwargs... )
256
- truncated_normal (rng:: AbstractRNG = _rng_from_array (); init_kwargs... ) = (dims... ; kwargs... ) -> truncated_normal (rng, dims... ; init_kwargs... , kwargs... )
263
+ truncated_normal (dims:: Integer... ; kwargs... ) = truncated_normal (default_rng_value (), dims... ; kwargs... )
264
+ truncated_normal (rng:: AbstractRNG = default_rng_value (); init_kwargs... ) = (dims... ; kwargs... ) -> truncated_normal (rng, dims... ; init_kwargs... , kwargs... )
257
265
258
266
ChainRulesCore. @non_differentiable truncated_normal (:: Any... )
259
267
260
268
"""
261
- orthogonal([rng=GLOBAL_RNG ], size...; gain = 1) -> Array
269
+ orthogonal([rng = default_rng_value() ], size...; gain = 1) -> Array
262
270
orthogonal([rng]; kw...) -> Function
263
271
264
272
Return an `Array{Float32}` of the given `size` which is a (semi) orthogonal matrix, as described in [1].
@@ -313,13 +321,13 @@ function orthogonal(rng::AbstractRNG, d1::Integer, ds::Integer...; kwargs...)
313
321
return reshape (orthogonal (rng, rows, cols; kwargs... ), dims)
314
322
end
315
323
316
- orthogonal (dims:: Integer... ; kwargs... ) = orthogonal (_rng_from_array (), dims... ; kwargs... )
317
- orthogonal (rng:: AbstractRNG = _rng_from_array (); init_kwargs... ) = (dims:: Integer... ; kwargs... ) -> orthogonal (rng, dims... ; init_kwargs... , kwargs... )
324
+ orthogonal (dims:: Integer... ; kwargs... ) = orthogonal (default_rng_value (), dims... ; kwargs... )
325
+ orthogonal (rng:: AbstractRNG = default_rng_value (); init_kwargs... ) = (dims:: Integer... ; kwargs... ) -> orthogonal (rng, dims... ; init_kwargs... , kwargs... )
318
326
319
327
ChainRulesCore. @non_differentiable orthogonal (:: Any... )
320
328
321
329
"""
322
- sparse_init([rng=GLOBAL_RNG ], rows, cols; sparsity, std = 0.01) -> Array
330
+ sparse_init([rng = default_rng_value() ], rows, cols; sparsity, std = 0.01) -> Array
323
331
sparse_init([rng]; kw...) -> Function
324
332
325
333
Return a `Matrix{Float32}` of size `rows, cols` where each column contains a fixed fraction of
@@ -361,8 +369,8 @@ function sparse_init(rng::AbstractRNG, dims::Integer...; sparsity, std = 0.01)
361
369
return mapslices (shuffle, sparse_array, dims= 1 )
362
370
end
363
371
364
- sparse_init (dims:: Integer... ; kwargs... ) = sparse_init (_rng_from_array (), dims... ; kwargs... )
365
- sparse_init (rng:: AbstractRNG = _rng_from_array (); init_kwargs... ) = (dims... ; kwargs... ) -> sparse_init (rng, dims... ; init_kwargs... , kwargs... )
372
+ sparse_init (dims:: Integer... ; kwargs... ) = sparse_init (default_rng_value (), dims... ; kwargs... )
373
+ sparse_init (rng:: AbstractRNG = default_rng_value (); init_kwargs... ) = (dims... ; kwargs... ) -> sparse_init (rng, dims... ; init_kwargs... , kwargs... )
366
374
367
375
ChainRulesCore. @non_differentiable sparse_init (:: Any... )
368
376
452
460
453
461
# For consistency, it accepts an RNG, but ignores it:
454
462
identity_init (:: AbstractRNG , dims:: Integer... ; kwargs... ) = identity_init (dims... ; kwargs... )
455
- identity_init (rng:: AbstractRNG = _rng_from_array (); init_kwargs... ) = (args... ;kwargs... ) -> identity_init (rng, args... ; init_kwargs... , kwargs... )
463
+ identity_init (rng:: AbstractRNG = default_rng_value (); init_kwargs... ) = (args... ;kwargs... ) -> identity_init (rng, args... ; init_kwargs... , kwargs... )
456
464
457
465
ChainRulesCore. @non_differentiable identity_init (:: Any... )
458
466
0 commit comments