forked from nimarb/pytorch_influence_functions
-
Notifications
You must be signed in to change notification settings - Fork 4
Open
Description
Thanks for the nice implementation!
I noticed that there was a unit test file test_hvp_grad.py
. Is it supposed to pass? Because when I ran it, it failed at both subtests (test_s_test_sample
and test_s_test_cg
, though I'm more interested in the former one), while all other tests (test_cg.py
, test_leave_one.py
) passed. Is this a known issue or am I missing something? Below is the log:
$ python test_hvp_grad.py
Grads:
tensor([ 1.1013e+00, -6.6031e-01, -4.0760e-01, 4.7483e+00, -1.8438e+00, [60/1385]
9.2940e-01, 9.0911e-01, 1.4016e+00, 3.9964e-01, 2.0680e-02,
-1.0921e+00, 6.5484e-01, 4.0422e-01, -4.7089e+00, 1.8285e+00,
-9.2169e-01, -9.0158e-01, -1.3900e+00, -3.9632e-01, -2.0509e-02,
-9.1265e-03, 5.4722e-03, 3.3779e-03, -3.9351e-02, 1.5280e-02,
-7.7022e-03, -7.5341e-03, -1.1616e-02, -3.3119e-03, -1.7138e-04,
-6.9659e-01, 6.9082e-01, 5.7729e-03], device='cuda:0')
Hessian:
tensor([[ 0.3261, -0.0991, -0.0574, ..., -0.0672, 0.0277, 0.0395],
[-0.0991, 0.3469, 0.0993, ..., 0.0359, -0.0204, -0.0154],
[-0.0574, 0.0993, 0.3803, ..., 0.0130, -0.0178, 0.0048],
...,
[-0.0672, 0.0359, 0.0130, ..., 0.0831, -0.0442, -0.0389],
[ 0.0277, -0.0204, -0.0178, ..., -0.0442, 0.1195, -0.0754],
[ 0.0395, -0.0154, 0.0048, ..., -0.0389, -0.0754, 0.1143]],
device='cuda:0')
Inverse Hessian
tensor([[ 1.8762e+07, 9.4299e+07, 1.3685e+08, ..., -6.7908e+07,
-6.7908e+07, -6.7908e+07],
[ 4.0619e+06, 1.8239e+07, -2.2055e+06, ..., -5.3456e+06,
-5.3456e+06, -5.3456e+06],
[ 1.3800e+08, 5.3143e+07, 3.8250e+07, ..., 1.2657e+08,
1.2657e+08, 1.2657e+08],
...,
[-2.5811e+07, 1.3880e+08, 2.7709e+08, ..., -1.3207e+08,
-1.3207e+08, -1.3207e+08],
[-2.5811e+07, 1.3880e+08, 2.7709e+08, ..., -1.3207e+08,
-1.3207e+08, -1.3207e+08],
[-2.5811e+07, 1.3880e+08, 2.7709e+08, ..., -1.3207e+08,
-1.3207e+08, -1.3207e+08]], device='cuda:0')
Real IHVP
tensor([ 106.3438, 4.5254, -53.4375, 6.1758, 35.1250, 25.2969,
-9.5625, 2.5459, -83.9531, 10.5859, 86.7500, 16.6602,
124.1797, -12.3633, 41.1875, 161.6406, 168.1250, 4.7598,
-127.6562, 8.8418, 87.3438, 26.9141, -57.0625, -18.2109,
38.3750, -126.5000, 153.6875, 3.5601, -201.8750, 3.6348,
154.3125, 162.1250, 168.6875], device='cuda:0')
Conjugate function value: -53.6090087890625, lin: -107.21803283691406, quad: 53.60902404785156
Conjugate function value: -58.024658203125, lin: -113.10745239257812, quad: 55.082794189453125
Conjugate function value: -58.946067810058594, lin: -117.88719177246094, quad: 58.941123962402344
Conjugate function value: -59.0511474609375, lin: -117.72232818603516, quad: 58.671180725097656
Conjugate function value: -59.05691146850586, lin: -117.94234466552734, quad: 58.885433197021484
Conjugate function value: -59.058937072753906, lin: -118.12271118164062, quad: 59.06377410888672
Conjugate function value: -59.05912399291992, lin: -118.10481262207031, quad: 59.04568862915039
Conjugate function value: -59.05913162231445, lin: -118.11825561523438, quad: 59.05912399291992
Conjugate function value: -59.05912780761719, lin: -118.11825561523438, quad: 59.05912780761719
Conjugate function value: -59.05912780761719, lin: -118.11825561523438, quad: 59.05912780761719
Conjugate function value: -59.05912780761719, lin: -118.11825561523438, quad: 59.05912780761719
Warning: Desired error not necessarily achieved due to precision loss.
Current function value: -59.059135
Iterations: 11
Function evaluations: 37
Gradient evaluations: 25
Hessian evaluations: 250
CG
tensor([ 0.6166, -0.3767, -1.2328, 10.8453, -3.1541, 0.7188, 2.4674, 0.3484,
2.8233, 0.0836, -0.5915, 0.7006, 0.8041, -8.1836, 3.3949, -0.7271,
-2.1705, -1.8215, -0.7727, 0.2280, -0.4861, -0.3473, 0.3244, -2.0964,
0.0632, -0.1725, -0.0668, 1.1967, -2.0442, -0.3139, -3.9365, 0.5366,
3.0261], device='cuda:0')
real / estimate
tensor([ 1.7247e+02, -1.2014e+01, 4.3346e+01, 5.6945e-01, -1.1136e+01,
3.5191e+01, -3.8756e+00, 7.3077e+00, -2.9736e+01, 1.2668e+02,
-1.4667e+02, 2.3781e+01, 1.5444e+02, 1.5107e+00, 1.2132e+01,
-2.2230e+02, -7.7458e+01, -2.6131e+00, 1.6520e+02, 3.8782e+01,
-1.7967e+02, -7.7504e+01, -1.7589e+02, 8.6866e+00, 6.0728e+02,
7.3351e+02, -2.2992e+03, 2.9749e+00, 9.8755e+01, -1.1580e+01,
-3.9200e+01, 3.0214e+02, 5.5743e+01], device='cuda:0')
L-2 difference: 539.9702758789062
L-infty difference: 199.83079528808594
IHVP sample 0: 100%|███████████████████████████████████████████████████████████████| 10000/10000 [00:12<00:00, 789.01it/s, est_norm=644]
IHVP sample 1: 100%|███████████████████████████████████████████████████████████████| 10000/10000 [00:12<00:00, 789.82it/s, est_norm=414]
IHVP sample 2: 100%|███████████████████████████████████████████████████████████████| 10000/10000 [00:11<00:00, 851.64it/s, est_norm=537]
IHVP sample 3: 100%|███████████████████████████████████████████████████████████████| 10000/10000 [00:12<00:00, 824.54it/s, est_norm=381]
IHVP sample 4: 100%|███████████████████████████████████████████████████████████████| 10000/10000 [00:12<00:00, 821.73it/s, est_norm=634]
IHVP sample 5: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10000/10000 [00:12<00:00, 819.52it/s, est_norm=587]
IHVP sample 6: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10000/10000 [00:11<00:00, 840.29it/s, est_norm=625]
IHVP sample 7: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10000/10000 [00:12<00:00, 798.59it/s, est_norm=793]
IHVP sample 8: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10000/10000 [00:12<00:00, 810.49it/s, est_norm=509]
IHVP sample 9: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10000/10000 [00:12<00:00, 815.49it/s, est_norm=688]
LiSSA
tensor([ 0.7166, 0.3238, -1.7510, 9.5551, -2.1528, 0.2947, 1.0022, -0.0366,
1.7486, -0.3762, -0.3074, 0.0676, 2.0942, -7.6356, 3.1385, 0.1560,
-0.9661, -1.2637, 0.4168, 0.7077, -0.4091, -0.3911, -0.3433, -1.9201,
-0.9841, -0.4505, -0.0366, 1.3006, -2.1651, -0.3313, -3.2533, -0.0115,
3.2655], device='cuda:0')
real / estimate
tensor([ 1.4840e+02, 1.3978e+01, 3.0518e+01, 6.4634e-01, -1.6316e+01,
8.5832e+01, -9.5412e+00, -6.9511e+01, -4.8012e+01, -2.8143e+01,
-2.8219e+02, 2.4652e+02, 5.9297e+01, 1.6192e+00, 1.3123e+01,
1.0358e+03, -1.7403e+02, -3.7664e+00, -3.0625e+02, 1.2494e+01,
-2.1348e+02, -6.8813e+01, 1.6621e+02, 9.4842e+00, -3.8995e+01,
2.8077e+02, -4.1982e+03, 2.7373e+00, 9.3240e+01, -1.0972e+01,
-4.7432e+01, -1.4123e+04, 5.1657e+01], device='cuda:0')
L-2 difference: 538.7401123046875
L-infty difference: 199.70989990234375
F
======================================================================
FAIL: test_s_test_cg (__main__.TestIHVPGrad)
----------------------------------------------------------------------
Traceback (most recent call last):
File "test_hvp_grad.py", line 144, in test_s_test_cg
self.assertTrue(self.check_estimation(estimated_ihvp))
AssertionError: False is not true
======================================================================
FAIL: test_s_test_sample (__main__.TestIHVPGrad)
----------------------------------------------------------------------
Traceback (most recent call last):
File "test_hvp_grad.py", line 162, in test_s_test_sample
self.assertTrue(self.check_estimation(flat_estimated_ihvp))
AssertionError: False is not true
----------------------------------------------------------------------
Ran 2 tests in 343.955s
FAILED (failures=2)
Also it seems that the results of test_cg
and test_sample
are close, but both are much different from the real_ihvp
. Maybe the real ihvp isn't computed correctly?
Thank you!
Metadata
Metadata
Assignees
Labels
No labels