Skip to content

Commit a16c72f

Browse files
committed
Optimize uniform_int_distribution
1 parent 198f080 commit a16c72f

File tree

2 files changed

+52
-39
lines changed

2 files changed

+52
-39
lines changed

libcxx/include/__random/uniform_int_distribution.h

Lines changed: 48 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ class __independent_bits_engine {
6464
_LIBCPP_HIDE_FROM_ABI __independent_bits_engine(_Engine& __e, size_t __w);
6565

6666
// generating functions
67-
_LIBCPP_HIDE_FROM_ABI result_type operator()() { return __eval(integral_constant<bool, _Rp != 0>()); }
67+
_LIBCPP_HIDE_FROM_ABI result_type operator()() { return __eval(integral_constant<bool, (_Rp & (_Rp - 1)) != 0>()); }
6868

6969
private:
7070
_LIBCPP_HIDE_FROM_ABI result_type __eval(false_type);
@@ -74,60 +74,71 @@ class __independent_bits_engine {
7474
template <class _Engine, class _UIntType>
7575
__independent_bits_engine<_Engine, _UIntType>::__independent_bits_engine(_Engine& __e, size_t __w)
7676
: __e_(__e), __w_(__w) {
77-
__n_ = __w_ / __m + (__w_ % __m != 0);
78-
__w0_ = __w_ / __n_;
79-
if (_Rp == 0)
80-
__y0_ = _Rp;
81-
else if (__w0_ < _WDt)
82-
__y0_ = (_Rp >> __w0_) << __w0_;
83-
else
84-
__y0_ = 0;
85-
if (_Rp - __y0_ > __y0_ / __n_) {
86-
++__n_;
77+
if (__w_ <= __m) {
78+
__n_ = __n0_ = 1;
79+
__w0_ = __w_;
80+
__mask0_ = __mask1_ = ~_Engine_result_type(0) >> (_EDt - __w0_);
81+
__y0_ = __y1_ = _Rp & ~__mask0_;
82+
} else {
83+
__n_ = (__w_ + __m - 1) / __m;
8784
__w0_ = __w_ / __n_;
88-
if (__w0_ < _WDt)
89-
__y0_ = (_Rp >> __w0_) << __w0_;
90-
else
91-
__y0_ = 0;
85+
__mask0_ = __mask1_ = ~_Engine_result_type(0) >> (_EDt - __w0_);
86+
__y0_ = __y1_ = _Rp & ~__mask0_;
87+
if _LIBCPP_CONSTEXPR_SINCE_CXX17 ((_Rp & (_Rp - 1)) != 0) {
88+
if (_Rp - __y0_ > __y0_ / __n_) {
89+
++__n_;
90+
__w0_ = __w_ / __n_;
91+
__mask0_ = __mask1_ = ~_Engine_result_type(0) >> (_EDt - __w0_);
92+
__y0_ = __y1_ = _Rp & ~__mask0_;
93+
}
94+
}
95+
size_t __n1 = __w_ % __n_;
96+
__n0_ = __n_ - __n1;
97+
if (__n1 > 0) {
98+
__mask1_ = ~_Engine_result_type(0) >> (_EDt - (__w0_ + 1));
99+
__y1_ = _Rp & ~__mask1_;
100+
}
92101
}
93-
__n0_ = __n_ - __w_ % __n_;
94-
if (__w0_ < _WDt - 1)
95-
__y1_ = (_Rp >> (__w0_ + 1)) << (__w0_ + 1);
96-
else
97-
__y1_ = 0;
98-
__mask0_ = __w0_ > 0 ? _Engine_result_type(~0) >> (_EDt - __w0_) : _Engine_result_type(0);
99-
__mask1_ = __w0_ < _EDt - 1 ? _Engine_result_type(~0) >> (_EDt - (__w0_ + 1)) : _Engine_result_type(~0);
100102
}
101103

102104
template <class _Engine, class _UIntType>
103105
inline _UIntType __independent_bits_engine<_Engine, _UIntType>::__eval(false_type) {
104-
return static_cast<result_type>(__e_() & __mask0_);
106+
result_type __sp = (__e_() - _Engine::min()) & __mask0_;
107+
for (size_t __k = 1; __k < __n0_; ++__k) {
108+
__sp <<= __w0_;
109+
__sp += (__e_() - _Engine::min()) & __mask0_;
110+
}
111+
for (size_t __k = __n0_; __k < __n_; ++__k) {
112+
__sp <<= __w0_ + 1;
113+
__sp += (__e_() - _Engine::min()) & __mask1_;
114+
}
115+
return __sp;
105116
}
106117

107118
template <class _Engine, class _UIntType>
108119
_UIntType __independent_bits_engine<_Engine, _UIntType>::__eval(true_type) {
109-
const size_t __w_rt = numeric_limits<result_type>::digits;
110-
result_type __sp = 0;
111-
for (size_t __k = 0; __k < __n0_; ++__k) {
120+
result_type __sp;
121+
{
122+
_Engine_result_type __u;
123+
do {
124+
__u = __e_() - _Engine::min();
125+
} while (__u >= __y0_);
126+
__sp = __u & __mask0_;
127+
}
128+
for (size_t __k = 1; __k < __n0_; ++__k) {
112129
_Engine_result_type __u;
113130
do {
114131
__u = __e_() - _Engine::min();
115132
} while (__u >= __y0_);
116-
if (__w0_ < __w_rt)
117-
__sp <<= __w0_;
118-
else
119-
__sp = 0;
133+
__sp <<= __w0_;
120134
__sp += __u & __mask0_;
121135
}
122136
for (size_t __k = __n0_; __k < __n_; ++__k) {
123137
_Engine_result_type __u;
124138
do {
125139
__u = __e_() - _Engine::min();
126140
} while (__u >= __y1_);
127-
if (__w0_ < __w_rt - 1)
128-
__sp <<= __w0_ + 1;
129-
else
130-
__sp = 0;
141+
__sp <<= __w0_ + 1;
131142
__sp += __u & __mask1_;
132143
}
133144
return __sp;
@@ -218,9 +229,9 @@ typename uniform_int_distribution<_IntType>::result_type uniform_int_distributio
218229
typedef __independent_bits_engine<_URNG, _UIntType> _Eng;
219230
if (__rp == 0)
220231
return static_cast<result_type>(_Eng(__g, __dt)());
221-
size_t __w = __dt - std::__countl_zero(__rp) - 1;
222-
if ((__rp & (numeric_limits<_UIntType>::max() >> (__dt - __w))) != 0)
223-
++__w;
232+
size_t __w = __dt - std::__countl_zero(__rp);
233+
if ((__rp & (__rp - 1)) == 0)
234+
return static_cast<result_type>(_Eng(__g, __w - 1)() + __p.a());
224235
_Eng __e(__g, __w);
225236
_UIntType __u;
226237
do {

libcxx/test/std/numerics/rand/rand.dist/rand.dist.uni/rand.dist.uni.int/output.pass.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,10 @@
1111
// <random>
1212

1313
#include <array>
14-
#include <random>
1514
#include <cassert>
15+
#include <cstddef>
16+
#include <cstdint>
17+
#include <random>
1618

1719
#include "test_macros.h"
1820

@@ -630,7 +632,7 @@ constexpr std::array<std::uint64_t, 256> mt19937_high_results = {
630632
61396520889854792,
631633
};
632634

633-
template <typename Eng, size_t Shift, typename Result, size_t N>
635+
template <typename Eng, std::size_t Shift, typename Result, std::size_t N>
634636
void test_results(const std::array<Result, N>& results) {
635637
Eng eng;
636638
std::uniform_int_distribution<Result> dist;

0 commit comments

Comments
 (0)