Open
Description
🐛 Bug
When sampling from a prior that's been moved to GPU, the correct device is only used for some priors, even though the state_dict
has been updated correctly (as of #2550, which this issue seems related to, although no regression was introduced as far as I can tell):
from gpytorch import priors
for prior in (
priors.NormalPrior(1.0, 1.0),
priors.GammaPrior(1.0, 1.0),
priors.HalfCauchyPrior(1.0, 1.0),
priors.HalfNormalPrior(1.0, 1.0),
priors.LogNormalPrior(1.0, 1.0),
priors.UniformPrior(1.0, 2.0),
):
prior.to("cuda:0")
samples = prior.rsample()
print(f"{str(prior):<35} {str(samples.device):<8} {dict(prior.state_dict())}")
NormalPrior() cuda:0 {'loc': tensor(1., device='cuda:0'), 'scale': tensor(1., device='cuda:0')}
GammaPrior() cuda:0 {'concentration': tensor(1., device='cuda:0'), 'rate': tensor(1., device='cuda:0')}
HalfCauchyPrior() cpu {'_transformed_scale': tensor(1., device='cuda:0')}
HalfNormalPrior() cpu {'_transformed_scale': tensor(1., device='cuda:0')}
LogNormalPrior() cpu {'_transformed_loc': tensor(1., device='cuda:0'), '_transformed_scale': tensor(1., device='cuda:0')}
UniformPrior(low: 1.0, high: 2.0) cpu {}
This manifests itself in BoTorch when a LogNormal
prior is in use. If the fit fails the first time, new initial hyperparameter values are sampled from the prior, which results in a device mismatch. In the reproducible example below, I'm triggering this manually with optimizer_kwargs
set such that a warning is raised, and warning_handler
set to trigger a retry for any warning.
To reproduce
** Code snippet to reproduce **
import torch
from botorch.fit import fit_gpytorch_mll
from botorch.models import SingleTaskGP
from botorch.models.transforms import Normalize, Standardize
from gpytorch import kernels, priors
from gpytorch.mlls import ExactMarginalLogLikelihood
n_inputs = 4
n_outputs = 2
n_train = 256
device = torch.device("cuda:0")
train_x = torch.rand(n_train, n_inputs, dtype=torch.float64, device=device)
train_y = torch.randn(n_train, n_outputs, dtype=torch.float64, device=device)
model = SingleTaskGP(
train_x,
train_y,
input_transform=Normalize(n_inputs),
outcome_transform=Standardize(m=n_outputs),
covar_module=kernels.ScaleKernel(
base_kernel=kernels.MaternKernel(
nu=2.5,
ard_num_dims=n_inputs,
batch_shape=torch.Size([n_outputs]),
lengthscale_prior=priors.LogNormalPrior(0.5, 0.5),
),
outputscale_prior=priors.GammaPrior(2.0, 0.15),
batch_shape=torch.Size([n_outputs]),
)
)
mll = ExactMarginalLogLikelihood(model.likelihood, model)
fit_gpytorch_mll(
mll,
optimizer_kwargs={"timeout_sec": 1e-3},
warning_handler=lambda _: False,
)
** Stack trace/error message **
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Cell In[2], line 34
16 model = SingleTaskGP(
17 train_x,
18 train_y,
(...)
30 )
31 )
33 mll = ExactMarginalLogLikelihood(model.likelihood, model)
---> 34 fit_gpytorch_mll(
35 mll,
36 optimizer_kwargs={"timeout_sec": 1e-3},
37 warning_handler=lambda _: False,
38 )
File .../python3.10/site-packages/botorch/fit.py:104, in fit_gpytorch_mll(mll, closure, optimizer, closure_kwargs, optimizer_kwargs, **kwargs)
101 if optimizer is not None: # defer to per-method defaults
102 kwargs["optimizer"] = optimizer
--> 104 return FitGPyTorchMLL(
105 mll,
106 type(mll.likelihood),
107 type(mll.model),
108 closure=closure,
109 closure_kwargs=closure_kwargs,
110 optimizer_kwargs=optimizer_kwargs,
111 **kwargs,
112 )
File .../python3.10/site-packages/botorch/utils/dispatcher.py:93, in Dispatcher.__call__(self, *args, **kwargs)
91 func = self.__getitem__(types=types)
92 try:
---> 93 return func(*args, **kwargs)
94 except MDNotImplementedError:
95 # Traverses registered methods in order, yields whenever a match is found
96 funcs = self.dispatch_iter(*types)
File .../python3.10/site-packages/botorch/fit.py:198, in _fit_fallback(mll, _, __, closure, optimizer, closure_kwargs, optimizer_kwargs, max_attempts, pick_best_of_all_attempts, warning_handler, caught_exception_types, **ignore)
195 ckpt_nograd = {name: ckpt[name] for name in params_nograd}
197 with parameter_rollback_ctx(params_nograd, checkpoint=ckpt_nograd):
--> 198 sample_all_priors(mll.model)
200 try:
201 # Fit the model
202 with catch_warnings(record=True) as warning_list, debug(True):
File .../python3.10/site-packages/botorch/optim/utils/model_utils.py:191, in sample_all_priors(model, max_retries)
186 raise RuntimeError(
187 "Failed to sample a feasible parameter value "
188 f"from the prior after {max_retries} attempts."
189 )
190 else:
--> 191 raise e
File .../python3.10/site-packages/botorch/optim/utils/model_utils.py:171, in sample_all_priors(model, max_retries)
166 prior_shape = prior._extended_shape()
167 if prior_shape.numel() == 1:
168 # For a univariate prior we can sample the size of the closure.
169 # Otherwise we will sample exactly the same value for all
170 # lengthscales where we commonly specify a univariate prior.
--> 171 setting_closure(module, prior.sample(closure(module).shape))
172 else:
173 closure_shape = closure(module).shape
File .../python3.10/site-packages/gpytorch/kernels/kernel.py:221, in Kernel._lengthscale_closure(self, m, v)
219 def _lengthscale_closure(self, m: Kernel, v: Tensor) -> Tensor:
220 # Used by the lengthscale_prior
--> 221 return m._set_lengthscale(v)
File .../python3.10/site-packages/gpytorch/kernels/kernel.py:231, in Kernel._set_lengthscale(self, value)
228 if not torch.is_tensor(value):
229 value = torch.as_tensor(value).to(self.raw_lengthscale)
--> 231 self.initialize(raw_lengthscale=self.raw_lengthscale_constraint.inverse_transform(value))
File .../python3.10/site-packages/gpytorch/module.py:103, in Module.initialize(self, **kwargs)
101 elif torch.is_tensor(val):
102 constraint = self.constraint_for_parameter_name(name)
--> 103 if constraint is not None and constraint.enforced and not constraint.check_raw(val):
104 raise RuntimeError(
105 "Attempting to manually set a parameter value that is out of bounds of "
106 f"its current constraints, {constraint}. "
107 "Most likely, you want to do the following:\n likelihood = GaussianLikelihood"
108 "(noise_constraint=gpytorch.constraints.GreaterThan(better_lower_bound))"
109 )
110 try:
File .../python3.10/site-packages/gpytorch/constraints/constraints.py:90, in Interval.check_raw(self, tensor)
88 def check_raw(self, tensor) -> bool:
89 return bool(
---> 90 torch.all((self.transform(tensor) <= self.upper_bound))
91 and torch.all(self.transform(tensor) >= self.lower_bound)
92 )
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!
Expected Behavior
System information
Please complete the following information:
- GPyTorch Version: 1.14.dev2+g83332c2c (latest main)
- PyTorch Version: '2.0.1+cu117'
- Computer OS: Linux