Skip to content

Commit 64b1aef

Browse files
committed
fix bug in function latin_hypercube_sampling_standard.
1 parent e4ab098 commit 64b1aef

File tree

1 file changed

+14
-11
lines changed

1 file changed

+14
-11
lines changed
Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,35 +1,38 @@
11
import torch
22

33

4-
def latin_hypercube_sampling_standard(n: int, d: int, device: torch.device, scramble: bool = True):
4+
def latin_hypercube_sampling_standard(n: int, d: int, device: torch.device, smooth: bool = True):
55
"""Generate Latin Hypercube samples in the unit hypercube.
66
77
:param n: The number of sample points to generate.
88
:param d: The dimensionality of the samples.
99
:param device: The device on which to generate the samples.
10-
:param scramble: Whether to scramble the order of the samples. Defaults to True.
10+
:param smooth: Whether to generate sample in random positions in the cells or not. Defaults to True.
1111
1212
:return: A tensor of shape (n, d), where each row represents a sample point and each column represents a dimension.
1313
"""
14-
samples = (torch.arange(0, n, device=device).view(-1, 1) + torch.rand(n, d, device=device)) / n
15-
if scramble:
16-
samples = samples[torch.randperm(n)]
14+
cells = torch.arange(0, n, device=device).view(-1, 1).expand(n, d).contiguous()
15+
cells_perms = torch.rand(n, d, device=device).argsort(dim=0)
16+
cells = cells.gather(0, cells_perms)
17+
if smooth:
18+
samples = (cells + torch.rand(n, d, device=device)) / n
19+
else:
20+
samples = (cells + 0.5) / n
1721
return samples
1822

1923

20-
def latin_hypercube_sampling(n: int, d: int, lb: torch.Tensor, ub: torch.Tensor, scramble: bool = True):
24+
def latin_hypercube_sampling(n: int, lb: torch.Tensor, ub: torch.Tensor, smooth: bool = True):
2125
"""Generate Latin Hypercube samples in the given hypercube defined by `lb` and `ub`.
2226
2327
:param n: The number of sample points to generate.
24-
:param d: The dimensionality of the samples.
25-
:param lb: The lower bounds of the hypercube. Must be a 1D tensor with same shape, dtype, and device as `ub`.
26-
:param ub: The upper bounds of the hypercube. Must be a 1D tensor with same shape, dtype, and device as `lb`.
27-
:param scramble: Whether to scramble the order of the samples. Defaults to True.
28+
:param lb: The lower bounds of the hypercube. Must be a 1D tensor of size `d` with same shape, dtype, and device as `ub`.
29+
:param ub: The upper bounds of the hypercube. Must be a 1D tensor of size `d` with same shape, dtype, and device as `lb`.
30+
:param smooth: Whether to generate sample in random positions in the cells or not. Defaults to True.
2831
2932
:return: A tensor of shape (n, d), where each row represents a sample point and each column represents a dimension whose device is the same as `lb` and `ub`.
3033
"""
3134
assert lb.device == ub.device and lb.dtype == ub.dtype and lb.ndim == 1 and ub.ndim == 1 and lb.size() == ub.size()
32-
samples = latin_hypercube_sampling_standard(n, d, lb.device, scramble)
35+
samples = latin_hypercube_sampling_standard(n, lb.size(0), lb.device, smooth)
3336
lb = lb[None, :]
3437
ub = ub[None, :]
3538
return lb + samples * (ub - lb)

0 commit comments

Comments
 (0)