diff --git a/libcxx/include/__random/uniform_int_distribution.h b/libcxx/include/__random/uniform_int_distribution.h index fa2c33755b739..2d06808cda5ab 100644 --- a/libcxx/include/__random/uniform_int_distribution.h +++ b/libcxx/include/__random/uniform_int_distribution.h @@ -9,6 +9,7 @@ #ifndef _LIBCPP___RANDOM_UNIFORM_INT_DISTRIBUTION_H #define _LIBCPP___RANDOM_UNIFORM_INT_DISTRIBUTION_H +#include <__assert> #include <__bit/countl.h> #include <__config> #include <__cstddef/size_t.h> @@ -64,7 +65,7 @@ class __independent_bits_engine { _LIBCPP_HIDE_FROM_ABI __independent_bits_engine(_Engine& __e, size_t __w); // generating functions - _LIBCPP_HIDE_FROM_ABI result_type operator()() { return __eval(integral_constant()); } + _LIBCPP_HIDE_FROM_ABI result_type operator()() { return __eval(integral_constant()); } private: _LIBCPP_HIDE_FROM_ABI result_type __eval(false_type); @@ -74,49 +75,66 @@ class __independent_bits_engine { template __independent_bits_engine<_Engine, _UIntType>::__independent_bits_engine(_Engine& __e, size_t __w) : __e_(__e), __w_(__w) { - __n_ = __w_ / __m + (__w_ % __m != 0); - __w0_ = __w_ / __n_; - if (_Rp == 0) - __y0_ = _Rp; - else if (__w0_ < _WDt) - __y0_ = (_Rp >> __w0_) << __w0_; - else - __y0_ = 0; - if (_Rp - __y0_ > __y0_ / __n_) { - ++__n_; + _LIBCPP_ASSERT_INTERNAL( + __w_ <= numeric_limits::digits, "cannot sample more bits than result_type can hold"); + _LIBCPP_ASSERT_INTERNAL(__w_ > 0, "must sample a positive number of bits"); + if (__w_ <= __m) { + __n_ = __n0_ = 1; + __w0_ = __w_; + __mask0_ = __mask1_ = ~_Engine_result_type(0) >> (_EDt - __w0_); + __y0_ = __y1_ = _Rp & ~__mask0_; + } else { + __n_ = (__w_ + __m - 1) / __m; __w0_ = __w_ / __n_; - if (__w0_ < _WDt) - __y0_ = (_Rp >> __w0_) << __w0_; - else - __y0_ = 0; + __mask0_ = __mask1_ = ~_Engine_result_type(0) >> (_EDt - __w0_); + __y0_ = __y1_ = _Rp & ~__mask0_; + if _LIBCPP_CONSTEXPR_SINCE_CXX17 ((_Rp & (_Rp - 1)) != 0) { + if (_Rp - __y0_ > __y0_ / __n_) { + ++__n_; + __w0_ = __w_ / __n_; + __mask0_ = __mask1_ = ~_Engine_result_type(0) >> (_EDt - __w0_); + __y0_ = __y1_ = _Rp & ~__mask0_; + } + } + size_t __n1 = __w_ % __n_; + __n0_ = __n_ - __n1; + if (__n1 > 0) { + __mask1_ = ~_Engine_result_type(0) >> (_EDt - (__w0_ + 1)); + __y1_ = _Rp & ~__mask1_; + } } - __n0_ = __n_ - __w_ % __n_; - if (__w0_ < _WDt - 1) - __y1_ = (_Rp >> (__w0_ + 1)) << (__w0_ + 1); - else - __y1_ = 0; - __mask0_ = __w0_ > 0 ? _Engine_result_type(~0) >> (_EDt - __w0_) : _Engine_result_type(0); - __mask1_ = __w0_ < _EDt - 1 ? _Engine_result_type(~0) >> (_EDt - (__w0_ + 1)) : _Engine_result_type(~0); } template inline _UIntType __independent_bits_engine<_Engine, _UIntType>::__eval(false_type) { - return static_cast(__e_() & __mask0_); + result_type __sp = (__e_() - _Engine::min()) & __mask0_; + for (size_t __k = 1; __k < __n0_; ++__k) { + __sp <<= __w0_; + __sp += (__e_() - _Engine::min()) & __mask0_; + } + for (size_t __k = __n0_; __k < __n_; ++__k) { + __sp <<= __w0_ + 1; + __sp += (__e_() - _Engine::min()) & __mask1_; + } + return __sp; } template _UIntType __independent_bits_engine<_Engine, _UIntType>::__eval(true_type) { - const size_t __w_rt = numeric_limits::digits; - result_type __sp = 0; - for (size_t __k = 0; __k < __n0_; ++__k) { + result_type __sp; + { + _Engine_result_type __u; + do { + __u = __e_() - _Engine::min(); + } while (__u >= __y0_); + __sp = __u & __mask0_; + } + for (size_t __k = 1; __k < __n0_; ++__k) { _Engine_result_type __u; do { __u = __e_() - _Engine::min(); } while (__u >= __y0_); - if (__w0_ < __w_rt) - __sp <<= __w0_; - else - __sp = 0; + __sp <<= __w0_; __sp += __u & __mask0_; } for (size_t __k = __n0_; __k < __n_; ++__k) { @@ -124,10 +142,7 @@ _UIntType __independent_bits_engine<_Engine, _UIntType>::__eval(true_type) { do { __u = __e_() - _Engine::min(); } while (__u >= __y1_); - if (__w0_ < __w_rt - 1) - __sp <<= __w0_ + 1; - else - __sp = 0; + __sp <<= __w0_ + 1; __sp += __u & __mask1_; } return __sp; @@ -218,9 +233,9 @@ typename uniform_int_distribution<_IntType>::result_type uniform_int_distributio typedef __independent_bits_engine<_URNG, _UIntType> _Eng; if (__rp == 0) return static_cast(_Eng(__g, __dt)()); - size_t __w = __dt - std::__countl_zero(__rp) - 1; - if ((__rp & (numeric_limits<_UIntType>::max() >> (__dt - __w))) != 0) - ++__w; + size_t __w = __dt - std::__countl_zero(__rp); + if ((__rp & (__rp - 1)) == 0) + return static_cast(_Eng(__g, __w - 1)() + __p.a()); _Eng __e(__g, __w); _UIntType __u; do { diff --git a/libcxx/test/benchmarks/numeric/rand.uni.int.bench.cpp b/libcxx/test/benchmarks/numeric/rand.uni.int.bench.cpp new file mode 100644 index 0000000000000..eb9a76835853d --- /dev/null +++ b/libcxx/test/benchmarks/numeric/rand.uni.int.bench.cpp @@ -0,0 +1,58 @@ +//===----------------------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// UNSUPPORTED: c++03 + +#include +#include + +#include + +template +static void bm_uniform_int_distribution(benchmark::State& state) { + Eng eng; + std::uniform_int_distribution dist(1ull, Max); + for (auto _ : state) { + benchmark::DoNotOptimize(dist(eng)); + } +} + +// n = 1 +// Best Case +BENCHMARK(bm_uniform_int_distribution); +BENCHMARK(bm_uniform_int_distribution); +// Worst Case +BENCHMARK(bm_uniform_int_distribution); +BENCHMARK(bm_uniform_int_distribution); +// Median Case +BENCHMARK(bm_uniform_int_distribution); +BENCHMARK(bm_uniform_int_distribution); + +// n = 2, n0 = 2 +// Best Case +BENCHMARK(bm_uniform_int_distribution); +BENCHMARK(bm_uniform_int_distribution); +// Worst Case +BENCHMARK(bm_uniform_int_distribution); +BENCHMARK(bm_uniform_int_distribution); +// Median Case +BENCHMARK(bm_uniform_int_distribution); +BENCHMARK(bm_uniform_int_distribution); + +// n = 2, n0 = 1 +// Best Case +BENCHMARK(bm_uniform_int_distribution); +BENCHMARK(bm_uniform_int_distribution); +// Worst Case +BENCHMARK(bm_uniform_int_distribution); +BENCHMARK(bm_uniform_int_distribution); +// Median Case +BENCHMARK(bm_uniform_int_distribution); +BENCHMARK(bm_uniform_int_distribution); + +BENCHMARK_MAIN();