Skip to content

Commit 4773a69

Browse files
committed
rng_from_array() -> default_rng_value()
1 parent 9dec787 commit 4773a69

File tree

3 files changed

+41
-31
lines changed

3 files changed

+41
-31
lines changed

docs/src/utilities.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,8 @@ Flux.ones32
4545
Flux.zeros32
4646
Flux.rand32
4747
Flux.randn32
48+
Flux.rng_from_array
49+
Flux.default_rng_value
4850
```
4951

5052
## Changing the type of model parameters

src/layers/normalise.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ end
5151
ChainRulesCore.@non_differentiable dropout_mask(::Any, ::Any, ::Any)
5252

5353
"""
54-
Dropout(p; dims=:, rng = _rng_from_array())
54+
Dropout(p; dims=:, rng = default_rng_value())
5555
5656
Dropout layer.
5757
@@ -96,9 +96,9 @@ mutable struct Dropout{F,D,R<:AbstractRNG}
9696
active::Union{Bool, Nothing}
9797
rng::R
9898
end
99-
Dropout(p, dims, active) = Dropout(p, dims, active, _rng_from_array())
99+
Dropout(p, dims, active) = Dropout(p, dims, active, default_rng_value())
100100

101-
function Dropout(p; dims=:, rng = _rng_from_array())
101+
function Dropout(p; dims=:, rng = default_rng_value())
102102
@assert 0 p 1
103103
Dropout(p, dims, nothing, rng)
104104
end
@@ -121,7 +121,7 @@ function Base.show(io::IO, d::Dropout)
121121
end
122122

123123
"""
124-
AlphaDropout(p; rng = _rng_from_array())
124+
AlphaDropout(p; rng = default_rng_value())
125125
126126
A dropout layer. Used in
127127
[Self-Normalizing Neural Networks](https://arxiv.org/abs/1706.02515).
@@ -155,8 +155,8 @@ mutable struct AlphaDropout{F,R<:AbstractRNG}
155155
new{typeof(p), typeof(rng)}(p, active, rng)
156156
end
157157
end
158-
AlphaDropout(p, active) = AlphaDropout(p, active, _rng_from_array())
159-
AlphaDropout(p; rng = _rng_from_array()) = AlphaDropout(p, nothing, rng)
158+
AlphaDropout(p, active) = AlphaDropout(p, active, default_rng_value())
159+
AlphaDropout(p; rng = default_rng_value()) = AlphaDropout(p, nothing, rng)
160160

161161
@functor AlphaDropout
162162
trainable(a::AlphaDropout) = (;)

src/utils.jl

Lines changed: 33 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -43,16 +43,24 @@ The current defaults are:
4343
- Julia version is < 1.7: `Random.GLOBAL_RNG`
4444
- Julia version is >= 1.7: `Random.default_rng()`
4545
"""
46-
_rng_from_array(::AbstractArray) = _rng_from_array()
47-
_rng_from_array(::CuArray) = CUDA.default_rng()
46+
rng_from_array(::AbstractArray) = default_rng_value()
47+
rng_from_array(::CuArray) = CUDA.default_rng()
48+
4849
if VERSION >= v"1.7"
49-
_rng_from_array() = Random.default_rng()
50+
@doc """
51+
default_rng_value()
52+
53+
Create an instance of the default RNG depending on Julia's version.
54+
- Julia version is < 1.7: `Random.GLOBAL_RNG`
55+
- Julia version is >= 1.7: `Random.default_rng()`
56+
"""
57+
default_rng_value() = Random.default_rng()
5058
else
51-
_rng_from_array() = Random.GLOBAL_RNG
59+
default_rng_value() = Random.GLOBAL_RNG
5260
end
5361

5462
"""
55-
glorot_uniform([rng=GLOBAL_RNG], size...; gain = 1) -> Array
63+
glorot_uniform([rng = default_rng_value()], size...; gain = 1) -> Array
5664
glorot_uniform([rng]; kw...) -> Function
5765
5866
Return an `Array{Float32}` of the given `size` containing random numbers drawn from a uniform
@@ -91,13 +99,13 @@ function glorot_uniform(rng::AbstractRNG, dims::Integer...; gain::Real=1)
9199
scale = Float32(gain) * sqrt(24.0f0 / sum(nfan(dims...)))
92100
(rand(rng, Float32, dims...) .- 0.5f0) .* scale
93101
end
94-
glorot_uniform(dims::Integer...; kw...) = glorot_uniform(_rng_from_array(), dims...; kw...)
95-
glorot_uniform(rng::AbstractRNG=_rng_from_array(); init_kwargs...) = (dims...; kwargs...) -> glorot_uniform(rng, dims...; init_kwargs..., kwargs...)
102+
glorot_uniform(dims::Integer...; kw...) = glorot_uniform(default_rng_value(), dims...; kw...)
103+
glorot_uniform(rng::AbstractRNG=default_rng_value(); init_kwargs...) = (dims...; kwargs...) -> glorot_uniform(rng, dims...; init_kwargs..., kwargs...)
96104

97105
ChainRulesCore.@non_differentiable glorot_uniform(::Any...)
98106

99107
"""
100-
glorot_normal([rng=GLOBAL_RNG], size...; gain = 1) -> Array
108+
glorot_normal([rng = default_rng_value(), size...; gain = 1) -> Array
101109
glorot_normal([rng]; kw...) -> Function
102110
103111
Return an `Array{Float32}` of the given `size` containing random numbers drawn from a normal
@@ -134,13 +142,13 @@ function glorot_normal(rng::AbstractRNG, dims::Integer...; gain::Real=1)
134142
std = Float32(gain) * sqrt(2.0f0 / sum(nfan(dims...)))
135143
randn(rng, Float32, dims...) .* std
136144
end
137-
glorot_normal(dims::Integer...; kwargs...) = glorot_normal(_rng_from_array(), dims...; kwargs...)
138-
glorot_normal(rng::AbstractRNG=_rng_from_array(); init_kwargs...) = (dims...; kwargs...) -> glorot_normal(rng, dims...; init_kwargs..., kwargs...)
145+
glorot_normal(dims::Integer...; kwargs...) = glorot_normal(default_rng_value(), dims...; kwargs...)
146+
glorot_normal(rng::AbstractRNG=default_rng_value(); init_kwargs...) = (dims...; kwargs...) -> glorot_normal(rng, dims...; init_kwargs..., kwargs...)
139147

140148
ChainRulesCore.@non_differentiable glorot_normal(::Any...)
141149

142150
"""
143-
kaiming_uniform([rng=GLOBAL_RNG], size...; gain = √2) -> Array
151+
kaiming_uniform([rng = default_rng_value()], size...; gain = √2) -> Array
144152
kaiming_uniform([rng]; kw...) -> Function
145153
146154
Return an `Array{Float32}` of the given `size` containing random numbers drawn from a uniform distribution
@@ -169,13 +177,13 @@ function kaiming_uniform(rng::AbstractRNG, dims::Integer...; gain::Real = √2)
169177
return (rand(rng, Float32, dims...) .- 0.5f0) .* 2bound
170178
end
171179

172-
kaiming_uniform(dims::Integer...; kwargs...) = kaiming_uniform(_rng_from_array(), dims...; kwargs...)
173-
kaiming_uniform(rng::AbstractRNG=_rng_from_array(); init_kwargs...) = (dims...; kwargs...) -> kaiming_uniform(rng, dims...; init_kwargs..., kwargs...)
180+
kaiming_uniform(dims::Integer...; kwargs...) = kaiming_uniform(default_rng_value(), dims...; kwargs...)
181+
kaiming_uniform(rng::AbstractRNG=default_rng_value(); init_kwargs...) = (dims...; kwargs...) -> kaiming_uniform(rng, dims...; init_kwargs..., kwargs...)
174182

175183
ChainRulesCore.@non_differentiable kaiming_uniform(::Any...)
176184

177185
"""
178-
kaiming_normal([rng=GLOBAL_RNG], size...; gain = √2) -> Array
186+
kaiming_normal([rng = default_rng_value()], size...; gain = √2) -> Array
179187
kaiming_normal([rng]; kw...) -> Function
180188
181189
Return an `Array{Float32}` of the given `size` containing random numbers taken from a normal
@@ -206,13 +214,13 @@ function kaiming_normal(rng::AbstractRNG, dims::Integer...; gain::Real = √2f0)
206214
return randn(rng, Float32, dims...) .* std
207215
end
208216

209-
kaiming_normal(dims::Integer...; kwargs...) = kaiming_normal(_rng_from_array(), dims...; kwargs...)
217+
kaiming_normal(dims::Integer...; kwargs...) = kaiming_normal(default_rng_value(), dims...; kwargs...)
210218
kaiming_normal(rng::AbstractRNG; init_kwargs...) = (dims...; kwargs...) -> kaiming_normal(rng, dims...; init_kwargs..., kwargs...)
211219

212220
ChainRulesCore.@non_differentiable kaiming_normal(::Any...)
213221

214222
"""
215-
truncated_normal([rng=GLOBAL_RNG], size...; mean = 0, std = 1, lo = -2, hi = 2) -> Array
223+
truncated_normal([rng = default_rng_value()], size...; mean = 0, std = 1, lo = -2, hi = 2) -> Array
216224
truncated_normal([rng]; kw...) -> Function
217225
218226
Return an `Array{Float32}` of the given `size` where each element is drawn from a truncated normal distribution.
@@ -252,13 +260,13 @@ function truncated_normal(rng::AbstractRNG, dims::Integer...; mean = 0, std = 1,
252260
return xs
253261
end
254262

255-
truncated_normal(dims::Integer...; kwargs...) = truncated_normal(_rng_from_array(), dims...; kwargs...)
256-
truncated_normal(rng::AbstractRNG=_rng_from_array(); init_kwargs...) = (dims...; kwargs...) -> truncated_normal(rng, dims...; init_kwargs..., kwargs...)
263+
truncated_normal(dims::Integer...; kwargs...) = truncated_normal(default_rng_value(), dims...; kwargs...)
264+
truncated_normal(rng::AbstractRNG=default_rng_value(); init_kwargs...) = (dims...; kwargs...) -> truncated_normal(rng, dims...; init_kwargs..., kwargs...)
257265

258266
ChainRulesCore.@non_differentiable truncated_normal(::Any...)
259267

260268
"""
261-
orthogonal([rng=GLOBAL_RNG], size...; gain = 1) -> Array
269+
orthogonal([rng = default_rng_value()], size...; gain = 1) -> Array
262270
orthogonal([rng]; kw...) -> Function
263271
264272
Return an `Array{Float32}` of the given `size` which is a (semi) orthogonal matrix, as described in [1].
@@ -313,13 +321,13 @@ function orthogonal(rng::AbstractRNG, d1::Integer, ds::Integer...; kwargs...)
313321
return reshape(orthogonal(rng, rows, cols; kwargs...), dims)
314322
end
315323

316-
orthogonal(dims::Integer...; kwargs...) = orthogonal(_rng_from_array(), dims...; kwargs...)
317-
orthogonal(rng::AbstractRNG=_rng_from_array(); init_kwargs...) = (dims::Integer...; kwargs...) -> orthogonal(rng, dims...; init_kwargs..., kwargs...)
324+
orthogonal(dims::Integer...; kwargs...) = orthogonal(default_rng_value(), dims...; kwargs...)
325+
orthogonal(rng::AbstractRNG=default_rng_value(); init_kwargs...) = (dims::Integer...; kwargs...) -> orthogonal(rng, dims...; init_kwargs..., kwargs...)
318326

319327
ChainRulesCore.@non_differentiable orthogonal(::Any...)
320328

321329
"""
322-
sparse_init([rng=GLOBAL_RNG], rows, cols; sparsity, std = 0.01) -> Array
330+
sparse_init([rng = default_rng_value()], rows, cols; sparsity, std = 0.01) -> Array
323331
sparse_init([rng]; kw...) -> Function
324332
325333
Return a `Matrix{Float32}` of size `rows, cols` where each column contains a fixed fraction of
@@ -361,8 +369,8 @@ function sparse_init(rng::AbstractRNG, dims::Integer...; sparsity, std = 0.01)
361369
return mapslices(shuffle, sparse_array, dims=1)
362370
end
363371

364-
sparse_init(dims::Integer...; kwargs...) = sparse_init(_rng_from_array(), dims...; kwargs...)
365-
sparse_init(rng::AbstractRNG=_rng_from_array(); init_kwargs...) = (dims...; kwargs...) -> sparse_init(rng, dims...; init_kwargs..., kwargs...)
372+
sparse_init(dims::Integer...; kwargs...) = sparse_init(default_rng_value(), dims...; kwargs...)
373+
sparse_init(rng::AbstractRNG=default_rng_value(); init_kwargs...) = (dims...; kwargs...) -> sparse_init(rng, dims...; init_kwargs..., kwargs...)
366374

367375
ChainRulesCore.@non_differentiable sparse_init(::Any...)
368376

@@ -452,7 +460,7 @@ end
452460

453461
# For consistency, it accepts an RNG, but ignores it:
454462
identity_init(::AbstractRNG, dims::Integer...; kwargs...) = identity_init(dims...; kwargs...)
455-
identity_init(rng::AbstractRNG=_rng_from_array(); init_kwargs...) = (args...;kwargs...) -> identity_init(rng, args...; init_kwargs..., kwargs...)
463+
identity_init(rng::AbstractRNG=default_rng_value(); init_kwargs...) = (args...;kwargs...) -> identity_init(rng, args...; init_kwargs..., kwargs...)
456464

457465
ChainRulesCore.@non_differentiable identity_init(::Any...)
458466

0 commit comments

Comments
 (0)