Skip to content

Commit fd93d21

Browse files
committed
changed check_positive to check_nonnegative in RNGs for multinomial and multinomial_logit. Updated comments and docstrings, added unit tests
1 parent 9202f1f commit fd93d21

File tree

5 files changed

+50
-7
lines changed

5 files changed

+50
-7
lines changed

stan/math/prim/prob/dirichlet_multinomial_rng.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ inline std::vector<int> dirichlet_multinomial_rng(
4141
check_positive_finite(function, "prior size variable", alpha_ref);
4242
check_nonnegative(function, "number of trials variables", N);
4343

44-
// special case N = 0 would lead to an exception thrown by multinomial_rng
44+
// special case N = 0
4545
if (N == 0) {
4646
return std::vector<int>(alpha.size(), 0);
4747
}

stan/math/prim/prob/multinomial_logit_rng.hpp

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,16 @@ namespace math {
1313

1414
/** \ingroup multivar_dists
1515
* Return a draw from a Multinomial distribution given a
16-
* a vector of unnormalized log probabilities and a pseudo-random
17-
* number generator.
16+
* vector of unnormalized log probabilities, a total count,
17+
* and a pseudo-random number generator.
1818
*
1919
* @tparam RNG Type of pseudo-random number generator.
2020
* @param beta Vector of unnormalized log probabilities.
21-
* @param N Total count
21+
* @param N Total count.
2222
* @param rng Pseudo-random number generator.
23-
* @return Multinomial random variate
23+
* @return Multinomial random variate.
24+
* @throw std::domain_error if any element of beta is not finite.
25+
* @throw std::domain_error is N is less than 0.
2426
*/
2527
template <class RNG, typename T_beta,
2628
require_eigen_col_vector_t<T_beta>* = nullptr>
@@ -29,7 +31,7 @@ inline std::vector<int> multinomial_logit_rng(const T_beta& beta, int N,
2931
static constexpr const char* function = "multinomial_logit_rng";
3032
const auto& beta_ref = to_ref(beta);
3133
check_finite(function, "Log-probabilities parameter", beta_ref);
32-
check_positive(function, "number of trials variables", N);
34+
check_nonnegative(function, "number of trials variables", N);
3335

3436
plain_type_t<T_beta> theta = softmax(beta_ref);
3537
std::vector<int> result(theta.size(), 0);

stan/math/prim/prob/multinomial_rng.hpp

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,26 @@
1010
namespace stan {
1111
namespace math {
1212

13+
/** \ingroup multivar_dists
14+
* Return a draw from a Multinomial distribution given a
15+
* probability simplex, a total count, and a pseudo-random
16+
* number generator.
17+
*
18+
* @tparam RNG Type of pseudo-random number generator.
19+
* @param theta Vector of normalized probabilities.
20+
* @param N Total count.
21+
* @param rng Pseudo-random number generator.
22+
* @return Multinomial random variate.
23+
* @throw std::domain_error if theta is not a simplex.
24+
* @throw std::domain_error is N is less than 0.
25+
*/
1326
template <class T_theta, class RNG,
1427
require_eigen_col_vector_t<T_theta>* = nullptr>
1528
inline std::vector<int> multinomial_rng(const T_theta& theta, int N, RNG& rng) {
1629
static constexpr const char* function = "multinomial_rng";
1730
const auto& theta_ref = to_ref(theta);
1831
check_simplex(function, "Probabilities parameter", theta_ref);
19-
check_positive(function, "number of trials variables", N);
32+
check_nonnegative(function, "number of trials variables", N);
2033

2134
std::vector<int> result(theta.size(), 0);
2235
double mass_left = 1.0;

test/unit/math/prim/prob/multinomial_logit_test.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,19 @@
88
using Eigen::Dynamic;
99
using Eigen::Matrix;
1010

11+
TEST(ProbDistributionsMultinomialLogit, RNGZero) {
12+
boost::random::mt19937 rng;
13+
Matrix<double, Dynamic, 1> beta(3);
14+
beta << 1.3, 0.1, -2.6;
15+
// bug in 4.8.1: RNG does not allow a zero total count
16+
EXPECT_NO_THROW(stan::math::multinomial_logit_rng(beta, 0, rng));
17+
// when the total count is zero, the sample should be a zero array
18+
std::vector<int> sample = stan::math::multinomial_logit_rng(beta, 0, rng);
19+
for (int k : sample) {
20+
EXPECT_EQ(0, k);
21+
}
22+
}
23+
1124
TEST(ProbDistributionsMultinomialLogit, RNGSize) {
1225
boost::random::mt19937 rng;
1326
Matrix<double, Dynamic, 1> beta(5);

test/unit/math/prim/prob/multinomial_test.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,21 @@
55
#include <limits>
66
#include <vector>
77

8+
TEST(ProbDistributionsMultinomial, RNGZero) {
9+
using Eigen::Dynamic;
10+
using Eigen::Matrix;
11+
boost::random::mt19937 rng;
12+
Matrix<double, Dynamic, 1> theta(3);
13+
theta << 0.3, 0.1, 0.6;
14+
// bug in 4.8.1: RNG does not allow a zero total count
15+
EXPECT_NO_THROW(stan::math::multinomial_rng(theta, 0, rng));
16+
// when the total count is zero, the sample should be a zero array
17+
std::vector<int> sample = stan::math::multinomial_rng(theta, 0, rng);
18+
for (int k : sample) {
19+
EXPECT_EQ(0, k);
20+
}
21+
}
22+
823
TEST(ProbDistributionsMultinomial, RNGSize) {
924
using Eigen::Dynamic;
1025
using Eigen::Matrix;

0 commit comments

Comments
 (0)