Skip to content

Commit 8da6847

Browse files
Add a Numba implementation for Generator.dirichlet
1 parent 2d84709 commit 8da6847

File tree

2 files changed

+78
-2
lines changed

2 files changed

+78
-2
lines changed

aesara/link/numba/dispatch/random.py

Lines changed: 78 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@
66
import numba.np.unsafe.ndarray as numba_ndarray
77
import numpy as np
88
from numba import types
9-
from numba.extending import overload
9+
from numba.extending import overload, overload_method, register_jitable
10+
from numba.np.random.distributions import random_beta, random_standard_gamma
11+
from numba.np.random.generator_methods import check_size, check_types, is_nonelike
1012

1113
import aesara.tensor.random.basic as aer
1214
from aesara.graph.basic import Apply
@@ -296,3 +298,78 @@ def dirichlet_rv(rng, size, dtype, alphas):
296298
return (rng, rng.dirichlet(alphas, size))
297299

298300
return dirichlet_rv
301+
302+
303+
@register_jitable
304+
def random_dirichlet(bitgen, alpha, size):
305+
"""
306+
This implementation is straight from ``numpy/random/_generator.pyx``.
307+
"""
308+
309+
k = len(alpha)
310+
alpha_arr = np.asarray(alpha, dtype=np.float64)
311+
312+
if np.any(np.less_equal(alpha_arr, 0)):
313+
raise ValueError("alpha <= 0")
314+
315+
shape = size + (k,)
316+
317+
diric = np.zeros(shape, np.float64)
318+
319+
i = 0
320+
totsize = diric.size
321+
322+
if (k > 0) and (alpha_arr.max() < 0.1):
323+
alpha_csum_arr = np.empty_like(alpha_arr)
324+
csum = 0.0
325+
for j in range(k - 1, -1, -1):
326+
csum += alpha_arr[j]
327+
alpha_csum_arr[j] = csum
328+
329+
while i < totsize:
330+
acc = 1.0
331+
for j in range(k - 1):
332+
v = random_beta(bitgen, alpha_arr[j], alpha_csum_arr[j + 1])
333+
diric[i + j] = acc * v
334+
acc *= 1.0 - v
335+
diric[i + k - 1] = acc
336+
i = i + k
337+
338+
else:
339+
while i < totsize:
340+
acc = 0.0
341+
for j in range(k):
342+
diric[i + j] = random_standard_gamma(bitgen, alpha_arr[j])
343+
acc = acc + diric[i + j]
344+
invacc = 1.0 / acc
345+
for j in range(k):
346+
diric[i + j] = diric[i + j] * invacc
347+
i = i + k
348+
349+
return diric
350+
351+
352+
@overload_method(types.NumPyRandomGeneratorType, "dirichlet")
353+
def NumPyRandomGeneratorType_dirichlet(inst, alphas, size=None):
354+
check_types(alphas, [types.Array, types.List], "alphas")
355+
356+
if isinstance(size, types.Omitted):
357+
size = size.value
358+
359+
if is_nonelike(size):
360+
361+
def impl(inst, alphas, size=None):
362+
return random_dirichlet(inst.bit_generator, alphas, ())
363+
364+
elif isinstance(size, (int, types.Integer)):
365+
366+
def impl(inst, alphas, size=None):
367+
return random_dirichlet(inst.bit_generator, alphas, (size,))
368+
369+
else:
370+
check_size(size)
371+
372+
def impl(inst, alphas, size=None):
373+
return random_dirichlet(inst.bit_generator, alphas, size)
374+
375+
return impl

tests/link/numba/test_random.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -520,7 +520,6 @@ def test_CategoricalRV(dist_args, size, cm):
520520
)
521521

522522

523-
@pytest.mark.skip(reason="Not yet supported in Numba via `Generator`s")
524523
@pytest.mark.parametrize(
525524
"a, size, cm",
526525
[

0 commit comments

Comments
 (0)