22
22
from gpytorch .distributions import MultivariateNormal
23
23
from gpytorch .likelihoods import GaussianLikelihood
24
24
from gpytorch .likelihoods .likelihood import Likelihood
25
+ from gpytorch .models .gp import GP
25
26
from torch .nn import Module
26
27
27
- device = torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
28
- # Account for different acquisitions
29
-
30
28
31
29
# reference: https://chrisorm.github.io/NGP.html
32
30
class MLP (nn .Module ):
@@ -56,21 +54,21 @@ def __init__(
56
54
prev_dim = input_dim
57
55
58
56
for hidden_dim in hidden_dims :
59
- layer = nn .Linear (prev_dim , hidden_dim ). to ( device )
57
+ layer = nn .Linear (prev_dim , hidden_dim )
60
58
if init_func is not None :
61
59
init_func (layer .weight )
62
60
layers .append (layer )
63
61
layers .append (activation ())
64
62
prev_dim = hidden_dim
65
63
66
- final_layer = nn .Linear (prev_dim , output_dim ). to ( device )
64
+ final_layer = nn .Linear (prev_dim , output_dim )
67
65
if init_func is not None :
68
66
init_func (final_layer .weight )
69
67
layers .append (final_layer )
70
- self .model = nn .Sequential (* layers ). to ( device )
68
+ self .model = nn .Sequential (* layers )
71
69
72
70
def forward (self , x : torch .Tensor ) -> torch .Tensor :
73
- return self .model (x . to ( device ) )
71
+ return self .model (x )
74
72
75
73
76
74
class REncoder (nn .Module ):
@@ -101,7 +99,7 @@ def __init__(
101
99
hidden_dims = hidden_dims ,
102
100
activation = activation ,
103
101
init_func = init_func ,
104
- ). to ( device )
102
+ )
105
103
106
104
def forward (self , inputs : torch .Tensor ) -> torch .Tensor :
107
105
r"""Forward pass for representation encoder.
@@ -112,7 +110,7 @@ def forward(self, inputs: torch.Tensor) -> torch.Tensor:
112
110
Returns:
113
111
torch.Tensor: Encoded representations
114
112
"""
115
- return self .mlp (inputs . to ( device ) )
113
+ return self .mlp (inputs )
116
114
117
115
118
116
class ZEncoder (nn .Module ):
@@ -144,14 +142,14 @@ def __init__(
144
142
hidden_dims = hidden_dims ,
145
143
activation = activation ,
146
144
init_func = init_func ,
147
- ). to ( device )
145
+ )
148
146
self .logvar_net = MLP (
149
147
input_dim = input_dim ,
150
148
output_dim = output_dim ,
151
149
hidden_dims = hidden_dims ,
152
150
activation = activation ,
153
151
init_func = init_func ,
154
- ). to ( device )
152
+ )
155
153
156
154
def forward (self , inputs : torch .Tensor ) -> torch .Tensor :
157
155
r"""Forward pass for latent encoder.
@@ -164,7 +162,6 @@ def forward(self, inputs: torch.Tensor) -> torch.Tensor:
164
162
- Mean of the latent Gaussian distribution.
165
163
- Log variance of the latent Gaussian distribution.
166
164
"""
167
- inputs = inputs .to (device )
168
165
return self .mean_net (inputs ), self .logvar_net (inputs )
169
166
170
167
@@ -197,7 +194,7 @@ def __init__(
197
194
hidden_dims = hidden_dims ,
198
195
activation = activation ,
199
196
init_func = init_func ,
200
- ). to ( device )
197
+ )
201
198
202
199
def forward (self , x_pred : torch .Tensor , z : torch .Tensor ) -> torch .Tensor :
203
200
r"""Forward pass for decoder.
@@ -212,14 +209,17 @@ def forward(self, x_pred: torch.Tensor, z: torch.Tensor) -> torch.Tensor:
212
209
torch.Tensor: Predicted target values of shape (n x z_dim), representing #
213
210
of data points by z_dim.
214
211
"""
215
- z = z .to (device )
216
- z_expanded = z .unsqueeze (0 ).expand (x_pred .size (0 ), - 1 ).to (device )
217
- x_pred = x_pred .to (device )
212
+ if z .dim () == 1 :
213
+ z = z .unsqueeze (0 )
214
+ if z .dim () == 3 :
215
+ z = z .squeeze (0 )
216
+ z_expanded = z .expand (x_pred .size (0 ), - 1 )
217
+ x_pred = x_pred
218
218
xz = torch .cat ([x_pred , z_expanded ], dim = - 1 )
219
219
return self .mlp (xz )
220
220
221
221
222
- class NeuralProcessModel (Model ):
222
+ class NeuralProcessModel (Model , GP ):
223
223
def __init__ (
224
224
self ,
225
225
train_X : torch .Tensor ,
@@ -262,35 +262,38 @@ def __init__(
262
262
forward pass.
263
263
"""
264
264
super ().__init__ ()
265
+ self .device = train_X .device
266
+
267
+ # self._validate_tensor_args(X=train_X, Y=train_Y)
265
268
self .r_encoder = REncoder (
266
269
x_dim + y_dim ,
267
270
r_dim ,
268
271
r_hidden_dims ,
269
272
activation = activation ,
270
273
init_func = init_func ,
271
- ).to (device )
274
+ ).to (self . device )
272
275
self .z_encoder = ZEncoder (
273
276
r_dim , z_dim , z_hidden_dims , activation = activation , init_func = init_func
274
- ).to (device )
277
+ ).to (self . device )
275
278
self .decoder = Decoder (
276
279
x_dim + z_dim ,
277
280
y_dim ,
278
281
decoder_hidden_dims ,
279
282
activation = activation ,
280
283
init_func = init_func ,
281
- ).to (device )
282
- self .train_X = train_X .to (device )
283
- self .train_Y = train_Y .to (device )
284
+ ).to (self . device )
285
+ self .train_X = train_X .to (self . device )
286
+ self .train_Y = train_Y .to (self . device )
284
287
self .n_context = n_context
285
288
self .z_dim = z_dim
286
289
self .z_mu_all = None
287
290
self .z_logvar_all = None
288
291
self .z_mu_context = None
289
292
self .z_logvar_context = None
290
293
if likelihood is None :
291
- self .likelihood = GaussianLikelihood ().to (device )
294
+ self .likelihood = GaussianLikelihood ().to (self . device )
292
295
else :
293
- self .likelihood = likelihood .to (device )
296
+ self .likelihood = likelihood .to (self . device )
294
297
self .input_transform = input_transform
295
298
296
299
def data_to_z_params (
@@ -310,11 +313,11 @@ def data_to_z_params(
310
313
- x_t: Target input data.
311
314
- y_t: Target target data.
312
315
"""
313
- x = x .to (device )
314
- y = y .to (device )
315
- xy = torch .cat ([x , y ], dim = - 1 ).to (device ).to (device )
316
+ x = x .to (self . device )
317
+ y = y .to (self . device )
318
+ xy = torch .cat ([x , y ], dim = - 1 ).to (self . device ).to (self . device )
316
319
rs = self .r_encoder (xy )
317
- r_agg = rs .mean (dim = r_dim ).to (device )
320
+ r_agg = rs .mean (dim = r_dim ).to (self . device )
318
321
return self .z_encoder (r_agg )
319
322
320
323
def sample_z (
@@ -344,11 +347,11 @@ def sample_z(
344
347
shape = [n , self .z_dim ]
345
348
if n == 1 :
346
349
shape = shape [1 :]
347
- eps = torch .autograd .Variable (logvar .data .new (* shape ).normal_ ()).to (device )
350
+ eps = torch .autograd .Variable (logvar .data .new (* shape ).normal_ ()).to (self . device )
348
351
349
352
std = min_std + scaler * torch .sigmoid (logvar )
350
- std = std .to (device )
351
- mu = mu .to (device )
353
+ std = std .to (self . device )
354
+ mu = mu .to (self . device )
352
355
return mu + std * eps
353
356
354
357
def KLD_gaussian (self , min_std : float = 0.01 , scaler : float = 0.5 ) -> torch .Tensor :
@@ -365,10 +368,10 @@ def KLD_gaussian(self, min_std: float = 0.01, scaler: float = 0.5) -> torch.Tens
365
368
366
369
if min_std <= 0 or scaler <= 0 :
367
370
raise ValueError ()
368
- std_q = min_std + scaler * torch .sigmoid (self .z_logvar_all ).to (device )
369
- std_p = min_std + scaler * torch .sigmoid (self .z_logvar_context ).to (device )
370
- p = torch .distributions .Normal (self .z_mu_context .to (device ), std_p )
371
- q = torch .distributions .Normal (self .z_mu_all .to (device ), std_q )
371
+ std_q = min_std + scaler * torch .sigmoid (self .z_logvar_all ).to (self . device )
372
+ std_p = min_std + scaler * torch .sigmoid (self .z_logvar_context ).to (self . device )
373
+ p = torch .distributions .Normal (self .z_mu_context .to (self . device ), std_p )
374
+ q = torch .distributions .Normal (self .z_mu_all .to (self . device ), std_q )
372
375
return torch .distributions .kl_divergence (p , q ).sum ()
373
376
374
377
def posterior (
@@ -378,7 +381,7 @@ def posterior(
378
381
observation_noise : bool = False ,
379
382
posterior_transform : PosteriorTransform | None = None ,
380
383
) -> GPyTorchPosterior :
381
- r"""Computes the model's posterior distribution for given input tensors.
384
+ r"""Computes the model's posterior for given input tensors.
382
385
383
386
Args:
384
387
X: Input Tensor
@@ -391,20 +394,19 @@ def posterior(
391
394
defaults to None.
392
395
393
396
Returns:
394
- GPyTorchPosterior: The posterior distribution object
395
- utilizing MultivariateNormal.
397
+ GPyTorchPosterior: The posterior utilizing MultivariateNormal.
396
398
"""
397
399
X = self .transform_inputs (X )
398
- X = X .to (device )
400
+ X = X .to (self . device )
399
401
mean = self .decoder (
400
- X .to (device ), self .sample_z (self .z_mu_all , self .z_logvar_all )
402
+ X .to (self . device ), self .sample_z (self .z_mu_all , self .z_logvar_all )
401
403
)
402
404
z_var = torch .exp (self .z_logvar_all )
403
- covariance = torch .eye (X .size (0 )).to (device ) * z_var .mean ()
405
+ covariance = torch .eye (X .size (0 )).to (self . device ) * z_var .mean ()
404
406
if observation_noise :
405
407
covariance = covariance + self .likelihood .noise * torch .eye (
406
408
covariance .size (0 )
407
- ).to (device )
409
+ ).to (self . device )
408
410
mvn = MultivariateNormal (mean , covariance )
409
411
posterior = GPyTorchPosterior (mvn )
410
412
if posterior_transform is not None :
@@ -425,7 +427,7 @@ def transform_inputs(
425
427
Returns:
426
428
torch.Tensor: A tensor of transformed inputs
427
429
"""
428
- X = X .to (device )
430
+ X = X .to (self . device )
429
431
if input_transform is not None :
430
432
input_transform .to (X )
431
433
return input_transform (X )
@@ -454,10 +456,10 @@ def forward(
454
456
x_c , y_c , x_t , y_t = self .random_split_context_target (
455
457
train_X , train_Y , self .n_context , axis = axis
456
458
)
457
- x_t = x_t .to (device )
458
- x_c = x_c .to (device )
459
- y_c = y_c .to (device )
460
- y_t = y_t .to (device )
459
+ x_t = x_t .to (self . device )
460
+ x_c = x_c .to (self . device )
461
+ y_c = y_c .to (self . device )
462
+ y_t = y_t .to (self . device )
461
463
self .z_mu_all , self .z_logvar_all = self .data_to_z_params (
462
464
self .train_X , self .train_Y
463
465
)
@@ -486,9 +488,9 @@ def random_split_context_target(
486
488
self .n_context = n_context
487
489
mask = torch .randperm (x .shape [axis ])[:n_context ]
488
490
splitter = torch .zeros (x .shape [axis ], dtype = torch .bool )
489
- x_c = x [mask ].to (device )
490
- y_c = y [mask ].to (device )
491
+ x_c = x [mask ].to (self . device )
492
+ y_c = y [mask ].to (self . device )
491
493
splitter [mask ] = True
492
- x_t = x [~ splitter ].to (device )
493
- y_t = y [~ splitter ].to (device )
494
+ x_t = x [~ splitter ].to (self . device )
495
+ y_t = y [~ splitter ].to (self . device )
494
496
return x_c , y_c , x_t , y_t
0 commit comments