Skip to content

Commit 589260b

Browse files
saitcakmakfacebook-github-bot
authored andcommitted
Fix random seed handling in sample_hypersphere (#2688)
Summary: Pull Request resolved: #2688 Fixes #2685 The seed was being ignored when `d=1`. Reviewed By: esantorella Differential Revision: D68443701 fbshipit-source-id: 2c8a5828170c8b955a452a7446d775c171db1baf
1 parent d147bd5 commit 589260b

File tree

2 files changed

+21
-4
lines changed

2 files changed

+21
-4
lines changed

botorch/utils/sampling.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -163,16 +163,15 @@ def sample_hypersphere(
163163
>>> sample_hypersphere(d=5, n=10)
164164
"""
165165
if d == 1:
166-
rnd = torch.randint(0, 2, (n, 1), device=device, dtype=dtype)
166+
with manual_seed(seed=seed):
167+
rnd = torch.randint(0, 2, (n, 1), device=device, dtype=dtype)
167168
return 2 * rnd - 1
168169
if qmc:
169170
rnd = draw_sobol_normal_samples(d=d, n=n, device=device, dtype=dtype, seed=seed)
170171
else:
171172
with manual_seed(seed=seed):
172-
rnd = torch.randn(n, d, dtype=dtype)
173+
rnd = torch.randn(n, d, device=device, dtype=dtype)
173174
samples = rnd / torch.linalg.norm(rnd, dim=-1, keepdim=True)
174-
if device is not None:
175-
samples = samples.to(device)
176175
return samples
177176

178177

test/utils/test_sampling.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,12 @@ def test_draw_sobol_samples(self):
9393
self.assertTrue(torch.all(samples <= bounds[1]))
9494
self.assertEqual(samples.device.type, self.device.type)
9595
self.assertEqual(samples.dtype, dtype)
96+
if seed is not None:
97+
# Check that seed reproduces the same samples.
98+
samples2 = draw_sobol_samples(
99+
bounds=bounds, n=n, q=q, batch_shape=batch_shape, seed=seed
100+
)
101+
self.assertTrue(torch.equal(samples, samples2))
96102

97103
def test_sample_simplex(self):
98104
for d, n, qmc, seed, dtype in itertools.product(
@@ -107,6 +113,12 @@ def test_sample_simplex(self):
107113
self.assertTrue(torch.max((samples.sum(dim=-1) - 1).abs()) < 1e-5)
108114
self.assertEqual(samples.device.type, self.device.type)
109115
self.assertEqual(samples.dtype, dtype)
116+
if seed is not None:
117+
# Check that seed reproduces the same samples.
118+
samples2 = sample_simplex(
119+
d=d, n=n, qmc=qmc, seed=seed, device=self.device, dtype=dtype
120+
)
121+
self.assertTrue(torch.equal(samples, samples2))
110122

111123
def test_sample_hypersphere(self):
112124
for d, n, qmc, seed, dtype in itertools.product(
@@ -119,6 +131,12 @@ def test_sample_hypersphere(self):
119131
self.assertTrue(torch.max((samples.pow(2).sum(dim=-1) - 1).abs()) < 1e-5)
120132
self.assertEqual(samples.device.type, self.device.type)
121133
self.assertEqual(samples.dtype, dtype)
134+
if seed is not None:
135+
# Check that seed reproduces the same samples.
136+
samples2 = sample_hypersphere(
137+
d=d, n=n, qmc=qmc, seed=seed, device=self.device, dtype=dtype
138+
)
139+
self.assertTrue(torch.equal(samples, samples2))
122140

123141
def test_batched_multinomial(self):
124142
num_categories = 5

0 commit comments

Comments
 (0)