Skip to content

Commit 046f609

Browse files
authored
NPR Updated Parameters
1 parent bf95d41 commit 046f609

File tree

1 file changed

+5
-44
lines changed

1 file changed

+5
-44
lines changed

botorch_community/models/np_regression.py

Lines changed: 5 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -291,18 +291,13 @@ def __init__(
291291
self.input_transform = input_transform
292292

293293
def data_to_z_params(
294-
self,
295-
x: torch.Tensor,
296-
y: torch.Tensor,
297-
xy_dim: int = 1,
298-
r_dim: int = 0,
294+
self, x: torch.Tensor, y: torch.Tensor, r_dim: int = 0
299295
) -> Tuple[torch.Tensor, torch.Tensor]:
300296
r"""Compute latent parameters from inputs as a latent distribution.
301297
302298
Args:
303299
x: Input tensor
304300
y: Target tensor
305-
xy_dim: Combined Input Dimension as int, defaults as 1
306301
r_dim: Combined Target Dimension as int, defaults as 0.
307302
308303
Returns:
@@ -314,7 +309,7 @@ def data_to_z_params(
314309
"""
315310
x = x.to(device)
316311
y = y.to(device)
317-
xy = torch.cat([x, y], dim=xy_dim).to(device).to(device)
312+
xy = torch.cat([x, y], dim=-1).to(device).to(device)
318313
rs = self.r_encoder(xy)
319314
r_agg = rs.mean(dim=r_dim).to(device)
320315
return self.z_encoder(r_agg)
@@ -463,54 +458,20 @@ def forward(
463458
self.z_mu_all, self.z_logvar_all = self.data_to_z_params(
464459
self.train_X, self.train_Y
465460
)
466-
self.z_mu_context, self.z_logvar_context = self.data_to_z_params(
467-
x_c, y_c
468-
)
461+
self.z_mu_context, self.z_logvar_context = self.data_to_z_params(x_c, y_c)
469462
x_t = self.transform_inputs(x_t)
470463
return self.posterior(x_t).distribution
471464

472465
def random_split_context_target(
473-
self, x: torch.Tensor, y: torch.Tensor, n_context: int, axis: int
474-
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
475-
r"""Helper function to split randomly into context and target.
476-
477-
Args:
478-
x: A `batch_shape x n x d` tensor of training features.
479-
y: A `batch_shape x n x m` tensor of training observations.
480-
n_context (int): Number of context points.
481-
axis: Dimension axis as int
482-
483-
Returns:
484-
Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
485-
- x_c: Context input data.
486-
- y_c: Context target data.
487-
- x_t: Target input data.
488-
- y_t: Target target data.
489-
"""
490-
self.n_context = n_context
491-
mask = torch.randperm(x.shape[axis])[:n_context]
492-
x_c = torch.from_numpy(x[mask]).to(device)
493-
y_c = torch.from_numpy(y[mask]).to(device)
494-
splitter = torch.zeros(x.shape[axis], dtype=torch.bool)
495-
splitter[mask] = True
496-
x_t = torch.from_numpy(x[~splitter]).to(device)
497-
y_t = torch.from_numpy(y[~splitter]).to(device)
498-
return x_c, y_c, x_t, y_t
499-
500-
def random_split_context_target(
501-
self,
502-
x: torch.Tensor,
503-
y: torch.Tensor,
504-
n_context: int = 20,
505-
axis: int = 0
466+
self, x: torch.Tensor, y: torch.Tensor, n_context, axis: int = 0
506467
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
507468
r"""Helper function to split randomly into context and target.
508469
509470
Args:
510471
x: A `batch_shape x n x d` tensor of training features.
511472
y: A `batch_shape x n x m` tensor of training observations.
512473
n_context (int): Number of context points.
513-
axis: Dimension axis as int
474+
axis: Dimension axis as int, defaults to 0.
514475
515476
Returns:
516477
Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:

0 commit comments

Comments
 (0)