@@ -291,18 +291,13 @@ def __init__(
291
291
self .input_transform = input_transform
292
292
293
293
def data_to_z_params (
294
- self ,
295
- x : torch .Tensor ,
296
- y : torch .Tensor ,
297
- xy_dim : int = 1 ,
298
- r_dim : int = 0 ,
294
+ self , x : torch .Tensor , y : torch .Tensor , r_dim : int = 0
299
295
) -> Tuple [torch .Tensor , torch .Tensor ]:
300
296
r"""Compute latent parameters from inputs as a latent distribution.
301
297
302
298
Args:
303
299
x: Input tensor
304
300
y: Target tensor
305
- xy_dim: Combined Input Dimension as int, defaults as 1
306
301
r_dim: Combined Target Dimension as int, defaults as 0.
307
302
308
303
Returns:
@@ -314,7 +309,7 @@ def data_to_z_params(
314
309
"""
315
310
x = x .to (device )
316
311
y = y .to (device )
317
- xy = torch .cat ([x , y ], dim = xy_dim ).to (device ).to (device )
312
+ xy = torch .cat ([x , y ], dim = - 1 ).to (device ).to (device )
318
313
rs = self .r_encoder (xy )
319
314
r_agg = rs .mean (dim = r_dim ).to (device )
320
315
return self .z_encoder (r_agg )
@@ -463,54 +458,20 @@ def forward(
463
458
self .z_mu_all , self .z_logvar_all = self .data_to_z_params (
464
459
self .train_X , self .train_Y
465
460
)
466
- self .z_mu_context , self .z_logvar_context = self .data_to_z_params (
467
- x_c , y_c
468
- )
461
+ self .z_mu_context , self .z_logvar_context = self .data_to_z_params (x_c , y_c )
469
462
x_t = self .transform_inputs (x_t )
470
463
return self .posterior (x_t ).distribution
471
464
472
465
def random_split_context_target (
473
- self , x : torch .Tensor , y : torch .Tensor , n_context : int , axis : int
474
- ) -> Tuple [torch .Tensor , torch .Tensor , torch .Tensor , torch .Tensor ]:
475
- r"""Helper function to split randomly into context and target.
476
-
477
- Args:
478
- x: A `batch_shape x n x d` tensor of training features.
479
- y: A `batch_shape x n x m` tensor of training observations.
480
- n_context (int): Number of context points.
481
- axis: Dimension axis as int
482
-
483
- Returns:
484
- Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
485
- - x_c: Context input data.
486
- - y_c: Context target data.
487
- - x_t: Target input data.
488
- - y_t: Target target data.
489
- """
490
- self .n_context = n_context
491
- mask = torch .randperm (x .shape [axis ])[:n_context ]
492
- x_c = torch .from_numpy (x [mask ]).to (device )
493
- y_c = torch .from_numpy (y [mask ]).to (device )
494
- splitter = torch .zeros (x .shape [axis ], dtype = torch .bool )
495
- splitter [mask ] = True
496
- x_t = torch .from_numpy (x [~ splitter ]).to (device )
497
- y_t = torch .from_numpy (y [~ splitter ]).to (device )
498
- return x_c , y_c , x_t , y_t
499
-
500
- def random_split_context_target (
501
- self ,
502
- x : torch .Tensor ,
503
- y : torch .Tensor ,
504
- n_context : int = 20 ,
505
- axis : int = 0
466
+ self , x : torch .Tensor , y : torch .Tensor , n_context , axis : int = 0
506
467
) -> Tuple [torch .Tensor , torch .Tensor , torch .Tensor , torch .Tensor ]:
507
468
r"""Helper function to split randomly into context and target.
508
469
509
470
Args:
510
471
x: A `batch_shape x n x d` tensor of training features.
511
472
y: A `batch_shape x n x m` tensor of training observations.
512
473
n_context (int): Number of context points.
513
- axis: Dimension axis as int
474
+ axis: Dimension axis as int, defaults to 0.
514
475
515
476
Returns:
516
477
Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
0 commit comments