|
1 | 1 | import torch |
2 | 2 |
|
3 | 3 |
|
4 | | -def latin_hypercube_sampling_standard(n: int, d: int, device: torch.device, scramble: bool = True): |
| 4 | +def latin_hypercube_sampling_standard(n: int, d: int, device: torch.device, smooth: bool = True): |
5 | 5 | """Generate Latin Hypercube samples in the unit hypercube. |
6 | 6 |
|
7 | 7 | :param n: The number of sample points to generate. |
8 | 8 | :param d: The dimensionality of the samples. |
9 | 9 | :param device: The device on which to generate the samples. |
10 | | - :param scramble: Whether to scramble the order of the samples. Defaults to True. |
| 10 | + :param smooth: Whether to generate sample in random positions in the cells or not. Defaults to True. |
11 | 11 |
|
12 | 12 | :return: A tensor of shape (n, d), where each row represents a sample point and each column represents a dimension. |
13 | 13 | """ |
14 | | - samples = (torch.arange(0, n, device=device).view(-1, 1) + torch.rand(n, d, device=device)) / n |
15 | | - if scramble: |
16 | | - samples = samples[torch.randperm(n)] |
| 14 | + cells = torch.arange(0, n, device=device).view(-1, 1).expand(n, d).contiguous() |
| 15 | + cells_perms = torch.rand(n, d, device=device).argsort(dim=0) |
| 16 | + cells = cells.gather(0, cells_perms) |
| 17 | + if smooth: |
| 18 | + samples = (cells + torch.rand(n, d, device=device)) / n |
| 19 | + else: |
| 20 | + samples = (cells + 0.5) / n |
17 | 21 | return samples |
18 | 22 |
|
19 | 23 |
|
20 | | -def latin_hypercube_sampling(n: int, d: int, lb: torch.Tensor, ub: torch.Tensor, scramble: bool = True): |
| 24 | +def latin_hypercube_sampling(n: int, lb: torch.Tensor, ub: torch.Tensor, smooth: bool = True): |
21 | 25 | """Generate Latin Hypercube samples in the given hypercube defined by `lb` and `ub`. |
22 | 26 |
|
23 | 27 | :param n: The number of sample points to generate. |
24 | | - :param d: The dimensionality of the samples. |
25 | | - :param lb: The lower bounds of the hypercube. Must be a 1D tensor with same shape, dtype, and device as `ub`. |
26 | | - :param ub: The upper bounds of the hypercube. Must be a 1D tensor with same shape, dtype, and device as `lb`. |
27 | | - :param scramble: Whether to scramble the order of the samples. Defaults to True. |
| 28 | + :param lb: The lower bounds of the hypercube. Must be a 1D tensor of size `d` with same shape, dtype, and device as `ub`. |
| 29 | + :param ub: The upper bounds of the hypercube. Must be a 1D tensor of size `d` with same shape, dtype, and device as `lb`. |
| 30 | + :param smooth: Whether to generate sample in random positions in the cells or not. Defaults to True. |
28 | 31 |
|
29 | 32 | :return: A tensor of shape (n, d), where each row represents a sample point and each column represents a dimension whose device is the same as `lb` and `ub`. |
30 | 33 | """ |
31 | 34 | assert lb.device == ub.device and lb.dtype == ub.dtype and lb.ndim == 1 and ub.ndim == 1 and lb.size() == ub.size() |
32 | | - samples = latin_hypercube_sampling_standard(n, d, lb.device, scramble) |
| 35 | + samples = latin_hypercube_sampling_standard(n, lb.size(0), lb.device, smooth) |
33 | 36 | lb = lb[None, :] |
34 | 37 | ub = ub[None, :] |
35 | 38 | return lb + samples * (ub - lb) |
0 commit comments