Skip to content

Commit 3788b51

Browse files
committed
make Sampler{E} encode the type E of elements which are generated
Before, a call like `rand(rng, Sampler(rng, 1:10), 3)` generated an `Array{Any,1}`, so a way to get the `eltype` of a Sampler is necessary. Instead of changing Sampler -> Sampler{E}, implementing appropriate eltype methods would have been possible, to keep the helper Sampler subtypes more flexible, but it seemed to be simpler this way.
1 parent da29da1 commit 3788b51

File tree

2 files changed

+28
-17
lines changed

2 files changed

+28
-17
lines changed

base/random/generation.jl

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ rand(r::AbstractRNG, ::SamplerTrivial{CloseOpen01_64}) = rand(r, CloseOpen12())
3939
const bits_in_Limb = sizeof(Limb) << 3
4040
const Limb_high_bit = one(Limb) << (bits_in_Limb-1)
4141

42-
struct SamplerBigFloat{I<:FloatInterval{BigFloat}} <: Sampler
42+
struct SamplerBigFloat{I<:FloatInterval{BigFloat}} <: Sampler{BigFloat}
4343
prec::Int
4444
nlimbs::Int
4545
limbs::Vector{Limb}
@@ -182,7 +182,7 @@ uint_sup(::Type{<:Union{Int128,UInt128}}) = UInt128
182182

183183
#### Fast
184184

185-
struct SamplerRangeFast{U<:BitUnsigned,T<:BitInteger} <: Sampler
185+
struct SamplerRangeFast{U<:BitUnsigned,T<:BitInteger} <: Sampler{T}
186186
a::T # first element of the range
187187
bw::UInt # bit width
188188
m::U # range length - 1
@@ -243,7 +243,7 @@ maxmultiple(k::T, sup::T=zero(T)) where {T<:Unsigned} =
243243
unsafe_maxmultiple(k::T, sup::T) where {T<:Unsigned} =
244244
div(sup, k + (k == 0))*k - one(k)
245245

246-
struct SamplerRangeInt{T<:Integer,U<:Unsigned} <: Sampler
246+
struct SamplerRangeInt{T<:Integer,U<:Unsigned} <: Sampler{T}
247247
a::T # first element of the range
248248
bw::Int # bit width
249249
k::U # range length or zero for full range
@@ -298,7 +298,7 @@ end
298298

299299
### BigInt
300300

301-
struct SamplerBigInt <: Sampler
301+
struct SamplerBigInt <: Sampler{BigInt}
302302
a::BigInt # first
303303
m::BigInt # range length - 1
304304
nlimbs::Int # number of limbs in generated BigInt's (z ∈ [0, m])
@@ -364,9 +364,10 @@ end
364364

365365
## random values from Set
366366

367-
Sampler(rng::AbstractRNG, t::Set, n::Repetition) = SamplerTag{Set}(Sampler(rng, t.dict, n))
367+
Sampler(rng::AbstractRNG, t::Set{T}, n::Repetition) where {T} =
368+
SamplerTag{Set{T}}(Sampler(rng, t.dict, n))
368369

369-
rand(rng::AbstractRNG, sp::SamplerTag{Set,<:Sampler}) = rand(rng, sp.data).first
370+
rand(rng::AbstractRNG, sp::SamplerTag{<:Set,<:Sampler}) = rand(rng, sp.data).first
370371

371372
## random values from BitSet
372373

base/random/random.jl

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ export srand,
2626

2727
abstract type AbstractRNG end
2828

29+
2930
### integers
3031

3132
# we define types which encode the generation of a specific number of bits
@@ -83,7 +84,9 @@ const BitFloatType = Union{Type{Float16},Type{Float32},Type{Float64}}
8384

8485
### Sampler
8586

86-
abstract type Sampler end
87+
abstract type Sampler{E} end
88+
89+
Base.eltype(::Sampler{E}) where {E} = E
8790

8891
# temporarily for BaseBenchmarks
8992
RangeGenerator(x) = Sampler(GLOBAL_RNG, x)
@@ -109,41 +112,48 @@ Sampler(rng::AbstractRNG, ::Type{X}) where {X} = Sampler(rng, X, Val(Inf))
109112
#### pre-defined useful Sampler types
110113

111114
# default fall-back for types
112-
struct SamplerType{T} <: Sampler end
115+
struct SamplerType{T} <: Sampler{T} end
113116

114117
Sampler(::AbstractRNG, ::Type{T}, ::Repetition) where {T} = SamplerType{T}()
115118

116-
Base.getindex(sp::SamplerType{T}) where {T} = T
119+
Base.getindex(::SamplerType{T}) where {T} = T
117120

118121
# default fall-back for values
119-
struct SamplerTrivial{T} <: Sampler
122+
struct SamplerTrivial{T,E} <: Sampler{E}
120123
self::T
121124
end
122125

123-
Sampler(::AbstractRNG, X, ::Repetition) = SamplerTrivial(X)
126+
SamplerTrivial(x::T) where {T} = SamplerTrivial{T,eltype(T)}(x)
127+
128+
Sampler(::AbstractRNG, x, ::Repetition) = SamplerTrivial(x)
124129

125130
Base.getindex(sp::SamplerTrivial) = sp.self
126131

127132
# simple sampler carrying data (which can be anything)
128-
struct SamplerSimple{T,S} <: Sampler
133+
struct SamplerSimple{T,S,E} <: Sampler{E}
129134
self::T
130135
data::S
131136
end
132137

138+
SamplerSimple(x::T, data::S) where {T,S} = SamplerSimple{T,S,eltype(T)}(x, data)
139+
133140
Base.getindex(sp::SamplerSimple) = sp.self
134141

135142
# simple sampler carrying a (type) tag T and data
136-
struct SamplerTag{T,S} <: Sampler
143+
struct SamplerTag{T,S,E} <: Sampler{E}
137144
data::S
138-
SamplerTag{T}(s::S) where {T,S} = new{T,S}(s)
145+
SamplerTag{T}(s::S) where {T,S} = new{T,S,eltype(T)}(s)
139146
end
140147

141148

142149
#### helper samplers
143150

151+
# TODO: make constraining constructors to enforce that those
152+
# types are <: Sampler{T}
153+
144154
##### Adapter to generate a randome value in [0, n]
145155

146-
struct LessThan{T<:Integer,S} <: Sampler
156+
struct LessThan{T<:Integer,S} <: Sampler{T}
147157
sup::T
148158
s::S # the scalar specification/sampler to feed to rand
149159
end
@@ -155,7 +165,7 @@ function rand(rng::AbstractRNG, sp::LessThan)
155165
end
156166
end
157167

158-
struct Masked{T<:Integer,S} <: Sampler
168+
struct Masked{T<:Integer,S} <: Sampler{T}
159169
mask::T
160170
s::S
161171
end
@@ -164,7 +174,7 @@ rand(rng::AbstractRNG, sp::Masked) = rand(rng, sp.s) & sp.mask
164174

165175
##### Uniform
166176

167-
struct UniformT{T} <: Sampler end
177+
struct UniformT{T} <: Sampler{T} end
168178

169179
uniform(::Type{T}) where {T} = UniformT{T}()
170180

0 commit comments

Comments
 (0)