@@ -34,7 +34,7 @@ ofeltype(x, y) = convert(float(eltype(x)), y)
34
34
epseltype (x) = eps (float (eltype (x)))
35
35
36
36
"""
37
- rng_from_array ([x])
37
+ _rng_from_array ([x])
38
38
39
39
Create an instance of the RNG most appropriate for `x`.
40
40
The current defaults are:
@@ -43,12 +43,12 @@ The current defaults are:
43
43
- Julia version is < 1.7: `Random.GLOBAL_RNG`
44
44
- Julia version is >= 1.7: `Random.default_rng()`
45
45
"""
46
- rng_from_array (:: AbstractArray ) = rng_from_array ()
47
- rng_from_array (:: CuArray ) = CUDA. default_rng ()
46
+ _rng_from_array (:: AbstractArray ) = _rng_from_array ()
47
+ _rng_from_array (:: CuArray ) = CUDA. default_rng ()
48
48
if VERSION >= v " 1.7"
49
- rng_from_array () = Random. default_rng ()
49
+ _rng_from_array () = Random. default_rng ()
50
50
else
51
- rng_from_array () = Random. GLOBAL_RNG
51
+ _rng_from_array () = Random. GLOBAL_RNG
52
52
end
53
53
54
54
"""
@@ -91,8 +91,8 @@ function glorot_uniform(rng::AbstractRNG, dims::Integer...; gain::Real=1)
91
91
scale = Float32 (gain) * sqrt (24.0f0 / sum (nfan (dims... )))
92
92
(rand (rng, Float32, dims... ) .- 0.5f0 ) .* scale
93
93
end
94
- glorot_uniform (dims:: Integer... ; kw... ) = glorot_uniform (rng_from_array (), dims... ; kw... )
95
- glorot_uniform (rng:: AbstractRNG = rng_from_array (); init_kwargs... ) = (dims... ; kwargs... ) -> glorot_uniform (rng, dims... ; init_kwargs... , kwargs... )
94
+ glorot_uniform (dims:: Integer... ; kw... ) = glorot_uniform (_rng_from_array (), dims... ; kw... )
95
+ glorot_uniform (rng:: AbstractRNG = _rng_from_array (); init_kwargs... ) = (dims... ; kwargs... ) -> glorot_uniform (rng, dims... ; init_kwargs... , kwargs... )
96
96
97
97
ChainRulesCore. @non_differentiable glorot_uniform (:: Any... )
98
98
@@ -134,8 +134,8 @@ function glorot_normal(rng::AbstractRNG, dims::Integer...; gain::Real=1)
134
134
std = Float32 (gain) * sqrt (2.0f0 / sum (nfan (dims... )))
135
135
randn (rng, Float32, dims... ) .* std
136
136
end
137
- glorot_normal (dims:: Integer... ; kwargs... ) = glorot_normal (rng_from_array (), dims... ; kwargs... )
138
- glorot_normal (rng:: AbstractRNG = rng_from_array (); init_kwargs... ) = (dims... ; kwargs... ) -> glorot_normal (rng, dims... ; init_kwargs... , kwargs... )
137
+ glorot_normal (dims:: Integer... ; kwargs... ) = glorot_normal (_rng_from_array (), dims... ; kwargs... )
138
+ glorot_normal (rng:: AbstractRNG = _rng_from_array (); init_kwargs... ) = (dims... ; kwargs... ) -> glorot_normal (rng, dims... ; init_kwargs... , kwargs... )
139
139
140
140
ChainRulesCore. @non_differentiable glorot_normal (:: Any... )
141
141
@@ -169,8 +169,8 @@ function kaiming_uniform(rng::AbstractRNG, dims::Integer...; gain::Real = √2)
169
169
return (rand (rng, Float32, dims... ) .- 0.5f0 ) .* 2 bound
170
170
end
171
171
172
- kaiming_uniform (dims:: Integer... ; kwargs... ) = kaiming_uniform (rng_from_array (), dims... ; kwargs... )
173
- kaiming_uniform (rng:: AbstractRNG = rng_from_array (); init_kwargs... ) = (dims... ; kwargs... ) -> kaiming_uniform (rng, dims... ; init_kwargs... , kwargs... )
172
+ kaiming_uniform (dims:: Integer... ; kwargs... ) = kaiming_uniform (_rng_from_array (), dims... ; kwargs... )
173
+ kaiming_uniform (rng:: AbstractRNG = _rng_from_array (); init_kwargs... ) = (dims... ; kwargs... ) -> kaiming_uniform (rng, dims... ; init_kwargs... , kwargs... )
174
174
175
175
ChainRulesCore. @non_differentiable kaiming_uniform (:: Any... )
176
176
@@ -206,7 +206,7 @@ function kaiming_normal(rng::AbstractRNG, dims::Integer...; gain::Real = √2f0)
206
206
return randn (rng, Float32, dims... ) .* std
207
207
end
208
208
209
- kaiming_normal (dims:: Integer... ; kwargs... ) = kaiming_normal (rng_from_array (), dims... ; kwargs... )
209
+ kaiming_normal (dims:: Integer... ; kwargs... ) = kaiming_normal (_rng_from_array (), dims... ; kwargs... )
210
210
kaiming_normal (rng:: AbstractRNG ; init_kwargs... ) = (dims... ; kwargs... ) -> kaiming_normal (rng, dims... ; init_kwargs... , kwargs... )
211
211
212
212
ChainRulesCore. @non_differentiable kaiming_normal (:: Any... )
@@ -252,8 +252,8 @@ function truncated_normal(rng::AbstractRNG, dims::Integer...; mean = 0, std = 1,
252
252
return xs
253
253
end
254
254
255
- truncated_normal (dims:: Integer... ; kwargs... ) = truncated_normal (rng_from_array (), dims... ; kwargs... )
256
- truncated_normal (rng:: AbstractRNG = rng_from_array (); init_kwargs... ) = (dims... ; kwargs... ) -> truncated_normal (rng, dims... ; init_kwargs... , kwargs... )
255
+ truncated_normal (dims:: Integer... ; kwargs... ) = truncated_normal (_rng_from_array (), dims... ; kwargs... )
256
+ truncated_normal (rng:: AbstractRNG = _rng_from_array (); init_kwargs... ) = (dims... ; kwargs... ) -> truncated_normal (rng, dims... ; init_kwargs... , kwargs... )
257
257
258
258
ChainRulesCore. @non_differentiable truncated_normal (:: Any... )
259
259
@@ -313,8 +313,8 @@ function orthogonal(rng::AbstractRNG, d1::Integer, ds::Integer...; kwargs...)
313
313
return reshape (orthogonal (rng, rows, cols; kwargs... ), dims)
314
314
end
315
315
316
- orthogonal (dims:: Integer... ; kwargs... ) = orthogonal (rng_from_array (), dims... ; kwargs... )
317
- orthogonal (rng:: AbstractRNG = rng_from_array (); init_kwargs... ) = (dims:: Integer... ; kwargs... ) -> orthogonal (rng, dims... ; init_kwargs... , kwargs... )
316
+ orthogonal (dims:: Integer... ; kwargs... ) = orthogonal (_rng_from_array (), dims... ; kwargs... )
317
+ orthogonal (rng:: AbstractRNG = _rng_from_array (); init_kwargs... ) = (dims:: Integer... ; kwargs... ) -> orthogonal (rng, dims... ; init_kwargs... , kwargs... )
318
318
319
319
ChainRulesCore. @non_differentiable orthogonal (:: Any... )
320
320
@@ -361,8 +361,8 @@ function sparse_init(rng::AbstractRNG, dims::Integer...; sparsity, std = 0.01)
361
361
return mapslices (shuffle, sparse_array, dims= 1 )
362
362
end
363
363
364
- sparse_init (dims:: Integer... ; kwargs... ) = sparse_init (rng_from_array (), dims... ; kwargs... )
365
- sparse_init (rng:: AbstractRNG = rng_from_array (); init_kwargs... ) = (dims... ; kwargs... ) -> sparse_init (rng, dims... ; init_kwargs... , kwargs... )
364
+ sparse_init (dims:: Integer... ; kwargs... ) = sparse_init (_rng_from_array (), dims... ; kwargs... )
365
+ sparse_init (rng:: AbstractRNG = _rng_from_array (); init_kwargs... ) = (dims... ; kwargs... ) -> sparse_init (rng, dims... ; init_kwargs... , kwargs... )
366
366
367
367
ChainRulesCore. @non_differentiable sparse_init (:: Any... )
368
368
452
452
453
453
# For consistency, it accepts an RNG, but ignores it:
454
454
identity_init (:: AbstractRNG , dims:: Integer... ; kwargs... ) = identity_init (dims... ; kwargs... )
455
- identity_init (rng:: AbstractRNG = rng_from_array (); init_kwargs... ) = (args... ;kwargs... ) -> identity_init (rng, args... ; init_kwargs... , kwargs... )
455
+ identity_init (rng:: AbstractRNG = _rng_from_array (); init_kwargs... ) = (args... ;kwargs... ) -> identity_init (rng, args... ; init_kwargs... , kwargs... )
456
456
457
457
ChainRulesCore. @non_differentiable identity_init (:: Any... )
458
458
@@ -461,33 +461,40 @@ zeros32(dims::Integer...) = Base.zeros(Float32, dims...)
461
461
462
462
"""
463
463
ones32(size...) = ones(Float32, size...)
464
- zeros32(size...) = zeros(Float32, size...)
465
464
466
- Return an `Array{Float32}` of the given `size`.
465
+ Return an `Array{Float32}` of the given `size` filled with 1s .
467
466
"""
468
467
ones32 (dims... ) = Base. ones (Float32, dims... )
469
468
470
- @doc @doc (ones32)
469
+ """
470
+ zeros32(size...) = zeros(Float32, size...)
471
+
472
+ Return an `Array{Float32}` of the given `size` filled with 0s.
473
+ """
471
474
zeros32 (dims... ) = Base. zeros (Float32, dims... )
472
475
473
476
"""
474
477
rand32([rng], size...)
475
- randn32([rng], size...)
476
478
477
- Return an `Array{Float32}` of the given `size`, filled like `rand` or `randn` .
479
+ Return an `Array{Float32}` of the given `size`, filled like `rand`.
478
480
When the size is not provided, `rand32(rng::AbstractRNG)` returns a function.
479
481
"""
480
482
rand32 (dims:: Integer... ) = Base. rand (Float32, dims... )
481
483
rand32 (rng:: AbstractRNG , dims:: Integer... ) = Base. rand (rng, Float32, dims... )
482
484
rand32 (rng:: AbstractRNG ) = (dims... ,) -> Base. rand (rng, Float32, dims... )
483
485
484
- @doc @doc (rand32)
486
+ """
487
+ randn32([rng], size...)
488
+
489
+ Return an `Array{Float32}` of the given `size`, filled like `randn`.
490
+ When the size is not provided, `randn32(rng::AbstractRNG)` returns a function.
491
+ """
485
492
randn32 (dims:: Integer... ) = Base. randn (Float32, dims... )
486
493
randn32 (rng:: AbstractRNG , dims:: Integer... ) = Base. randn (rng, Float32, dims... )
487
494
randn32 (rng:: AbstractRNG ) = (dims... ,) -> Base. randn (rng, Float32, dims... )
488
495
489
496
"""
490
- create_bias (weights, bias, size...)
497
+ _create_bias (weights, bias, size...)
491
498
492
499
Return a bias parameter for a layer, based on the value given
493
500
to the constructor's keyword `bias=bias`.
@@ -497,10 +504,10 @@ to the constructor's keyword `bias=bias`.
497
504
* `bias::AbstractArray` uses the array provided, provided it has the correct size.
498
505
It does not at present correct the `eltype` to match that of `weights`.
499
506
"""
500
- function create_bias (weights:: AbstractArray , bias:: Bool , dims:: Integer... )
507
+ function _create_bias (weights:: AbstractArray , bias:: Bool , dims:: Integer... )
501
508
bias ? fill! (similar (weights, dims... ), 0 ) : false
502
509
end
503
- function create_bias (weights:: AbstractArray , bias:: AbstractArray , dims:: Integer... )
510
+ function _create_bias (weights:: AbstractArray , bias:: AbstractArray , dims:: Integer... )
504
511
size (bias) == dims || throw (DimensionMismatch (" expected bias of size $(dims) , got size $(size (bias)) " ))
505
512
bias
506
513
end
@@ -518,6 +525,34 @@ Normally, the throttled function will run as much as it can, without ever
518
525
going more than once per `wait` duration; but if you'd like to disable the
519
526
execution on the leading edge, pass `leading=false`. To enable execution on
520
527
the trailing edge, pass `trailing=true`.
528
+
529
+ # Examples
530
+ ```jldoctest
531
+ julia> a = Flux.throttle(() -> println("Flux"), 2);
532
+
533
+ julia> a()
534
+ Flux
535
+
536
+ julia> a()
537
+ Flux
538
+
539
+ julia> for i = 1:4 # sleeps for 1 second -> the function can be called in alternate iterations
540
+ a()
541
+ sleep(1)
542
+ end
543
+ Flux
544
+ Flux
545
+
546
+ julia> for i = 1:4 # sleeps for 2 second -> the function can be called in the next iteration
547
+ a()
548
+ sleep(2)
549
+ end
550
+ Flux
551
+ Flux
552
+ Flux
553
+ Flux
554
+
555
+ ```
521
556
"""
522
557
function throttle (f, timeout; leading= true , trailing= false )
523
558
cooldown = true
0 commit comments