@@ -20,7 +20,7 @@ use distributions::Distribution;
20
20
/// ```rust
21
21
/// use rand::distributions::{Bernoulli, Distribution};
22
22
///
23
- /// let d = Bernoulli::new(0.3);
23
+ /// let d = Bernoulli::new(0.3).unwrap() ;
24
24
/// let v = d.sample(&mut rand::thread_rng());
25
25
/// println!("{} is from a Bernoulli distribution", v);
26
26
/// ```
@@ -61,13 +61,16 @@ const ALWAYS_TRUE: u64 = ::core::u64::MAX;
61
61
// in `no_std` mode.
62
62
const SCALE : f64 = 2.0 * ( 1u64 << 63 ) as f64 ;
63
63
64
+ /// Error type returned from `Bernoulli::new`.
65
+ #[ derive( Clone , Copy , Debug , PartialEq , Eq ) ]
66
+ pub enum BernoulliError {
67
+ /// `p < 0` or `p > 1`.
68
+ InvalidProbability ,
69
+ }
70
+
64
71
impl Bernoulli {
65
72
/// Construct a new `Bernoulli` with the given probability of success `p`.
66
73
///
67
- /// # Panics
68
- ///
69
- /// If `p < 0` or `p > 1`.
70
- ///
71
74
/// # Precision
72
75
///
73
76
/// For `p = 1.0`, the resulting distribution will always generate true.
@@ -77,12 +80,12 @@ impl Bernoulli {
77
80
/// a multiple of 2<sup>-64</sup>. (Note that not all multiples of
78
81
/// 2<sup>-64</sup> in `[0, 1]` can be represented as a `f64`.)
79
82
#[ inline]
80
- pub fn new ( p : f64 ) -> Bernoulli {
83
+ pub fn new ( p : f64 ) -> Result < Bernoulli , BernoulliError > {
81
84
if p < 0.0 || p >= 1.0 {
82
- if p == 1.0 { return Bernoulli { p_int : ALWAYS_TRUE } }
83
- panic ! ( "Bernoulli::new not called with 0.0 <= p <= 1.0" ) ;
85
+ if p == 1.0 { return Ok ( Bernoulli { p_int : ALWAYS_TRUE } ) }
86
+ return Err ( BernoulliError :: InvalidProbability ) ;
84
87
}
85
- Bernoulli { p_int : ( p * SCALE ) as u64 }
88
+ Ok ( Bernoulli { p_int : ( p * SCALE ) as u64 } )
86
89
}
87
90
88
91
/// Construct a new `Bernoulli` with the probability of success of
@@ -91,19 +94,16 @@ impl Bernoulli {
91
94
///
92
95
/// If `numerator == denominator` then the returned `Bernoulli` will always
93
96
/// return `true`. If `numerator == 0` it will always return `false`.
94
- ///
95
- /// # Panics
96
- ///
97
- /// If `denominator == 0` or `numerator > denominator`.
98
- ///
99
97
#[ inline]
100
- pub fn from_ratio ( numerator : u32 , denominator : u32 ) -> Bernoulli {
101
- assert ! ( numerator <= denominator) ;
98
+ pub fn from_ratio ( numerator : u32 , denominator : u32 ) -> Result < Bernoulli , BernoulliError > {
99
+ if !( numerator <= denominator) {
100
+ return Err ( BernoulliError :: InvalidProbability ) ;
101
+ }
102
102
if numerator == denominator {
103
- return Bernoulli { p_int : :: core :: u64 :: MAX }
103
+ return Ok ( Bernoulli { p_int : ALWAYS_TRUE } )
104
104
}
105
105
let p_int = ( ( numerator as f64 / denominator as f64 ) * SCALE ) as u64 ;
106
- Bernoulli { p_int }
106
+ Ok ( Bernoulli { p_int } )
107
107
}
108
108
}
109
109
@@ -126,8 +126,8 @@ mod test {
126
126
#[ test]
127
127
fn test_trivial ( ) {
128
128
let mut r = :: test:: rng ( 1 ) ;
129
- let always_false = Bernoulli :: new ( 0.0 ) ;
130
- let always_true = Bernoulli :: new ( 1.0 ) ;
129
+ let always_false = Bernoulli :: new ( 0.0 ) . unwrap ( ) ;
130
+ let always_true = Bernoulli :: new ( 1.0 ) . unwrap ( ) ;
131
131
for _ in 0 ..5 {
132
132
assert_eq ! ( r. sample:: <bool , _>( & always_false) , false ) ;
133
133
assert_eq ! ( r. sample:: <bool , _>( & always_true) , true ) ;
@@ -142,8 +142,8 @@ mod test {
142
142
const P : f64 = 0.3 ;
143
143
const NUM : u32 = 3 ;
144
144
const DENOM : u32 = 10 ;
145
- let d1 = Bernoulli :: new ( P ) ;
146
- let d2 = Bernoulli :: from_ratio ( NUM , DENOM ) ;
145
+ let d1 = Bernoulli :: new ( P ) . unwrap ( ) ;
146
+ let d2 = Bernoulli :: from_ratio ( NUM , DENOM ) . unwrap ( ) ;
147
147
const N : u32 = 100_000 ;
148
148
149
149
let mut sum1: u32 = 0 ;
0 commit comments