|
6 | 6 | import numba.np.unsafe.ndarray as numba_ndarray
|
7 | 7 | import numpy as np
|
8 | 8 | from numba import types
|
9 |
| -from numba.extending import overload |
| 9 | +from numba.extending import overload, overload_method, register_jitable |
| 10 | +from numba.np.random.distributions import random_beta, random_standard_gamma |
| 11 | +from numba.np.random.generator_methods import check_size, check_types, is_nonelike |
10 | 12 |
|
11 | 13 | import aesara.tensor.random.basic as aer
|
12 | 14 | from aesara.graph.basic import Apply
|
@@ -296,3 +298,78 @@ def dirichlet_rv(rng, size, dtype, alphas):
|
296 | 298 | return (rng, rng.dirichlet(alphas, size))
|
297 | 299 |
|
298 | 300 | return dirichlet_rv
|
| 301 | + |
| 302 | + |
| 303 | +@register_jitable |
| 304 | +def random_dirichlet(bitgen, alpha, size): |
| 305 | + """ |
| 306 | + This implementation is straight from ``numpy/random/_generator.pyx``. |
| 307 | + """ |
| 308 | + |
| 309 | + k = len(alpha) |
| 310 | + alpha_arr = np.asarray(alpha, dtype=np.float64) |
| 311 | + |
| 312 | + if np.any(np.less_equal(alpha_arr, 0)): |
| 313 | + raise ValueError("alpha <= 0") |
| 314 | + |
| 315 | + shape = size + (k,) |
| 316 | + |
| 317 | + diric = np.zeros(shape, np.float64) |
| 318 | + |
| 319 | + i = 0 |
| 320 | + totsize = diric.size |
| 321 | + |
| 322 | + if (k > 0) and (alpha_arr.max() < 0.1): |
| 323 | + alpha_csum_arr = np.empty_like(alpha_arr) |
| 324 | + csum = 0.0 |
| 325 | + for j in range(k - 1, -1, -1): |
| 326 | + csum += alpha_arr[j] |
| 327 | + alpha_csum_arr[j] = csum |
| 328 | + |
| 329 | + while i < totsize: |
| 330 | + acc = 1.0 |
| 331 | + for j in range(k - 1): |
| 332 | + v = random_beta(bitgen, alpha_arr[j], alpha_csum_arr[j + 1]) |
| 333 | + diric[i + j] = acc * v |
| 334 | + acc *= 1.0 - v |
| 335 | + diric[i + k - 1] = acc |
| 336 | + i = i + k |
| 337 | + |
| 338 | + else: |
| 339 | + while i < totsize: |
| 340 | + acc = 0.0 |
| 341 | + for j in range(k): |
| 342 | + diric[i + j] = random_standard_gamma(bitgen, alpha_arr[j]) |
| 343 | + acc = acc + diric[i + j] |
| 344 | + invacc = 1.0 / acc |
| 345 | + for j in range(k): |
| 346 | + diric[i + j] = diric[i + j] * invacc |
| 347 | + i = i + k |
| 348 | + |
| 349 | + return diric |
| 350 | + |
| 351 | + |
| 352 | +@overload_method(types.NumPyRandomGeneratorType, "dirichlet") |
| 353 | +def NumPyRandomGeneratorType_dirichlet(inst, alphas, size=None): |
| 354 | + check_types(alphas, [types.Array, types.List], "alphas") |
| 355 | + |
| 356 | + if isinstance(size, types.Omitted): |
| 357 | + size = size.value |
| 358 | + |
| 359 | + if is_nonelike(size): |
| 360 | + |
| 361 | + def impl(inst, alphas, size=None): |
| 362 | + return random_dirichlet(inst.bit_generator, alphas, ()) |
| 363 | + |
| 364 | + elif isinstance(size, (int, types.Integer)): |
| 365 | + |
| 366 | + def impl(inst, alphas, size=None): |
| 367 | + return random_dirichlet(inst.bit_generator, alphas, (size,)) |
| 368 | + |
| 369 | + else: |
| 370 | + check_size(size) |
| 371 | + |
| 372 | + def impl(inst, alphas, size=None): |
| 373 | + return random_dirichlet(inst.bit_generator, alphas, size) |
| 374 | + |
| 375 | + return impl |
0 commit comments