Skip to content

[Bug] _MultitaskGaussianLikelihoodBase does not pass additional arguments in marginal #2630

Open
@adrianLepp

Description

@adrianLepp

🐛 Bug

In the call of self._shaped_noise_covar in _MultitaskGaussianLikelihoodBase
additional arguments are not passed to the method as it is done in _GaussianLikelihoodBase
.
I wanted to implement the FixedNoiseGaussianLikelihood for a Multitask GP. here it should be possible to pass a noise parameter with shape of the outputs to the likelihood as given in the example:

train_x = torch.randn(55, 2)
noises = torch.ones(55) * 0.01
likelihood = FixedNoiseGaussianLikelihood(noise=noises, learn_additional_noise=True)
pred_y = likelihood(gp_model(train_x))

test_x = torch.randn(21, 2)
test_noises = torch.ones(21) * 0.02
pred_y = likelihood(gp_model(test_x), noise=test_noises)

To reproduce

This my class (WIP) based on considered solutions in #901

class FixedTaskNoiseMultitaskLikelihood(_MultitaskGaussianLikelihoodBase):
    def __init__(
        self,
        noise: torch.Tensor,
        has_task_noise: bool = False,
        task_noise: torch.Tensor = None,
        task_noise_factor: torch.Tensor = None,
        *args,
        **kwargs
    ) -> None:
        noise_covar = FixedGaussianNoise(noise=noise)
        super().__init__(noise_covar=noise_covar, *args, **kwargs)
        self.has_global_noise = False
        self.has_task_noise = has_task_noise

        if self.has_task_noise:
            if task_noise is not None:
                self.task_noise = task_noise
                self.task_noise_factor = None
            elif task_noise_factor is not None:
                self.task_noise_factor = task_noise_factor
                self.task_noise = None
            else:
                raise ValueError("Must supply task noise or task noise factor")
        
    def _shaped_noise_covar(self, base_shape: torch.Size,  *params, add_noise=True, **kwargs): 
        if self.has_task_noise and self.task_noise is not None:
            if 'noise' in kwargs is not None:
                return DiagLinearOperator(kwargs['noise'])
            else:
                return DiagLinearOperator(self.task_noise)

        else:
            data_noise = self.noise_covar(*params, shape=torch.Size((base_shape[-2],)), **kwargs)

            if len(params) > 0:
            # we can infer the shape from the params
                shape = None
            else:
            # here shape[:-1] is the batch shape requested, and shape[-1] is `n`, the number of points
                shape = base_shape

            _data_noise = self.noise_covar(*params, shape=shape, **kwargs)

            if not self.has_task_noise: 
                eye = torch.ones(1, device=data_noise.device, dtype=data_noise.dtype)
                task_noise = ConstantDiagLinearOperator(
                    eye, diag_shape=torch.Size((self.num_tasks,))
                )
            else: # task_noise_factor
                task_noise_factor = self.task_noise_factor.to(device=data_noise.device, dtype=data_noise.dtype)
                task_noise = DiagLinearOperator(task_noise_factor)
        
            return KroneckerProductLinearOperator(data_noise, task_noise)

calling

task_noise = torch.ones_like(train_y).flatten()
likelihood = FixedTaskNoiseMultitaskLikelihood(num_tasks=num_tasks, noise, rank=num_tasks, has_task_noise=True, task_noise=task_noise)

test_noises = torch.ones(torch.Size((model.num_tasks,test_x.shape[0]))).flatten()
likelihood(gp_model(test_x), noise=test_noises)

should pass noise to _shaped_noise_covar.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions