Skip to content

Commit e3b7b18

Browse files
authored
Merge pull request #25319 from JuliaLang/rf/rand/unleash-remains
two small changes from #24912
2 parents 6a46b10 + 3788b51 commit e3b7b18

File tree

4 files changed

+59
-48
lines changed

4 files changed

+59
-48
lines changed

base/random/RNGs.jl

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -284,12 +284,12 @@ rng_native_52(::MersenneTwister) = Float64
284284
#### helper functions
285285

286286
# precondition: !mt_empty(r)
287-
rand_inbounds(r::MersenneTwister, ::Close1Open2_64) = mt_pop!(r)
288-
rand_inbounds(r::MersenneTwister, ::CloseOpen_64=CloseOpen()) =
289-
rand_inbounds(r, Close1Open2()) - 1.0
287+
rand_inbounds(r::MersenneTwister, ::CloseOpen12_64) = mt_pop!(r)
288+
rand_inbounds(r::MersenneTwister, ::CloseOpen01_64=CloseOpen01()) =
289+
rand_inbounds(r, CloseOpen12()) - 1.0
290290

291291
rand_inbounds(r::MersenneTwister, ::UInt52Raw{T}) where {T<:BitInteger} =
292-
reinterpret(UInt64, rand_inbounds(r, Close1Open2())) % T
292+
reinterpret(UInt64, rand_inbounds(r, CloseOpen12())) % T
293293

294294
function rand(r::MersenneTwister, x::SamplerTrivial{UInt52Raw{UInt64}})
295295
reserve_1(r)
@@ -308,7 +308,7 @@ end
308308

309309
#### floats
310310

311-
rand(r::MersenneTwister, sp::SamplerTrivial{Close1Open2_64}) =
311+
rand(r::MersenneTwister, sp::SamplerTrivial{CloseOpen12_64}) =
312312
(reserve_1(r); rand_inbounds(r, sp[]))
313313

