@@ -231,6 +231,7 @@ def __init__(
231
231
y_dim : int = 1 ,
232
232
r_dim : int = 64 ,
233
233
z_dim : int = 8 ,
234
+ n_context : int = 20 ,
234
235
activation : Callable = nn .Sigmoid ,
235
236
init_func : Optional [Callable ] = torch .nn .init .normal_ ,
236
237
likelihood : Likelihood | None = None ,
@@ -248,6 +249,7 @@ def __init__(
248
249
y_dim: Int dimensionality of target data y.
249
250
r_dim: Int dimensionality of representation r.
250
251
z_dim: Int dimensionality of latent variable z.
252
+ n_context (int): Number of context points, defaults to 20.
251
253
activation: Activation function applied between layers, defaults to nn.
252
254
Sigmoid.
253
255
init_func: A function initializing the weights,
@@ -276,6 +278,7 @@ def __init__(
276
278
).to (device )
277
279
self .train_X = train_X .to (device )
278
280
self .train_Y = train_Y .to (device )
281
+ self .n_context = n_context
279
282
self .z_dim = z_dim
280
283
self .z_mu_all = None
281
284
self .z_logvar_all = None
@@ -430,40 +433,38 @@ def transform_inputs(
430
433
return input_transform (X )
431
434
try :
432
435
return self .input_transform (X )
433
- except AttributeError :
436
+ except ( AttributeError , TypeError ) :
434
437
return X
435
438
436
439
def forward (
437
440
self ,
438
441
train_X : torch .Tensor ,
439
442
train_Y : torch .Tensor ,
440
- n_context : int ,
441
443
axis : int = 0 ,
442
444
) -> MultivariateNormal :
443
445
r"""Forward pass for the model.
444
446
445
447
Args:
446
448
train_X: A `batch_shape x n x d` tensor of training features.
447
449
train_Y: A `batch_shape x n x m` tensor of training observations.
448
- n_context (int): Number of context points.
449
450
axis: Dimension axis as int, defaulted as 0.
450
451
451
452
Returns:
452
453
MultivariateNormal: Predicted target distribution.
453
454
"""
454
455
train_X = self .transform_inputs (train_X )
455
456
x_c , y_c , x_t , y_t = self .random_split_context_target (
456
- train_X , train_Y , n_context , axis = axis
457
+ train_X , train_Y , self . n_context , axis = axis
457
458
)
458
459
x_t = x_t .to (device )
459
460
x_c = x_c .to (device )
460
461
y_c = y_c .to (device )
461
462
y_t = y_t .to (device )
462
463
self .z_mu_all , self .z_logvar_all = self .data_to_z_params (
463
- self .train_X , self .train_Y , dim = axis
464
+ self .train_X , self .train_Y
464
465
)
465
466
self .z_mu_context , self .z_logvar_context = self .data_to_z_params (
466
- x_c , y_c , dim = axis
467
+ x_c , y_c
467
468
)
468
469
x_t = self .transform_inputs (x_t )
469
470
return self .posterior (x_t ).distribution
@@ -486,6 +487,7 @@ def random_split_context_target(
486
487
- x_t: Target input data.
487
488
- y_t: Target target data.
488
489
"""
490
+ self .n_context = n_context
489
491
mask = torch .randperm (x .shape [axis ])[:n_context ]
490
492
x_c = torch .from_numpy (x [mask ]).to (device )
491
493
y_c = torch .from_numpy (y [mask ]).to (device )
@@ -494,3 +496,35 @@ def random_split_context_target(
494
496
x_t = torch .from_numpy (x [~ splitter ]).to (device )
495
497
y_t = torch .from_numpy (y [~ splitter ]).to (device )
496
498
return x_c , y_c , x_t , y_t
499
+
500
+ def random_split_context_target (
501
+ self ,
502
+ x : torch .Tensor ,
503
+ y : torch .Tensor ,
504
+ n_context : int = 20 ,
505
+ axis : int = 0
506
+ ) -> Tuple [torch .Tensor , torch .Tensor , torch .Tensor , torch .Tensor ]:
507
+ r"""Helper function to split randomly into context and target.
508
+
509
+ Args:
510
+ x: A `batch_shape x n x d` tensor of training features.
511
+ y: A `batch_shape x n x m` tensor of training observations.
512
+ n_context (int): Number of context points.
513
+ axis: Dimension axis as int
514
+
515
+ Returns:
516
+ Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
517
+ - x_c: Context input data.
518
+ - y_c: Context target data.
519
+ - x_t: Target input data.
520
+ - y_t: Target target data.
521
+ """
522
+ self .n_context = n_context
523
+ mask = torch .randperm (x .shape [axis ])[:n_context ]
524
+ splitter = torch .zeros (x .shape [axis ], dtype = torch .bool )
525
+ x_c = x [mask ].to (device )
526
+ y_c = y [mask ].to (device )
527
+ splitter [mask ] = True
528
+ x_t = x [~ splitter ].to (device )
529
+ y_t = y [~ splitter ].to (device )
530
+ return x_c , y_c , x_t , y_t
0 commit comments