Skip to content

Commit 9ea7a99

Browse files
authored
expose prob as positional argument for bernoulli ops (#1949)
1 parent 6c357b8 commit 9ea7a99

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

mindnlp/core/ops/random.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,11 @@
1111

1212
# bernoulli
1313
has_bernoulli = hasattr(mindspore.mint, 'bernoulli')
14-
def bernoulli(input, *, generator=None):
14+
def bernoulli(input, *, generator=None, p=0.5):
1515
if use_pyboost() and has_bernoulli:
1616
return mindspore.mint.bernoulli(input, generator=generator)
1717
random_numbers = rand(*input.shape, dtype=mindspore.float32)
18-
samples = random_numbers < 0.5
18+
samples = random_numbers < p
1919
samples = samples.int()
2020
return samples
2121

0 commit comments

Comments
 (0)