314314
#### integers
@@ -380,7 +380,7 @@ function _rand_max383!(r::MersenneTwister, A::UnsafeView{Float64}, I::FloatInter
380380
@gc_preserve r unsafe_copyto!(A.ptr+m*sizeof(Float64), pointer(r.vals), n-m)
381381
r.idxF = n-m
382382
end
383-
if I isa CloseOpen
383+
if I isa CloseOpen01
384384
for i=1:n
385385
A[i] -= 1.0
386386
end
@@ -389,10 +389,10 @@ function _rand_max383!(r::MersenneTwister, A::UnsafeView{Float64}, I::FloatInter
389389
end
390390

391391

392-
fill_array!(s::DSFMT_state, A::Ptr{Float64}, n::Int, ::CloseOpen_64) =
392+
fill_array!(s::DSFMT_state, A::Ptr{Float64}, n::Int, ::CloseOpen01_64) =
393393
dsfmt_fill_array_close_open!(s, A, n)
394394

395-
fill_array!(s::DSFMT_state, A::Ptr{Float64}, n::Int, ::Close1Open2_64) =
395+
fill_array!(s::DSFMT_state, A::Ptr{Float64}, n::Int, ::CloseOpen12_64) =
396396
dsfmt_fill_array_close1_open2!(s, A, n)
397397

398398

@@ -443,10 +443,10 @@ mask128(u::UInt128, ::Type{Float32}) =
443443
(u & 0x007fffff007fffff007fffff007fffff) | 0x3f8000003f8000003f8000003f800000
444444

445445
for T in (Float16, Float32)
446-
@eval function rand!(r::MersenneTwister, A::Array{$T}, ::SamplerTrivial{Close1Open2{$T}})
446+
@eval function rand!(r::MersenneTwister, A::Array{$T}, ::SamplerTrivial{CloseOpen12{$T}})
447447
n = length(A)
448448
n128 = n * sizeof($T) ÷ 16
449-
_rand!(r, A, 2*n128, Close1Open2())
449+
_rand!(r, A, 2*n128, CloseOpen12())
450450
@gc_preserve A begin
451451
A128 = UnsafeView{UInt128}(pointer(A), n128)
452452
for i in 1:n128
@@ -471,8 +471,8 @@ for T in (Float16, Float32)
471471
A
472472
end
473473

474-
@eval function rand!(r::MersenneTwister, A::Array{$T}, ::SamplerTrivial{CloseOpen{$T}})
475-
rand!(r, A, Close1Open2($T))
474+
@eval function rand!(r::MersenneTwister, A::Array{$T}, ::SamplerTrivial{CloseOpen01{$T}})
475+
rand!(r, A, CloseOpen12($T))
476476
I32 = one(Float32)
477477
for i in eachindex(A)
478478
@inbounds A[i] = Float32(A[i])-I32 # faster than "A[i] -= one(T)" for T==Float16
@@ -487,7 +487,7 @@ function rand!(r::MersenneTwister, A::UnsafeView{UInt128}, ::SamplerType{UInt128
487487
n::Int=length(A)
488488
i = n
489489
while true
490-
rand!(r, UnsafeView{Float64}(A.ptr, 2i), Close1Open2())
490+
rand!(r, UnsafeView{Float64}(A.ptr, 2i), CloseOpen12())
491491
n < 5 && break
492492
i = 0
493493
while n-i >= 5

base/random/generation.jl

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -17,29 +17,29 @@
1717
### random floats
1818

1919
Sampler(rng::AbstractRNG, ::Type{T}, n::Repetition) where {T<:AbstractFloat} =
20-
Sampler(rng, CloseOpen(T), n)
20+
Sampler(rng, CloseOpen01(T), n)
2121

2222
# generic random generation function which can be used by RNG implementors
2323
# it is not defined as a fallback rand method as this could create ambiguities
2424

25-
rand(r::AbstractRNG, ::SamplerTrivial{CloseOpen{Float16}}) =
25+
rand(r::AbstractRNG, ::SamplerTrivial{CloseOpen01{Float16}}) =
2626
Float16(reinterpret(Float32,
2727
(rand(r, UInt10(UInt32)) << 13) | 0x3f800000) - 1)
2828

29-
rand(r::AbstractRNG, ::SamplerTrivial{CloseOpen{Float32}}) =
29+
rand(r::AbstractRNG, ::SamplerTrivial{CloseOpen01{Float32}}) =
3030
reinterpret(Float32, rand(r, UInt23()) | 0x3f800000) - 1
3131

32-
rand(r::AbstractRNG, ::SamplerTrivial{Close1Open2_64}) =
32+
rand(r::AbstractRNG, ::SamplerTrivial{CloseOpen12_64}) =
3333
reinterpret(Float64, 0x3ff0000000000000 | rand(r, UInt52()))
3434

35-
rand(r::AbstractRNG, ::SamplerTrivial{CloseOpen_64}) = rand(r, Close1Open2()) - 1.0
35+
rand(r::AbstractRNG, ::SamplerTrivial{CloseOpen01_64}) = rand(r, CloseOpen12()) - 1.0
3636

3737
#### BigFloat
3838

3939
const bits_in_Limb = sizeof(Limb) << 3
4040
const Limb_high_bit = one(Limb) << (bits_in_Limb-1)
4141

42-
struct SamplerBigFloat{I<:FloatInterval{BigFloat}} <: Sampler
42+
struct SamplerBigFloat{I<:FloatInterval{BigFloat}} <: Sampler{BigFloat}
4343
prec::Int
4444
nlimbs::Int
4545
limbs::Vector{Limb}
@@ -70,13 +70,13 @@ function _rand(rng::AbstractRNG, sp::SamplerBigFloat)
7070
(z, randbool)
7171
end
7272

73-
function _rand(rng::AbstractRNG, sp::SamplerBigFloat, ::Close1Open2{BigFloat})
73+
function _rand(rng::AbstractRNG, sp::SamplerBigFloat, ::CloseOpen12{BigFloat})
7474
z = _rand(rng, sp)[1]
7575
z.exp = 1
7676
z
7777
end
7878

79-
function _rand(rng::AbstractRNG, sp::SamplerBigFloat, ::CloseOpen{BigFloat})
79+
function _rand(rng::AbstractRNG, sp::SamplerBigFloat, ::CloseOpen01{BigFloat})
8080
z, randbool = _rand(rng, sp)
8181
z.exp = 0
8282
randbool &&
@@ -88,8 +88,8 @@ end
8888

8989
# alternative, with 1 bit less of precision
9090
# TODO: make an API for requesting full or not-full precision
91-
function _rand(rng::AbstractRNG, sp::SamplerBigFloat, ::CloseOpen{BigFloat}, ::Nothing)
92-
z = _rand(rng, sp, Close1Open2(BigFloat))
91+
function _rand(rng::AbstractRNG, sp::SamplerBigFloat, ::CloseOpen01{BigFloat}, ::Nothing)
92+
z = _rand(rng, sp, CloseOpen12(BigFloat))
9393
ccall((:mpfr_sub_ui, :libmpfr), Int32, (Ref{BigFloat}, Ref{BigFloat}, Culong, Int32),
9494
z, z, 1, Base.MPFR.ROUNDING_MODE[])
9595
z
@@ -108,7 +108,7 @@ rand(r::AbstractRNG, ::SamplerTrivial{UInt23Raw{UInt32}}) = rand(r, UInt32)
108108
rand(r::AbstractRNG, ::SamplerTrivial{UInt52Raw{UInt64}}) =
109109
_rand52(r, rng_native_52(r))
110110

111-
_rand52(r::AbstractRNG, ::Type{Float64}) = reinterpret(UInt64, rand(r, Close1Open2()))
111+
_rand52(r::AbstractRNG, ::Type{Float64}) = reinterpret(UInt64, rand(r, CloseOpen12()))
112112
_rand52(r::AbstractRNG, ::Type{UInt64}) = rand(r, UInt64)
113113

114114
rand(r::AbstractRNG, ::SamplerTrivial{UInt104Raw{UInt128}}) =
@@ -182,7 +182,7 @@ uint_sup(::Type{<:Union{Int128,UInt128}}) = UInt128
182182

183183
#### Fast
184184

185-
struct SamplerRangeFast{U<:BitUnsigned,T<:BitInteger} <: Sampler
185+
struct SamplerRangeFast{U<:BitUnsigned,T<:BitInteger} <: Sampler{T}
186186
a::T # first element of the range
187187
bw::UInt # bit width
188188
m::U # range length - 1
@@ -243,7 +243,7 @@ maxmultiple(k::T, sup::T=zero(T)) where {T<:Unsigned} =
243243
unsafe_maxmultiple(k::T, sup::T) where {T<:Unsigned} =
244244
div(sup, k + (k == 0))*k - one(k)
245245

246-
struct SamplerRangeInt{T<:Integer,U<:Unsigned} <: Sampler
246+
struct SamplerRangeInt{T<:Integer,U<:Unsigned} <: Sampler{T}
247247
a::T # first element of the range
248248
bw::Int # bit width
249249
k::U # range length or zero for full range
@@ -298,7 +298,7 @@ end
298298

299299
### BigInt
300300

301-
struct SamplerBigInt <: Sampler
301+
struct SamplerBigInt <: Sampler{BigInt}
302302
a::BigInt # first
303303
m::BigInt # range length - 1
304304
nlimbs::Int # number of limbs in generated BigInt's (z ∈ [0, m])
@@ -364,9 +364,10 @@ end
364364

365365
## random values from Set
366366

367-
Sampler(rng::AbstractRNG, t::Set, n::Repetition) = SamplerTag{Set}(Sampler(rng, t.dict, n))
367+
Sampler(rng::AbstractRNG, t::Set{T}, n::Repetition) where {T} =
368+
SamplerTag{Set{T}}(Sampler(rng, t.dict, n))
368369

369-
rand(rng::AbstractRNG, sp::SamplerTag{Set,<:Sampler}) = rand(rng, sp.data).first
370+
rand(rng::AbstractRNG, sp::SamplerTag{<:Set,<:Sampler}) = rand(rng, sp.data).first
370371

371372
## random values from BitSet
372373

base/random/random.jl

Lines changed: 27 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ export srand,
2727

2828
abstract type AbstractRNG end
2929

30+
3031
### integers
3132

3233
# we define types which encode the generation of a specific number of bits
@@ -68,23 +69,25 @@ end
6869

6970
abstract type FloatInterval{T<:AbstractFloat} end
7071

71-
struct CloseOpen{ T<:AbstractFloat} <: FloatInterval{T} end # interval [0,1)
72-
struct Close1Open2{T<:AbstractFloat} <: FloatInterval{T} end # interval [1,2)
72+
struct CloseOpen01{T<:AbstractFloat} <: FloatInterval{T} end # interval [0,1)
73+
struct CloseOpen12{T<:AbstractFloat} <: FloatInterval{T} end # interval [1,2)
7374

7475
const FloatInterval_64 = FloatInterval{Float64}
75-
const CloseOpen_64 = CloseOpen{Float64}
76-
const Close1Open2_64 = Close1Open2{Float64}
76+
const CloseOpen01_64 = CloseOpen01{Float64}
77+
const CloseOpen12_64 = CloseOpen12{Float64}
7778

78-
CloseOpen( ::Type{T}=Float64) where {T<:AbstractFloat} = CloseOpen{T}()
79-
Close1Open2(::Type{T}=Float64) where {T<:AbstractFloat} = Close1Open2{T}()
79+
CloseOpen01(::Type{T}=Float64) where {T<:AbstractFloat} = CloseOpen01{T}()
80+
CloseOpen12(::Type{T}=Float64) where {T<:AbstractFloat} = CloseOpen12{T}()
8081

8182
Base.eltype(::Type{<:FloatInterval{T}}) where {T<:AbstractFloat} = T
8283

8384
const BitFloatType = Union{Type{Float16},Type{Float32},Type{Float64}}
8485

8586
### Sampler
8687

87-
abstract type Sampler end
88+
abstract type Sampler{E} end
89+
90+
Base.eltype(::Sampler{E}) where {E} = E
8891

8992
# temporarily for BaseBenchmarks
9093
RangeGenerator(x) = Sampler(GLOBAL_RNG, x)
@@ -110,41 +113,48 @@ Sampler(rng::AbstractRNG, ::Type{X}) where {X} = Sampler(rng, X, Val(Inf))
110113
#### pre-defined useful Sampler types
111114

112115
# default fall-back for types
113-
struct SamplerType{T} <: Sampler end
116+
struct SamplerType{T} <: Sampler{T} end
114117

115118
Sampler(::AbstractRNG, ::Type{T}, ::Repetition) where {T} = SamplerType{T}()
116119

117-
Base.getindex(sp::SamplerType{T}) where {T} = T
120+
Base.getindex(::SamplerType{T}) where {T} = T
118121

119122
# default fall-back for values
120-
struct SamplerTrivial{T} <: Sampler
123+
struct SamplerTrivial{T,E} <: Sampler{E}
121124
self::T
122125
end
123126

124-
Sampler(::AbstractRNG, X, ::Repetition) = SamplerTrivial(X)
127+
SamplerTrivial(x::T) where {T} = SamplerTrivial{T,eltype(T)}(x)
128+
129+
Sampler(::AbstractRNG, x, ::Repetition) = SamplerTrivial(x)
125130

126131
Base.getindex(sp::SamplerTrivial) = sp.self
127132

128133
# simple sampler carrying data (which can be anything)
129-
struct SamplerSimple{T,S} <: Sampler
134+
struct SamplerSimple{T,S,E} <: Sampler{E}
130135
self::T
131136
data::S
132137
end
133138

139+
SamplerSimple(x::T, data::S) where {T,S} = SamplerSimple{T,S,eltype(T)}(x, data)
140+
134141
Base.getindex(sp::SamplerSimple) = sp.self
135142

136143
# simple sampler carrying a (type) tag T and data
137-
struct SamplerTag{T,S} <: Sampler
144+
struct SamplerTag{T,S,E} <: Sampler{E}
138145
data::S
139-
SamplerTag{T}(s::S) where {T,S} = new{T,S}(s)
146+
SamplerTag{T}(s::S) where {T,S} = new{T,S,eltype(T)}(s)
140147
end
141148

142149

143150
#### helper samplers
144151

152+
# TODO: make constraining constructors to enforce that those
153+
# types are <: Sampler{T}
154+
145155
##### Adapter to generate a randome value in [0, n]
146156

147-
struct LessThan{T<:Integer,S} <: Sampler
157+
struct LessThan{T<:Integer,S} <: Sampler{T}
148158
sup::T
149159
s::S # the scalar specification/sampler to feed to rand
150160
end
@@ -156,7 +166,7 @@ function rand(rng::AbstractRNG, sp::LessThan)
156166
end
157167
end
158168

159-
struct Masked{T<:Integer,S} <: Sampler
169+
struct Masked{T<:Integer,S} <: Sampler{T}
160170
mask::T
161171
s::S
162172
end
@@ -165,7 +175,7 @@ rand(rng::AbstractRNG, sp::Masked) = rand(rng, sp.s) & sp.mask
165175

166176
##### Uniform
167177

168-
struct UniformT{T} <: Sampler end
178+
struct UniformT{T} <: Sampler{T} end
169179

170180
uniform(::Type{T}) where {T} = UniformT{T}()
171181

test/random.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -290,8 +290,8 @@ let mt = MersenneTwister(0)
290290
@test rand!(mt, AF64)[end] == 0.957735065345398
291291
@test rand!(mt, AF64)[end] == 0.6492481059865669
292292
resize!(AF64, 2*length(mt.vals))
293-
@test invoke(rand!, Tuple{MersenneTwister,AbstractArray{Float64},Base.Random.SamplerTrivial{Base.Random.CloseOpen_64}},
294-
mt, AF64, Base.Random.SamplerTrivial(Base.Random.CloseOpen()))[end] == 0.1142787906708973
293+
@test invoke(rand!, Tuple{MersenneTwister,AbstractArray{Float64},Base.Random.SamplerTrivial{Base.Random.CloseOpen01_64}},
294+
mt, AF64, Base.Random.SamplerTrivial(Base.Random.CloseOpen01()))[end] == 0.1142787906708973
295295
end
296296

297297
# Issue #9037

0 commit comments

Comments
 (0)