Skip to content

Commit 5f41307

Browse files
committed
Optimize uniform_int_distribution
1 parent 2864e25 commit 5f41307

File tree

1 file changed

+52
-37
lines changed

1 file changed

+52
-37
lines changed

libcxx/include/__random/uniform_int_distribution.h

Lines changed: 52 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#ifndef _LIBCPP___RANDOM_UNIFORM_INT_DISTRIBUTION_H
1010
#define _LIBCPP___RANDOM_UNIFORM_INT_DISTRIBUTION_H
1111

12+
#include <__assert>
1213
#include <__bit/countl.h>
1314
#include <__config>
1415
#include <__cstddef/size_t.h>
@@ -64,7 +65,7 @@ class __independent_bits_engine {
6465
_LIBCPP_HIDE_FROM_ABI __independent_bits_engine(_Engine& __e, size_t __w);
6566

6667
// generating functions
67-
_LIBCPP_HIDE_FROM_ABI result_type operator()() { return __eval(integral_constant<bool, _Rp != 0>()); }
68+
_LIBCPP_HIDE_FROM_ABI result_type operator()() { return __eval(integral_constant<bool, (_Rp & (_Rp - 1)) != 0>()); }
6869

6970
private:
7071
_LIBCPP_HIDE_FROM_ABI result_type __eval(false_type);
@@ -74,60 +75,74 @@ class __independent_bits_engine {
7475
template <class _Engine, class _UIntType>
7576
__independent_bits_engine<_Engine, _UIntType>::__independent_bits_engine(_Engine& __e, size_t __w)
7677
: __e_(__e), __w_(__w) {
77-
__n_ = __w_ / __m + (__w_ % __m != 0);
78-
__w0_ = __w_ / __n_;
79-
if (_Rp == 0)
80-
__y0_ = _Rp;
81-
else if (__w0_ < _WDt)
82-
__y0_ = (_Rp >> __w0_) << __w0_;
83-
else
84-
__y0_ = 0;
85-
if (_Rp - __y0_ > __y0_ / __n_) {
86-
++__n_;
78+
_LIBCPP_ASSERT_INTERNAL(
79+
__w_ <= numeric_limits<result_type>::digits, "cannot sample more bits than result_type can hold");
80+
_LIBCPP_ASSERT_INTERNAL(__w_ > 0, "must sample a positive number of bits");
81+
if (__w_ <= __m) {
82+
__n_ = __n0_ = 1;
83+
__w0_ = __w_;
84+
__mask0_ = __mask1_ = ~_Engine_result_type(0) >> (_EDt - __w0_);
85+
__y0_ = __y1_ = _Rp & ~__mask0_;
86+
} else {
87+
__n_ = (__w_ + __m - 1) / __m;
8788
__w0_ = __w_ / __n_;
88-
if (__w0_ < _WDt)
89-
__y0_ = (_Rp >> __w0_) << __w0_;
90-
else
91-
__y0_ = 0;
89+
__mask0_ = __mask1_ = ~_Engine_result_type(0) >> (_EDt - __w0_);
90+
__y0_ = __y1_ = _Rp & ~__mask0_;
91+
if _LIBCPP_CONSTEXPR_SINCE_CXX17 ((_Rp & (_Rp - 1)) != 0) {
92+
if (_Rp - __y0_ > __y0_ / __n_) {
93+
++__n_;
94+
__w0_ = __w_ / __n_;
95+
__mask0_ = __mask1_ = ~_Engine_result_type(0) >> (_EDt - __w0_);
96+
__y0_ = __y1_ = _Rp & ~__mask0_;
97+
}
98+
}
99+
size_t __n1 = __w_ % __n_;
100+
__n0_ = __n_ - __n1;
101+
if (__n1 > 0) {
102+
__mask1_ = ~_Engine_result_type(0) >> (_EDt - (__w0_ + 1));
103+
__y1_ = _Rp & ~__mask1_;
104+
}
92105
}
93-
__n0_ = __n_ - __w_ % __n_;
94-
if (__w0_ < _WDt - 1)
95-
__y1_ = (_Rp >> (__w0_ + 1)) << (__w0_ + 1);
96-
else
97-
__y1_ = 0;
98-
__mask0_ = __w0_ > 0 ? _Engine_result_type(~0) >> (_EDt - __w0_) : _Engine_result_type(0);
99-
__mask1_ = __w0_ < _EDt - 1 ? _Engine_result_type(~0) >> (_EDt - (__w0_ + 1)) : _Engine_result_type(~0);
100106
}
101107

102108
template <class _Engine, class _UIntType>
103109
inline _UIntType __independent_bits_engine<_Engine, _UIntType>::__eval(false_type) {
104-
return static_cast<result_type>(__e_() & __mask0_);
110+
result_type __sp = (__e_() - _Engine::min()) & __mask0_;
111+
for (size_t __k = 1; __k < __n0_; ++__k) {
112+
__sp <<= __w0_;
113+
__sp += (__e_() - _Engine::min()) & __mask0_;
114+
}
115+
for (size_t __k = __n0_; __k < __n_; ++__k) {
116+
__sp <<= __w0_ + 1;
117+
__sp += (__e_() - _Engine::min()) & __mask1_;
118+
}
119+
return __sp;
105120
}
106121

107122
template <class _Engine, class _UIntType>
108123
_UIntType __independent_bits_engine<_Engine, _UIntType>::__eval(true_type) {
109-
const size_t __w_rt = numeric_limits<result_type>::digits;
110-
result_type __sp = 0;
111-
for (size_t __k = 0; __k < __n0_; ++__k) {
124+
result_type __sp;
125+
{
126+
_Engine_result_type __u;
127+
do {
128+
__u = __e_() - _Engine::min();
129+
} while (__u >= __y0_);
130+
__sp = __u & __mask0_;
131+
}
132+
for (size_t __k = 1; __k < __n0_; ++__k) {
112133
_Engine_result_type __u;
113134
do {
114135
__u = __e_() - _Engine::min();
115136
} while (__u >= __y0_);
116-
if (__w0_ < __w_rt)
117-
__sp <<= __w0_;
118-
else
119-
__sp = 0;
137+
__sp <<= __w0_;
120138
__sp += __u & __mask0_;
121139
}
122140
for (size_t __k = __n0_; __k < __n_; ++__k) {
123141
_Engine_result_type __u;
124142
do {
125143
__u = __e_() - _Engine::min();
126144
} while (__u >= __y1_);
127-
if (__w0_ < __w_rt - 1)
128-
__sp <<= __w0_ + 1;
129-
else
130-
__sp = 0;
145+
__sp <<= __w0_ + 1;
131146
__sp += __u & __mask1_;
132147
}
133148
return __sp;
@@ -218,9 +233,9 @@ typename uniform_int_distribution<_IntType>::result_type uniform_int_distributio
218233
typedef __independent_bits_engine<_URNG, _UIntType> _Eng;
219234
if (__rp == 0)
220235
return static_cast<result_type>(_Eng(__g, __dt)());
221-
size_t __w = __dt - std::__countl_zero(__rp) - 1;
222-
if ((__rp & (numeric_limits<_UIntType>::max() >> (__dt - __w))) != 0)
223-
++__w;
236+
size_t __w = __dt - std::__countl_zero(__rp);
237+
if ((__rp & (__rp - 1)) == 0)
238+
return static_cast<result_type>(_Eng(__g, __w - 1)() + __p.a());
224239
_Eng __e(__g, __w);
225240
_UIntType __u;
226241
do {

0 commit comments

Comments
 (0)