Skip to content

Commit 76d24f5

Browse files
committed
Optimize uniform_int_distribution
1 parent 2864e25 commit 76d24f5

File tree

1 file changed

+51
-37
lines changed

1 file changed

+51
-37
lines changed

libcxx/include/__random/uniform_int_distribution.h

Lines changed: 51 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ class __independent_bits_engine {
6464
_LIBCPP_HIDE_FROM_ABI __independent_bits_engine(_Engine& __e, size_t __w);
6565

6666
// generating functions
67-
_LIBCPP_HIDE_FROM_ABI result_type operator()() { return __eval(integral_constant<bool, _Rp != 0>()); }
67+
_LIBCPP_HIDE_FROM_ABI result_type operator()() { return __eval(integral_constant<bool, (_Rp & (_Rp - 1)) != 0>()); }
6868

6969
private:
7070
_LIBCPP_HIDE_FROM_ABI result_type __eval(false_type);
@@ -74,60 +74,74 @@ class __independent_bits_engine {
7474
template <class _Engine, class _UIntType>
7575
__independent_bits_engine<_Engine, _UIntType>::__independent_bits_engine(_Engine& __e, size_t __w)
7676
: __e_(__e), __w_(__w) {
77-
__n_ = __w_ / __m + (__w_ % __m != 0);
78-
__w0_ = __w_ / __n_;
79-
if (_Rp == 0)
80-
__y0_ = _Rp;
81-
else if (__w0_ < _WDt)
82-
__y0_ = (_Rp >> __w0_) << __w0_;
83-
else
84-
__y0_ = 0;
85-
if (_Rp - __y0_ > __y0_ / __n_) {
86-
++__n_;
77+
_LIBCPP_ASSERT_INTERNAL(
78+
w <= numeric_limits<result_type>::digits, "cannot sample more bits than result_type can hold");
79+
_LIBCPP_ASSERT_INTERNAL(w > 0, "must sample a positive number of bits");
80+
if (__w_ <= __m) {
81+
__n_ = __n0_ = 1;
82+
__w0_ = __w_;
83+
__mask0_ = __mask1_ = ~_Engine_result_type(0) >> (_EDt - __w0_);
84+
__y0_ = __y1_ = _Rp & ~__mask0_;
85+
} else {
86+
__n_ = (__w_ + __m - 1) / __m;
8787
__w0_ = __w_ / __n_;
88-
if (__w0_ < _WDt)
89-
__y0_ = (_Rp >> __w0_) << __w0_;
90-
else
91-
__y0_ = 0;
88+
__mask0_ = __mask1_ = ~_Engine_result_type(0) >> (_EDt - __w0_);
89+
__y0_ = __y1_ = _Rp & ~__mask0_;
90+
if _LIBCPP_CONSTEXPR_SINCE_CXX17 ((_Rp & (_Rp - 1)) != 0) {
91+
if (_Rp - __y0_ > __y0_ / __n_) {
92+
++__n_;
93+
__w0_ = __w_ / __n_;
94+
__mask0_ = __mask1_ = ~_Engine_result_type(0) >> (_EDt - __w0_);
95+
__y0_ = __y1_ = _Rp & ~__mask0_;
96+
}
97+
}
98+
size_t __n1 = __w_ % __n_;
99+
__n0_ = __n_ - __n1;
100+
if (__n1 > 0) {
101+
__mask1_ = ~_Engine_result_type(0) >> (_EDt - (__w0_ + 1));
102+
__y1_ = _Rp & ~__mask1_;
103+
}
92104
}
93-
__n0_ = __n_ - __w_ % __n_;
94-
if (__w0_ < _WDt - 1)
95-
__y1_ = (_Rp >> (__w0_ + 1)) << (__w0_ + 1);
96-
else
97-
__y1_ = 0;
98-
__mask0_ = __w0_ > 0 ? _Engine_result_type(~0) >> (_EDt - __w0_) : _Engine_result_type(0);
99-
__mask1_ = __w0_ < _EDt - 1 ? _Engine_result_type(~0) >> (_EDt - (__w0_ + 1)) : _Engine_result_type(~0);
100105
}
101106

102107
template <class _Engine, class _UIntType>
103108
inline _UIntType __independent_bits_engine<_Engine, _UIntType>::__eval(false_type) {
104-
return static_cast<result_type>(__e_() & __mask0_);
109+
result_type __sp = (__e_() - _Engine::min()) & __mask0_;
110+
for (size_t __k = 1; __k < __n0_; ++__k) {
111+
__sp <<= __w0_;
112+
__sp += (__e_() - _Engine::min()) & __mask0_;
113+
}
114+
for (size_t __k = __n0_; __k < __n_; ++__k) {
115+
__sp <<= __w0_ + 1;
116+
__sp += (__e_() - _Engine::min()) & __mask1_;
117+
}
118+
return __sp;
105119
}
106120

107121
template <class _Engine, class _UIntType>
108122
_UIntType __independent_bits_engine<_Engine, _UIntType>::__eval(true_type) {
109-
const size_t __w_rt = numeric_limits<result_type>::digits;
110-
result_type __sp = 0;
111-
for (size_t __k = 0; __k < __n0_; ++__k) {
123+
result_type __sp;
124+
{
125+
_Engine_result_type __u;
126+
do {
127+
__u = __e_() - _Engine::min();
128+
} while (__u >= __y0_);
129+
__sp = __u & __mask0_;
130+
}
131+
for (size_t __k = 1; __k < __n0_; ++__k) {
112132
_Engine_result_type __u;
113133
do {
114134
__u = __e_() - _Engine::min();
115135
} while (__u >= __y0_);
116-
if (__w0_ < __w_rt)
117-
__sp <<= __w0_;
118-
else
119-
__sp = 0;
136+
__sp <<= __w0_;
120137
__sp += __u & __mask0_;
121138
}
122139
for (size_t __k = __n0_; __k < __n_; ++__k) {
123140
_Engine_result_type __u;
124141
do {
125142
__u = __e_() - _Engine::min();
126143
} while (__u >= __y1_);
127-
if (__w0_ < __w_rt - 1)
128-
__sp <<= __w0_ + 1;
129-
else
130-
__sp = 0;
144+
__sp <<= __w0_ + 1;
131145
__sp += __u & __mask1_;
132146
}
133147
return __sp;
@@ -218,9 +232,9 @@ typename uniform_int_distribution<_IntType>::result_type uniform_int_distributio
218232
typedef __independent_bits_engine<_URNG, _UIntType> _Eng;
219233
if (__rp == 0)
220234
return static_cast<result_type>(_Eng(__g, __dt)());
221-
size_t __w = __dt - std::__countl_zero(__rp) - 1;
222-
if ((__rp & (numeric_limits<_UIntType>::max() >> (__dt - __w))) != 0)
223-
++__w;
235+
size_t __w = __dt - std::__countl_zero(__rp);
236+
if ((__rp & (__rp - 1)) == 0)
237+
return static_cast<result_type>(_Eng(__g, __w - 1)() + __p.a());
224238
_Eng __e(__g, __w);
225239
_UIntType __u;
226240
do {

0 commit comments

Comments
 (0)