11
11
Contributor: eibarolle
12
12
"""
13
13
14
- import copy
15
- import numpy as np
16
- from numpy .random import binomial
17
14
import torch
18
15
import torch .nn as nn
19
- import matplotlib .pyplot as plts
20
- # %matplotlib inline
21
16
from botorch .models .model import Model
22
17
from botorch .posteriors import GPyTorchPosterior
23
18
from botorch .acquisition .objective import PosteriorTransform
24
- from sklearn .gaussian_process import GaussianProcessRegressor
25
- from sklearn .gaussian_process .kernels import (RBF , Matern , RationalQuadratic ,
26
- ExpSineSquared , DotProduct ,
27
- ConstantKernel )
28
19
from typing import Callable , List , Optional , Tuple
29
- from torch .nn import Module , ModuleDict , ModuleList
30
- from sklearn import preprocessing
31
- from scipy .stats import multivariate_normal
20
+ from torch .nn import Module
32
21
from gpytorch .distributions import MultivariateNormal
33
22
34
- device = torch .device ("cpu" )
23
+ device = torch .device ("cuda" if torch . cuda . is_available () else " cpu" )
35
24
# Account for different acquisitions
36
25
37
26
#reference: https://chrisorm.github.io/NGP.html
@@ -59,21 +48,21 @@ def __init__(
59
48
prev_dim = input_dim
60
49
61
50
for hidden_dim in hidden_dims :
62
- layer = nn .Linear (prev_dim , hidden_dim )
51
+ layer = nn .Linear (prev_dim , hidden_dim ). to ( device )
63
52
if init_func is not None :
64
53
init_func (layer .weight )
65
54
layers .append (layer )
66
55
layers .append (activation ())
67
56
prev_dim = hidden_dim
68
57
69
- final_layer = nn .Linear (prev_dim , output_dim )
58
+ final_layer = nn .Linear (prev_dim , output_dim ). to ( device )
70
59
if init_func is not None :
71
60
init_func (final_layer .weight )
72
61
layers .append (final_layer )
73
- self .model = nn .Sequential (* layers )
62
+ self .model = nn .Sequential (* layers ). to ( device )
74
63
75
64
def forward (self , x : torch .Tensor ) -> torch .Tensor :
76
- return self .model (x )
65
+ return self .model (x . to ( device ) )
77
66
78
67
79
68
class REncoder (nn .Module ):
@@ -95,12 +84,9 @@ def __init__(
95
84
init_func: A function initializing the weights, defaults to nn.init.normal_.
96
85
"""
97
86
super ().__init__ ()
98
- self .mlp = MLP (input_dim , output_dim , hidden_dims , activation = activation , init_func = init_func )
87
+ self .mlp = MLP (input_dim = input_dim , output_dim = output_dim , hidden_dims = hidden_dims , activation = activation , init_func = init_func ). to ( device )
99
88
100
- def forward (
101
- self ,
102
- inputs : torch .Tensor ,
103
- ) -> torch .Tensor :
89
+ def forward (self , inputs : torch .Tensor ) -> torch .Tensor :
104
90
r"""Forward pass for representation encoder.
105
91
106
92
Args:
@@ -109,7 +95,7 @@ def forward(
109
95
Returns:
110
96
torch.Tensor: Encoded representations
111
97
"""
112
- return self .mlp (inputs )
98
+ return self .mlp (inputs . to ( device ) )
113
99
114
100
class ZEncoder (nn .Module ):
115
101
def __init__ (self ,
@@ -130,13 +116,10 @@ def __init__(self,
130
116
init_func: A function initializing the weights, defaults to nn.init.normal_.
131
117
"""
132
118
super ().__init__ ()
133
- self .mean_net = MLP (input_dim , output_dim , hidden_dims , activation = activation , init_func = init_func )
134
- self .logvar_net = MLP (input_dim , output_dim , hidden_dims , activation = activation , init_func = init_func )
119
+ self .mean_net = MLP (input_dim = input_dim , output_dim = output_dim , hidden_dims = hidden_dims , activation = activation , init_func = init_func ). to ( device )
120
+ self .logvar_net = MLP (input_dim = input_dim , output_dim = output_dim , hidden_dims = hidden_dims , activation = activation , init_func = init_func ). to ( device )
135
121
136
- def forward (
137
- self ,
138
- inputs : torch .Tensor ,
139
- ) -> torch .Tensor :
122
+ def forward (self , inputs : torch .Tensor ) -> torch .Tensor :
140
123
r"""Forward pass for latent encoder.
141
124
142
125
Args:
@@ -147,6 +130,7 @@ def forward(
147
130
- Mean of the latent Gaussian distribution.
148
131
- Log variance of the latent Gaussian distribution.
149
132
"""
133
+ inputs = inputs .to (device )
150
134
return self .mean_net (inputs ), self .logvar_net (inputs )
151
135
152
136
class Decoder (torch .nn .Module ):
@@ -168,23 +152,21 @@ def __init__(
168
152
init_func: A function initializing the weights, defaults to nn.init.normal_.
169
153
"""
170
154
super ().__init__ ()
171
- self .mlp = MLP (input_dim , output_dim , hidden_dims , activation = activation , init_func = init_func )
155
+ self .mlp = MLP (input_dim = input_dim , output_dim = output_dim , hidden_dims = hidden_dims , activation = activation , init_func = init_func ). to ( device )
172
156
173
- def forward (
174
- self ,
175
- x_pred : torch .Tensor ,
176
- z : torch .Tensor ,
177
- ) -> torch .Tensor :
157
+ def forward (self , x_pred : torch .Tensor , z : torch .Tensor ) -> torch .Tensor :
178
158
r"""Forward pass for decoder.
179
159
180
160
Args:
181
- x_pred: No. of data points, by x_dim
182
- z: No. of samples, by z_dim
161
+ x_pred: Input points of shape (n x d_x), representing # of data points by x_dim.
162
+ z: Latent encoding of shape (num_samples x d_z), representing # of samples by z_dim.
183
163
184
164
Returns:
185
- torch.Tensor: Predicted target values.
165
+ torch.Tensor: Predicted target values of shape (n, z_dim), representing # of data points by z_dim .
186
166
"""
187
- z_expanded = z .unsqueeze (0 ).expand (x_pred .size (0 ), - 1 )
167
+ z = z .to (device )
168
+ z_expanded = z .unsqueeze (0 ).expand (x_pred .size (0 ), - 1 ).to (device )
169
+ x_pred = x_pred .to (device )
188
170
xz = torch .cat ([x_pred , z_expanded ], dim = - 1 )
189
171
return self .mlp (xz )
190
172
@@ -231,16 +213,14 @@ def __init__(
231
213
init_func: A function initializing the weights, defaults to nn.init.normal_.
232
214
"""
233
215
super ().__init__ ()
234
- self .r_encoder = REncoder (x_dim + y_dim , r_dim , r_hidden_dims , activation = activation , init_func = init_func )
235
- self .z_encoder = ZEncoder (r_dim , z_dim , z_hidden_dims , activation = activation , init_func = init_func )
236
- self .decoder = Decoder (x_dim + z_dim , y_dim , decoder_hidden_dims , activation = activation , init_func = init_func )
216
+ self .r_encoder = REncoder (x_dim + y_dim , r_dim , r_hidden_dims , activation = activation , init_func = init_func ). to ( device )
217
+ self .z_encoder = ZEncoder (r_dim , z_dim , z_hidden_dims , activation = activation , init_func = init_func ). to ( device )
218
+ self .decoder = Decoder (x_dim + z_dim , y_dim , decoder_hidden_dims , activation = activation , init_func = init_func ). to ( device )
237
219
self .z_dim = z_dim
238
220
self .z_mu_all = None
239
221
self .z_logvar_all = None
240
222
self .z_mu_context = None
241
223
self .z_logvar_context = None
242
- # self.optimizer = torch.optim.Adam(self.parameters(), lr=1e-3) # Look at BoTorch native versions
243
- #self.train(n_epochs, x_train, y_train)
244
224
245
225
def data_to_z_params (
246
226
self ,
@@ -264,18 +244,20 @@ def data_to_z_params(
264
244
- x_t: Target input data.
265
245
- y_t: Target target data.
266
246
"""
267
- xy = torch .cat ([x ,y ], dim = xy_dim )
247
+ x = x .to (device )
248
+ y = y .to (device )
249
+ xy = torch .cat ([x ,y ], dim = xy_dim ).to (device ).to (device )
268
250
rs = self .r_encoder (xy )
269
- r_agg = rs .mean (dim = r_dim )
251
+ r_agg = rs .mean (dim = r_dim ). to ( device )
270
252
return self .z_encoder (r_agg )
271
253
272
254
def sample_z (
273
255
self ,
274
256
mu : torch .Tensor ,
275
257
logvar : torch .Tensor ,
276
258
n : int = 1 ,
277
- min_std : float = 0.1 ,
278
- scaler : float = 0.9
259
+ min_std : float = 0.01 ,
260
+ scaler : float = 0.5
279
261
) -> torch .Tensor :
280
262
r"""Reparameterization trick for z's latent distribution.
281
263
@@ -291,12 +273,15 @@ def sample_z(
291
273
"""
292
274
if min_std <= 0 or scaler <= 0 :
293
275
raise ValueError ()
276
+
277
+ shape = [n , self .z_dim ]
294
278
if n == 1 :
295
- eps = torch .autograd .Variable (logvar .data .new (self .z_dim ).normal_ ()).to (device )
296
- else :
297
- eps = torch .autograd .Variable (logvar .data .new (n ,self .z_dim ).normal_ ()).to (device )
279
+ shape = shape [1 :]
280
+ eps = torch .autograd .Variable (logvar .data .new (* shape ).normal_ ()).to (device )
298
281
299
282
std = min_std + scaler * torch .sigmoid (logvar )
283
+ std = std .to (device )
284
+ mu = mu .to (device )
300
285
return mu + std * eps
301
286
302
287
def KLD_gaussian (
@@ -316,10 +301,10 @@ def KLD_gaussian(
316
301
317
302
if min_std <= 0 or scaler <= 0 :
318
303
raise ValueError ()
319
- std_q = min_std + scaler * torch .sigmoid (self .z_logvar_all )
320
- std_p = min_std + scaler * torch .sigmoid (self .z_logvar_context )
321
- p = torch .distributions .Normal (self .z_mu_context , std_p )
322
- q = torch .distributions .Normal (self .z_mu_all , std_q )
304
+ std_q = min_std + scaler * torch .sigmoid (self .z_logvar_all ). to ( device )
305
+ std_p = min_std + scaler * torch .sigmoid (self .z_logvar_context ). to ( device )
306
+ p = torch .distributions .Normal (self .z_mu_context . to ( device ) , std_p )
307
+ q = torch .distributions .Normal (self .z_mu_all . to ( device ) , std_q )
323
308
return torch .distributions .kl_divergence (p , q ).sum ()
324
309
325
310
def posterior (
@@ -343,7 +328,8 @@ def posterior(
343
328
GPyTorchPosterior: The posterior distribution object
344
329
utilizing MultivariateNormal.
345
330
"""
346
- mean = self .decoder (X , self .sample_z (self .z_mu_all , self .z_logvar_all ))
331
+ X = X .to (device )
332
+ mean = self .decoder (X .to (device ), self .sample_z (self .z_mu_all , self .z_logvar_all ))
347
333
covariance = torch .eye (X .size (0 )) * covariance_multiplier
348
334
if (observation_noise ):
349
335
covariance = covariance + observation_constant
@@ -352,20 +338,6 @@ def posterior(
352
338
if posterior_transform is not None :
353
339
posterior = posterior_transform (posterior )
354
340
return posterior
355
-
356
- def load_state_dict (
357
- self ,
358
- state_dict : dict ,
359
- strict : bool = True
360
- ) -> None :
361
- """
362
- Initialize the fully Bayesian model before loading the state dict.
363
-
364
- Args:
365
- state_dict (dict): A dictionary containing the parameters.
366
- strict (bool): Case matching strictness.
367
- """
368
- super ().load_state_dict (state_dict , strict = strict )
369
341
370
342
def transform_inputs (
371
343
self ,
@@ -381,6 +353,7 @@ def transform_inputs(
381
353
Returns:
382
354
torch.Tensor: A tensor of transformed inputs
383
355
"""
356
+ X = X .to (device )
384
357
if input_transform is not None :
385
358
input_transform .to (X )
386
359
return input_transform (X )
@@ -420,6 +393,11 @@ def forward(
420
393
if y_c .size (1 - target_dim ) != y_t .size (1 - target_dim ):
421
394
raise ValueError ()
422
395
396
+ x_t = x_t .to (device )
397
+ x_c = x_c .to (device )
398
+ y_c = y_c .to (device )
399
+ y_t = y_t .to (device )
400
+
423
401
self .z_mu_all , self .z_logvar_all = self .data_to_z_params (torch .cat ([x_c , x_t ], dim = input_dim ), torch .cat ([y_c , y_t ], dim = target_dim ))
424
402
self .z_mu_context , self .z_logvar_context = self .data_to_z_params (x_c , y_c )
425
403
z = self .sample_z (self .z_mu_all , self .z_logvar_all )
@@ -447,12 +425,12 @@ def random_split_context_target(
447
425
- x_t: Target input data.
448
426
- y_t: Target target data.
449
427
"""
450
- ind = np . arange (x .shape [0 ])
451
- mask = np . random . choice ( ind , size = n_context , replace = False )
452
- x_c = torch .from_numpy (x [mask ])
453
- y_c = torch .from_numpy ( y [ mask ] )
454
- x_t = torch . from_numpy ( np . delete ( x , mask , axis = 0 ))
455
- y_t = torch .from_numpy (np . delete ( y , mask , axis = 0 ) )
456
-
428
+ mask = torch . randperm (x .shape [0 ])[: n_context ]
429
+ x_c = torch . from_numpy ( x [ mask ]). to ( device )
430
+ y_c = torch .from_numpy (y [mask ]). to ( device )
431
+ splitter = torch .zeros ( x . shape [ 0 ], dtype = torch . bool )
432
+ splitter [ mask ] = True
433
+ x_t = torch .from_numpy (x [ ~ splitter ]). to ( device )
434
+ y_t = torch . from_numpy ( y [ ~ splitter ]). to ( device )
457
435
return x_c , y_c , x_t , y_t
458
436
0 commit comments