Skip to content

Commit 4e418c8

Browse files
committed
Optimize uniform_int_distribution
1 parent 2864e25 commit 4e418c8

File tree

1 file changed

+50
-37
lines changed

1 file changed

+50
-37
lines changed

libcxx/include/__random/uniform_int_distribution.h

Lines changed: 50 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ class __independent_bits_engine {
6464
_LIBCPP_HIDE_FROM_ABI __independent_bits_engine(_Engine& __e, size_t __w);
6565

6666
// generating functions
67-
_LIBCPP_HIDE_FROM_ABI result_type operator()() { return __eval(integral_constant<bool, _Rp != 0>()); }
67+
_LIBCPP_HIDE_FROM_ABI result_type operator()() { return __eval(integral_constant<bool, (_Rp & (_Rp - 1)) != 0>()); }
6868

6969
private:
7070
_LIBCPP_HIDE_FROM_ABI result_type __eval(false_type);
@@ -74,60 +74,73 @@ class __independent_bits_engine {
7474
template <class _Engine, class _UIntType>
7575
__independent_bits_engine<_Engine, _UIntType>::__independent_bits_engine(_Engine& __e, size_t __w)
7676
: __e_(__e), __w_(__w) {
77-
__n_ = __w_ / __m + (__w_ % __m != 0);
78-
__w0_ = __w_ / __n_;
79-
if (_Rp == 0)
80-
__y0_ = _Rp;
81-
else if (__w0_ < _WDt)
82-
__y0_ = (_Rp >> __w0_) << __w0_;
83-
else
84-
__y0_ = 0;
85-
if (_Rp - __y0_ > __y0_ / __n_) {
86-
++__n_;
77+
_LIBCPP_ASSERT_INTERNAL(w <= numeric_limits<result_type>::digits, "cannot sample more bits than result_type can hold");
78+
_LIBCPP_ASSERT_INTERNAL(w > 0, "must sample a positive number of bits");
79+
if (__w_ <= __m) {
80+
__n_ = __n0_ = 1;
81+
__w0_ = __w_;
82+
__mask0_ = __mask1_ = ~_Engine_result_type(0) >> (_EDt - __w0_);
83+
__y0_ = __y1_ = _Rp & ~__mask0_;
84+
} else {
85+
__n_ = (__w_ + __m - 1) / __m;
8786
__w0_ = __w_ / __n_;
88-
if (__w0_ < _WDt)
89-
__y0_ = (_Rp >> __w0_) << __w0_;
90-
else
91-
__y0_ = 0;
87+
__mask0_ = __mask1_ = ~_Engine_result_type(0) >> (_EDt - __w0_);
88+
__y0_ = __y1_ = _Rp & ~__mask0_;
89+
if _LIBCPP_CONSTEXPR_SINCE_CXX17 ((_Rp & (_Rp - 1)) != 0) {
90+
if (_Rp - __y0_ > __y0_ / __n_) {
91+
++__n_;
92+
__w0_ = __w_ / __n_;
93+
__mask0_ = __mask1_ = ~_Engine_result_type(0) >> (_EDt - __w0_);
94+
__y0_ = __y1_ = _Rp & ~__mask0_;
95+
}
96+
}
97+
size_t __n1 = __w_ % __n_;
98+
__n0_ = __n_ - __n1;
99+
if (__n1 > 0) {
100+
__mask1_ = ~_Engine_result_type(0) >> (_EDt - (__w0_ + 1));
101+
__y1_ = _Rp & ~__mask1_;
102+
}
92103
}
93-
__n0_ = __n_ - __w_ % __n_;
94-
if (__w0_ < _WDt - 1)
95-
__y1_ = (_Rp >> (__w0_ + 1)) << (__w0_ + 1);
96-
else
97-
__y1_ = 0;
98-
__mask0_ = __w0_ > 0 ? _Engine_result_type(~0) >> (_EDt - __w0_) : _Engine_result_type(0);
99-
__mask1_ = __w0_ < _EDt - 1 ? _Engine_result_type(~0) >> (_EDt - (__w0_ + 1)) : _Engine_result_type(~0);
100104
}
101105

102106
template <class _Engine, class _UIntType>
103107
inline _UIntType __independent_bits_engine<_Engine, _UIntType>::__eval(false_type) {
104-
return static_cast<result_type>(__e_() & __mask0_);
108+
result_type __sp = (__e_() - _Engine::min()) & __mask0_;
109+
for (size_t __k = 1; __k < __n0_; ++__k) {
110+
__sp <<= __w0_;
111+
__sp += (__e_() - _Engine::min()) & __mask0_;
112+
}
113+
for (size_t __k = __n0_; __k < __n_; ++__k) {
114+
__sp <<= __w0_ + 1;
115+
__sp += (__e_() - _Engine::min()) & __mask1_;
116+
}
117+
return __sp;
105118
}
106119

107120
template <class _Engine, class _UIntType>
108121
_UIntType __independent_bits_engine<_Engine, _UIntType>::__eval(true_type) {
109-
const size_t __w_rt = numeric_limits<result_type>::digits;
110-
result_type __sp = 0;
111-
for (size_t __k = 0; __k < __n0_; ++__k) {
122+
result_type __sp;
123+
{
124+
_Engine_result_type __u;
125+
do {
126+
__u = __e_() - _Engine::min();
127+
} while (__u >= __y0_);
128+
__sp = __u & __mask0_;
129+
}
130+
for (size_t __k = 1; __k < __n0_; ++__k) {
112131
_Engine_result_type __u;
113132
do {
114133
__u = __e_() - _Engine::min();
115134
} while (__u >= __y0_);
116-
if (__w0_ < __w_rt)
117-
__sp <<= __w0_;
118-
else
119-
__sp = 0;
135+
__sp <<= __w0_;
120136
__sp += __u & __mask0_;
121137
}
122138
for (size_t __k = __n0_; __k < __n_; ++__k) {
123139
_Engine_result_type __u;
124140
do {
125141
__u = __e_() - _Engine::min();
126142
} while (__u >= __y1_);
127-
if (__w0_ < __w_rt - 1)
128-
__sp <<= __w0_ + 1;
129-
else
130-
__sp = 0;
143+
__sp <<= __w0_ + 1;
131144
__sp += __u & __mask1_;
132145
}
133146
return __sp;
@@ -218,9 +231,9 @@ typename uniform_int_distribution<_IntType>::result_type uniform_int_distributio
218231
typedef __independent_bits_engine<_URNG, _UIntType> _Eng;
219232
if (__rp == 0)
220233
return static_cast<result_type>(_Eng(__g, __dt)());
221-
size_t __w = __dt - std::__countl_zero(__rp) - 1;
222-
if ((__rp & (numeric_limits<_UIntType>::max() >> (__dt - __w))) != 0)
223-
++__w;
234+
size_t __w = __dt - std::__countl_zero(__rp);
235+
if ((__rp & (__rp - 1)) == 0)
236+
return static_cast<result_type>(_Eng(__g, __w - 1)() + __p.a());
224237
_Eng __e(__g, __w);
225238
_UIntType __u;
226239
do {

0 commit comments

Comments
 (0)