Skip to content

Commit d16119d

Browse files
author
Frederik Rahbaek Warburg
committed
fixed a bug
1 parent cc35fd6 commit d16119d

File tree

6 files changed

+63
-54
lines changed

6 files changed

+63
-54
lines changed

stochman/curves.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -39,10 +39,7 @@ def __init__(
3939
_begin = begin.detach() # BxD
4040
_end = end.detach() # BxD
4141
else:
42-
raise ValueError(
43-
"BasicCurve.__init__ requires begin and end points to have "
44-
"the same shape"
45-
)
42+
raise ValueError("BasicCurve.__init__ requires begin and end points to have " "the same shape")
4643

4744
# register begin and end as buffers
4845
self.register_buffer("begin", _begin) # BxD
@@ -82,6 +79,7 @@ def plot(
8279
"""
8380
with torch.no_grad():
8481
import matplotlib.pyplot as plt
82+
8583
t = torch.linspace(t0, t1, N, dtype=self.begin.dtype, device=self.device)
8684
points = self(t) # NxD or BxNxD
8785

@@ -126,7 +124,7 @@ def euclidean_length(self, t0: float = 0.0, t1: float = 1.0, N: int = 100) -> to
126124
if not is_batched:
127125
points = points.unsqueeze(0) # 1xNxD
128126
delta = points[:, 1:] - points[:, :-1] # Bx(N-1)xD
129-
energies = (delta ** 2).sum(dim=2) # Bx(N-1)
127+
energies = (delta**2).sum(dim=2) # Bx(N-1)
130128
lengths = energies.sqrt().sum(dim=1) # B
131129
return lengths
132130

stochman/manifold.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -444,7 +444,7 @@ def curve_energy(self, curve: BasicCurve, reduction: Optional[str] = "sum", dt=N
444444
emb_curve = self.embed(curve) # BxNxD
445445
B, N, D = emb_curve.shape
446446
delta = emb_curve[:, 1:, :] - emb_curve[:, :-1, :] # Bx(N-1)xD
447-
energy = (delta ** 2).sum((1, 2)) * dt # B
447+
energy = (delta**2).sum((1, 2)) * dt # B
448448
return tensor_reduction(energy, reduction)
449449

450450
def curve_length(self, curve: BasicCurve, dt=None):
@@ -541,7 +541,7 @@ def __init__(self, data, sigma, rho, device=None):
541541
"""
542542
super().__init__()
543543
self.data = data
544-
self.sigma2 = sigma ** 2
544+
self.sigma2 = sigma**2
545545
self.rho = rho
546546
self.device = device
547547

@@ -586,7 +586,7 @@ def metric(self, c, return_deriv=False):
586586
if return_deriv:
587587
weighted_delta = (w_p / sigma2).reshape(-1, 1).expand(-1, D) * delta # NxD
588588
dSdc = 2.0 * torch.diag(w_p.mm(delta).flatten()) - weighted_delta.t().mm(delta2) # DxD
589-
dM = dSdc.t() * (m ** 2).reshape(-1, 1).expand(-1, D) # DxD
589+
dM = dSdc.t() * (m**2).reshape(-1, 1).expand(-1, D) # DxD
590590
dMdc.append(dM.reshape(1, D, D))
591591

592592
if return_deriv:

stochman/nnj.py

Lines changed: 50 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -77,8 +77,10 @@ def _jacobian_wrt_input_mult_left_vec(self, x: Tensor, val: Tensor, jac_in: Tens
7777
def _jacobian_wrt_input_transpose_mult_left_vec(self, x: Tensor, val: Tensor, jac_in: Tensor) -> Tensor:
7878
return F.linear(jac_in.movedim(1, -1), self.weight.T, bias=None).movedim(-1, 1)
7979

80-
def _jacobian_wrt_weight_sandwich(self, x: Tensor, val: Tensor, tmp: Tensor, diag: bool = False) -> Tensor:
81-
80+
def _jacobian_wrt_weight_sandwich(
81+
self, x: Tensor, val: Tensor, tmp: Tensor, diag: bool = False
82+
) -> Tensor:
83+
8284
b, c = x.shape
8385
diag_elements = torch.diagonal(tmp, dim1=1, dim2=2)
8486
feat_k2 = (x**2).unsqueeze(1)
@@ -125,7 +127,9 @@ def _jacobian_wrt_input_mult_left_vec(self, x: Tensor, val: Tensor, jac_in: Tens
125127
.movedim(dims2, dims1)
126128
)
127129

128-
def _jacobian_wrt_weight_sandwich(self, x: Tensor, val: Tensor, tmp: Tensor, diag: bool = False) -> Tensor:
130+
def _jacobian_wrt_weight_sandwich(
131+
self, x: Tensor, val: Tensor, tmp: Tensor, diag: bool = False
132+
) -> Tensor:
129133
# non parametric, so return empty
130134
return None
131135

@@ -137,17 +141,16 @@ def _jacobian_wrt_input_sandwich(self, x: Tensor, val: Tensor, tmp: Tensor, diag
137141
weight = torch.ones(c2, c1, int(self.scale_factor), int(self.scale_factor), device=x.device)
138142

139143
tmp = F.conv2d(
140-
tmp.reshape(-1, c2, h2, w2),
141-
weight=weight,
142-
bias=None,
143-
stride=int(self.scale_factor),
144-
padding=0,
145-
dilation=1,
146-
groups=1,
147-
)
148-
149-
return tmp.reshape(b, c1*h1*w1)
144+
tmp.reshape(-1, c2, h2, w2),
145+
weight=weight,
146+
bias=None,
147+
stride=int(self.scale_factor),
148+
padding=0,
149+
dilation=1,
150+
groups=1,
151+
)
150152

153+
return tmp.reshape(b, c1 * h1 * w1)
151154

152155

153156
class Conv1d(AbstractJacobian, nn.Conv1d):
@@ -416,36 +419,36 @@ def _jacobian_wrt_weight_T_mult_right(
416419
return Jt_tmp
417420

418421
def _jacobian_wrt_weight_mult_left(
419-
self, x: Tensor, val: Tensor, tmp: Tensor, use_less_memory: bool = True
420-
) -> Tensor:
422+
self, x: Tensor, val: Tensor, tmp: Tensor, use_less_memory: bool = True
423+
) -> Tensor:
421424
b, c1, h1, w1 = x.shape
422425
c2, h2, w2 = val.shape[1:]
423426
kernel_h, kernel_w = self.kernel_size
424427
num_of_rows = tmp.shape[-2]
425428

426429
# expand rows as cubes [(output channel)x(output height)x(output width)]
427-
tmp_rows = tmp.movedim(-1,-2).reshape(b, c2, h2, w2, num_of_rows)
430+
tmp_rows = tmp.movedim(-1, -2).reshape(b, c2, h2, w2, num_of_rows)
428431
# see rows as columns of the transposed matrix
429432
tmpt_cols = tmp_rows
430433
# transpose the images in (output height)x(output width)
431434
tmpt_cols = torch.flip(tmpt_cols, [-3, -2])
432435
# switch batch size and output channel
433-
tmpt_cols = tmpt_cols.movedim(0,1)
436+
tmpt_cols = tmpt_cols.movedim(0, 1)
434437

435438
if use_less_memory:
436439

437-
tmp_J = torch.zeros(b, c2*c1*kernel_h*kernel_w, num_of_rows, device=x.device)
440+
tmp_J = torch.zeros(b, c2 * c1 * kernel_h * kernel_w, num_of_rows, device=x.device)
438441
for i in range(b):
439442
# set the weight to the convolution
440-
input_single_batch = x[i:i+1,:,:,:]
441-
reversed_input_single_batch = torch.flip(input_single_batch, [-2,-1]).movedim(0,1)
442-
443-
tmp_single_batch = tmpt_cols[:,i:i+1,:,:,:]
443+
input_single_batch = x[i : i + 1, :, :, :]
444+
reversed_input_single_batch = torch.flip(input_single_batch, [-2, -1]).movedim(0, 1)
445+
446+
tmp_single_batch = tmpt_cols[:, i : i + 1, :, :, :]
444447

445448
# convolve each column
446449
tmp_J_single_batch = (
447450
F.conv2d(
448-
tmpt_cols.movedim((1, 2, 3), (-3, -2, -1)).reshape(-1, 1, h2, w2),
451+
tmp_single_batch.movedim((1, 2, 3), (-3, -2, -1)).reshape(-1, 1, h2, w2),
449452
weight=reversed_input_single_batch,
450453
bias=None,
451454
stride=self.stride,
@@ -458,14 +461,14 @@ def _jacobian_wrt_weight_mult_left(
458461
)
459462

460463
# reshape as a (num of weights)x(num of column) matrix
461-
tmp_J_single_batch = tmp_J_single_batch.reshape(c2*c1*kernel_h*kernel_w, num_of_rows)
464+
tmp_J_single_batch = tmp_J_single_batch.reshape(c2 * c1 * kernel_h * kernel_w, num_of_rows)
462465
tmp_J[i, :, :] = tmp_J_single_batch
463466

464467
# transpose
465-
tmp_J = tmp_J.movedim(-1,-2)
466-
else:
468+
tmp_J = tmp_J.movedim(-1, -2)
469+
else:
467470
# set the weight to the convolution
468-
reversed_inputs = torch.flip(x, [-2,-1]).movedim(0,1)
471+
reversed_inputs = torch.flip(x, [-2, -1]).movedim(0, 1)
469472

470473
# convolve each column
471474
Jt_tmptt_cols = (
@@ -483,9 +486,9 @@ def _jacobian_wrt_weight_mult_left(
483486
)
484487

485488
# reshape as a (num of input)x(num of output) matrix, one for each batch size
486-
Jt_tmptt_cols = Jt_tmptt_cols.reshape(c2*c1*kernel_h*kernel_w,num_of_rows)
489+
Jt_tmptt_cols = Jt_tmptt_cols.reshape(c2 * c1 * kernel_h * kernel_w, num_of_rows)
487490
# transpose
488-
tmp_J = Jt_tmptt_cols.movedim(0,1)
491+
tmp_J = Jt_tmptt_cols.movedim(0, 1)
489492

490493
return tmp
491494

@@ -495,7 +498,9 @@ def _jacobian_wrt_input_sandwich(self, x: Tensor, val: Tensor, tmp: Tensor, diag
495498
else:
496499
return self._jacobian_wrt_input_full_sandwich(x, val, tmp)
497500

498-
def _jacobian_wrt_weight_sandwich(self, x: Tensor, val: Tensor, tmp: Tensor, diag: bool = False) -> Tensor:
501+
def _jacobian_wrt_weight_sandwich(
502+
self, x: Tensor, val: Tensor, tmp: Tensor, diag: bool = False
503+
) -> Tensor:
499504
if diag:
500505
return self._jacobian_wrt_weight_diag_sandwich(x, val, tmp)
501506
else:
@@ -652,7 +657,9 @@ def _jacobian_wrt_input_mult_left_vec(self, x: Tensor, val: Tensor, jac_in: Tens
652657
def _jacobian_wrt_input_sandwich(self, x: Tensor, val: Tensor, tmp: Tensor, diag: bool = False) -> Tensor:
653658
return tmp
654659

655-
def _jacobian_wrt_weight_sandwich(self, x: Tensor, val: Tensor, tmp: Tensor, diag: bool = False) -> Tensor:
660+
def _jacobian_wrt_weight_sandwich(
661+
self, x: Tensor, val: Tensor, tmp: Tensor, diag: bool = False
662+
) -> Tensor:
656663
return None
657664

658665

@@ -675,7 +682,9 @@ def _jacobian_wrt_input_mult_left_vec(self, x: Tensor, val: Tensor, jac_in: Tens
675682
def _jacobian_wrt_input_sandwich(self, x: Tensor, val: Tensor, tmp: Tensor, diag: bool = False) -> Tensor:
676683
return tmp
677684

678-
def _jacobian_wrt_weight_sandwich(self, x: Tensor, val: Tensor, tmp: Tensor, diag: bool = False) -> Tensor:
685+
def _jacobian_wrt_weight_sandwich(
686+
self, x: Tensor, val: Tensor, tmp: Tensor, diag: bool = False
687+
) -> Tensor:
679688
return None
680689

681690

@@ -785,7 +794,9 @@ def _jacobian_wrt_input_mult_left_vec(self, x: Tensor, val: Tensor, jac_in: Tens
785794
jac_in = jac_in[arange_repeated, idx, :, :, :].reshape(*val.shape, *jac_in_orig_shape[4:])
786795
return jac_in
787796

788-
def _jacobian_wrt_weight_sandwich(self, x: Tensor, val: Tensor, tmp: Tensor, diag: bool = False) -> Tensor:
797+
def _jacobian_wrt_weight_sandwich(
798+
self, x: Tensor, val: Tensor, tmp: Tensor, diag: bool = False
799+
) -> Tensor:
789800
# non parametric, so return empty
790801
return None
791802

@@ -800,9 +811,9 @@ def _jacobian_wrt_input_diag_sandwich(self, x: Tensor, val: Tensor, tmp: Tensor)
800811
new_tmp = new_tmp.reshape(b * c1, h1 * w1)
801812
idx = self.idx.reshape(b * c2, h2 * w2)
802813
arange_repeated = torch.repeat_interleave(torch.arange(b * c1), h2 * w2).long()
803-
arange_repeated = arange_repeated.reshape(b*c2, h2*w2)
804-
805-
new_tmp[arange_repeated, idx] = tmp.reshape(b*c2, h2*w2)
814+
arange_repeated = arange_repeated.reshape(b * c2, h2 * w2)
815+
816+
new_tmp[arange_repeated, idx] = tmp.reshape(b * c2, h2 * w2)
806817

807818
return new_tmp.reshape(b, c1 * h1 * w1)
808819

@@ -892,7 +903,9 @@ def _jacobian(self, x: Tensor, val: Tensor) -> Tensor:
892903
jac = 1.0 - val**2
893904
return jac
894905

895-
def _jacobian_wrt_weight_sandwich(self, x: Tensor, val: Tensor, tmp: Tensor, diag: bool = False) -> Tensor:
906+
def _jacobian_wrt_weight_sandwich(
907+
self, x: Tensor, val: Tensor, tmp: Tensor, diag: bool = False
908+
) -> Tensor:
896909
# non parametric, so return empty
897910
return None
898911

stochman/utils.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ def forward(ctx, M, p0: torch.Tensor, p1: torch.Tensor):
1111
dist2 = dist**2
1212

1313
lm0 = C.deriv(torch.zeros(1, device=p0.device)).squeeze(1) # log(p0, p1); Bx(d)
14-
lm1 = -C.deriv(torch.ones(1, device=p0.device)).squeeze(1) # log(p1, p0); Bx(d)
14+
lm1 = -C.deriv(torch.ones(1, device=p0.device)).squeeze(1) # log(p1, p0); Bx(d)
1515
G0 = M.metric(p0) # Bx(d)x(d) or Bx(d)
1616
G1 = M.metric(p1) # Bx(d)x(d) or Bx(d)
1717
if G0.ndim == 3: # metric is square
@@ -32,9 +32,7 @@ def forward(ctx, M, p0: torch.Tensor, p1: torch.Tensor):
3232
@staticmethod
3333
def backward(ctx, grad_output):
3434
Glm0, Glm1 = ctx.saved_tensors
35-
return (None,
36-
2.0 * grad_output.view(-1, 1) * Glm0,
37-
2.0 * grad_output.view(-1, 1) * Glm1)
35+
return (None, 2.0 * grad_output.view(-1, 1) * Glm0, 2.0 * grad_output.view(-1, 1) * Glm1)
3836

3937

4038
def squared_manifold_distance(manifold, p0: torch.Tensor, p1: torch.Tensor):
@@ -53,9 +51,9 @@ def squared_manifold_distance(manifold, p0: torch.Tensor, p1: torch.Tensor):
5351

5452

5553
def tensor_reduction(x: torch.Tensor, reduction: str):
56-
if reduction == 'sum':
54+
if reduction == "sum":
5755
return x.sum()
58-
elif reduction == 'mean':
56+
elif reduction == "mean":
5957
return x.mean()
6058
elif reduction is None or reduction == "none":
6159
return x

tests/test_curves.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ def test_constant_speed(self, curve_class):
113113
assert isinstance(curve_length, torch.Tensor)
114114
assert new_t.shape == (batch_size, timesteps)
115115
assert Ct.shape == (batch_size, timesteps, dim)
116-
assert curve_length.shape == (batch_size, )
116+
assert curve_length.shape == (batch_size,)
117117

118118
def test_plotting_in_axis(self, curve_class):
119119
batch_size = 5

tests/test_nnj.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020

2121
def _compare_jacobian(f: Callable, x: torch.Tensor) -> torch.Tensor:
22-
""" Use pytorch build-in jacobian function to compare for correctness of computations"""
22+
"""Use pytorch build-in jacobian function to compare for correctness of computations"""
2323
out = f(x)
2424
output = torch.autograd.functional.jacobian(f, x)
2525
m = out.ndim
@@ -153,7 +153,7 @@ def test_jacobians(self, model, input_shape, device, dtype):
153153

154154
@pytest.mark.parametrize("return_jac", [True, False])
155155
def test_jac_return(self, model, input_shape, device, return_jac):
156-
""" Test that all models returns the jacobian output if asked for it """
156+
"""Test that all models returns the jacobian output if asked for it"""
157157
if "cuda" in device and not torch.cuda.is_available():
158158
pytest.skip("Test requires cuda support")
159159

0 commit comments

Comments
 (0)