Skip to content

Commit 7232725

Browse files
authored
Updated NPR Parameters
1 parent 8e33fc4 commit 7232725

File tree

1 file changed

+40
-6
lines changed

1 file changed

+40
-6
lines changed

botorch_community/models/np_regression.py

Lines changed: 40 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,7 @@ def __init__(
231231
y_dim: int = 1,
232232
r_dim: int = 64,
233233
z_dim: int = 8,
234+
n_context: int = 20,
234235
activation: Callable = nn.Sigmoid,
235236
init_func: Optional[Callable] = torch.nn.init.normal_,
236237
likelihood: Likelihood | None = None,
@@ -248,6 +249,7 @@ def __init__(
248249
y_dim: Int dimensionality of target data y.
249250
r_dim: Int dimensionality of representation r.
250251
z_dim: Int dimensionality of latent variable z.
252+
n_context (int): Number of context points, defaults to 20.
251253
activation: Activation function applied between layers, defaults to nn.
252254
Sigmoid.
253255
init_func: A function initializing the weights,
@@ -276,6 +278,7 @@ def __init__(
276278
).to(device)
277279
self.train_X = train_X.to(device)
278280
self.train_Y = train_Y.to(device)
281+
self.n_context = n_context
279282
self.z_dim = z_dim
280283
self.z_mu_all = None
281284
self.z_logvar_all = None
@@ -430,40 +433,38 @@ def transform_inputs(
430433
return input_transform(X)
431434
try:
432435
return self.input_transform(X)
433-
except AttributeError:
436+
except (AttributeError, TypeError):
434437
return X
435438

436439
def forward(
437440
self,
438441
train_X: torch.Tensor,
439442
train_Y: torch.Tensor,
440-
n_context: int,
441443
axis: int = 0,
442444
) -> MultivariateNormal:
443445
r"""Forward pass for the model.
444446
445447
Args:
446448
train_X: A `batch_shape x n x d` tensor of training features.
447449
train_Y: A `batch_shape x n x m` tensor of training observations.
448-
n_context (int): Number of context points.
449450
axis: Dimension axis as int, defaulted as 0.
450451
451452
Returns:
452453
MultivariateNormal: Predicted target distribution.
453454
"""
454455
train_X = self.transform_inputs(train_X)
455456
x_c, y_c, x_t, y_t = self.random_split_context_target(
456-
train_X, train_Y, n_context, axis=axis
457+
train_X, train_Y, self.n_context, axis=axis
457458
)
458459
x_t = x_t.to(device)
459460
x_c = x_c.to(device)
460461
y_c = y_c.to(device)
461462
y_t = y_t.to(device)
462463
self.z_mu_all, self.z_logvar_all = self.data_to_z_params(
463-
self.train_X, self.train_Y, dim=axis
464+
self.train_X, self.train_Y
464465
)
465466
self.z_mu_context, self.z_logvar_context = self.data_to_z_params(
466-
x_c, y_c, dim=axis
467+
x_c, y_c
467468
)
468469
x_t = self.transform_inputs(x_t)
469470
return self.posterior(x_t).distribution
@@ -486,6 +487,7 @@ def random_split_context_target(
486487
- x_t: Target input data.
487488
- y_t: Target target data.
488489
"""
490+
self.n_context = n_context
489491
mask = torch.randperm(x.shape[axis])[:n_context]
490492
x_c = torch.from_numpy(x[mask]).to(device)
491493
y_c = torch.from_numpy(y[mask]).to(device)
@@ -494,3 +496,35 @@ def random_split_context_target(
494496
x_t = torch.from_numpy(x[~splitter]).to(device)
495497
y_t = torch.from_numpy(y[~splitter]).to(device)
496498
return x_c, y_c, x_t, y_t
499+
500+
def random_split_context_target(
501+
self,
502+
x: torch.Tensor,
503+
y: torch.Tensor,
504+
n_context: int = 20,
505+
axis: int = 0
506+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
507+
r"""Helper function to split randomly into context and target.
508+
509+
Args:
510+
x: A `batch_shape x n x d` tensor of training features.
511+
y: A `batch_shape x n x m` tensor of training observations.
512+
n_context (int): Number of context points.
513+
axis: Dimension axis as int
514+
515+
Returns:
516+
Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
517+
- x_c: Context input data.
518+
- y_c: Context target data.
519+
- x_t: Target input data.
520+
- y_t: Target target data.
521+
"""
522+
self.n_context = n_context
523+
mask = torch.randperm(x.shape[axis])[:n_context]
524+
splitter = torch.zeros(x.shape[axis], dtype=torch.bool)
525+
x_c = x[mask].to(device)
526+
y_c = y[mask].to(device)
527+
splitter[mask] = True
528+
x_t = x[~splitter].to(device)
529+
y_t = y[~splitter].to(device)
530+
return x_c, y_c, x_t, y_t

0 commit comments

Comments
 (0)