Open
Description
🐛 Bug
In the call of self._shaped_noise_covar in _MultitaskGaussianLikelihoodBase
additional arguments are not passed to the method as it is done in _GaussianLikelihoodBase
.
I wanted to implement the FixedNoiseGaussianLikelihood
for a Multitask GP. here it should be possible to pass a noise parameter with shape of the outputs to the likelihood as given in the example:
train_x = torch.randn(55, 2)
noises = torch.ones(55) * 0.01
likelihood = FixedNoiseGaussianLikelihood(noise=noises, learn_additional_noise=True)
pred_y = likelihood(gp_model(train_x))
test_x = torch.randn(21, 2)
test_noises = torch.ones(21) * 0.02
pred_y = likelihood(gp_model(test_x), noise=test_noises)
To reproduce
This my class (WIP) based on considered solutions in #901
class FixedTaskNoiseMultitaskLikelihood(_MultitaskGaussianLikelihoodBase):
def __init__(
self,
noise: torch.Tensor,
has_task_noise: bool = False,
task_noise: torch.Tensor = None,
task_noise_factor: torch.Tensor = None,
*args,
**kwargs
) -> None:
noise_covar = FixedGaussianNoise(noise=noise)
super().__init__(noise_covar=noise_covar, *args, **kwargs)
self.has_global_noise = False
self.has_task_noise = has_task_noise
if self.has_task_noise:
if task_noise is not None:
self.task_noise = task_noise
self.task_noise_factor = None
elif task_noise_factor is not None:
self.task_noise_factor = task_noise_factor
self.task_noise = None
else:
raise ValueError("Must supply task noise or task noise factor")
def _shaped_noise_covar(self, base_shape: torch.Size, *params, add_noise=True, **kwargs):
if self.has_task_noise and self.task_noise is not None:
if 'noise' in kwargs is not None:
return DiagLinearOperator(kwargs['noise'])
else:
return DiagLinearOperator(self.task_noise)
else:
data_noise = self.noise_covar(*params, shape=torch.Size((base_shape[-2],)), **kwargs)
if len(params) > 0:
# we can infer the shape from the params
shape = None
else:
# here shape[:-1] is the batch shape requested, and shape[-1] is `n`, the number of points
shape = base_shape
_data_noise = self.noise_covar(*params, shape=shape, **kwargs)
if not self.has_task_noise:
eye = torch.ones(1, device=data_noise.device, dtype=data_noise.dtype)
task_noise = ConstantDiagLinearOperator(
eye, diag_shape=torch.Size((self.num_tasks,))
)
else: # task_noise_factor
task_noise_factor = self.task_noise_factor.to(device=data_noise.device, dtype=data_noise.dtype)
task_noise = DiagLinearOperator(task_noise_factor)
return KroneckerProductLinearOperator(data_noise, task_noise)
calling
task_noise = torch.ones_like(train_y).flatten()
likelihood = FixedTaskNoiseMultitaskLikelihood(num_tasks=num_tasks, noise, rank=num_tasks, has_task_noise=True, task_noise=task_noise)
test_noises = torch.ones(torch.Size((model.num_tasks,test_x.shape[0]))).flatten()
likelihood(gp_model(test_x), noise=test_noises)
should pass noise
to _shaped_noise_covar.