-
Notifications
You must be signed in to change notification settings - Fork 30
Description
First of all thanks for this incredibly great library, it has been a lifesaver! (and hi @gpleiss from jlg :)
🐛 Bug
When adding A
(a DenseLinearOperator
) to v
(a LowRankRootLinearOperator
) the result of inv_quad_logdet()
is not what it should be. (Other cases work -- for instance, if A
is a DiagLinearOperator
, things are fine.)
Below, I've included a test case based on computing normal likelihoods so that I can produce some SciPy numbers for ground truth.
To reproduce
I've broken up the code into a few cases, showing ways to get the right answer and the add_low_rank case which breaks.
Imports...
import torch
import numpy as np
from scipy.stats import multivariate_normal
import linear_operator
from linear_operator import operators
torch.__version__, linear_operator.__version__, np.__version__
# => ('2.5.1.post3', '0.5.3', '1.26.4')
Simulate some test data, and get SciPy log liks:
N = 2 ** 12
D = 8
mean = np.zeros(D)
# test case: normal log likelihoods
# using scipy for reference point
rg = np.random.default_rng(0)
# make a low rank + identity covariance
v = rg.normal(size=(D, 2))
cov = np.eye(D) + v @ v.T
# draws...
y = rg.multivariate_normal(mean=mean, cov=cov, size=N)
# log liks
rv = multivariate_normal(mean=mean, cov=cov)
scipy_lls = rv.logpdf(y)
# convert to torch for below
mean = torch.asarray(mean, dtype=torch.float)
v = torch.asarray(v, dtype=torch.float)
cov = torch.asarray(cov, dtype=torch.float)
y = torch.asarray(y, dtype=torch.float)
# helper function for getting log liks via inv_quad_logdet
log2pi = torch.log(torch.tensor(2 * np.pi))
def ll_via_inv_quad(cov, y):
inv_quad, logdet = linear_operator.inv_quad_logdet(cov, y.T, logdet=True, reduce_inv_quad=False)
ll = -0.5 * (inv_quad + logdet + log2pi * y.shape[1])
return ll
Cases which behave as expected
The isclose
s here are True.
# logliks via dense operator
dense_cov = operators.DenseLinearOperator(cov)
dense_lls = ll_via_inv_quad(dense_cov, y)
np.isclose(scipy_lls, dense_lls).all() # => True
# logliks via diag low rank
diag_eye = operators.DiagLinearOperator(torch.ones(D))
diag_root_cov = operators.LowRankRootAddedDiagLinearOperator(diag_eye, root)
diag_root_lls = ll_via_inv_quad(diag_root_cov, y)
np.isclose(scipy_lls, diag_root_lls).all() # => True
Failing case
If we use a dense linear operator and land in .add_low_rank(), the isclose()
is False here:
# logliks via dense add_low_rank
dense_eye = operators.DenseLinearOperator(torch.eye(D))
root = operators.LowRankRootLinearOperator(v)
dense_root_cov = dense_eye + root
dense_root_lls = ll_via_inv_quad(dense_root_cov, y)
np.isclose(scipy_lls, dense_root_lls).all() # => False
The differences are substanatial -- the max difference was 876303.6523659548 in this case, and the median abs difference from scipy was 25754.86743231282.
Workaround
I wanted a way to use Woodbury with a dense operator, so I wrote a quick implementation of a LowRankRootSumLinearOperator
which is basically identical to LowRankRootAddedDiagLinearOperator
-- it makes a Cholesky of the capacitance matrix. My code is here in case it is helpful at all: https://github.com/cwindolf/dartsort/blob/main/src/dartsort/util/more_operators.py
# log liks via alternative to the dense add_low_rank
from dartsort.util import more_operators
alt_root_cov = more_operators.LowRankRootSumLinearOperator(dense_eye, root)
alt_root_lls = ll_via_inv_quad(alt_root_cov, y)
np.isclose(scipy_lls, alt_root_lls).all() # => True
Expected Behavior
inv_quad_logdet()
should lead to results such that things match scipy in all cases.
System information
Please complete the following information:
-
0.5.3
-
2.5.1.post3
-
Mac