Skip to content

Commit 49c8e29

Browse files
authored
Merge pull request #19 from FrederikWarburg/master
Extend nnj functionalities
2 parents 9280b54 + f510a87 commit 49c8e29

File tree

7 files changed

+756
-51
lines changed

7 files changed

+756
-51
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ If you want to cite the framework feel free to use this (but only if you loved i
125125
```bibtex
126126
@article{software:stochman,
127127
title={StochMan},
128-
author={Nicki S. Detlefsen and Alison Pouplin and Cilie W. Feldager and Cong Geng and Dimitris Kalatzis and Helene Hauschultz and Miguel González Duque and Søren Hauberg},
128+
author={Nicki S. Detlefsen and Alison Pouplin and Cilie W. Feldager and Cong Geng and Dimitris Kalatzis and Helene Hauschultz and Miguel González Duque and Frederik Warburg and Marco Miani and Søren Hauberg},
129129
journal={GitHub. Note: https://github.com/MachineLearningLifeScience/stochman/},
130130
year={2021}
131131
}

stochman/curves.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -39,10 +39,7 @@ def __init__(
3939
_begin = begin.detach() # BxD
4040
_end = end.detach() # BxD
4141
else:
42-
raise ValueError(
43-
"BasicCurve.__init__ requires begin and end points to have "
44-
"the same shape"
45-
)
42+
raise ValueError("BasicCurve.__init__ requires begin and end points to have " "the same shape")
4643

4744
# register begin and end as buffers
4845
self.register_buffer("begin", _begin) # BxD
@@ -82,6 +79,7 @@ def plot(
8279
"""
8380
with torch.no_grad():
8481
import matplotlib.pyplot as plt
82+
8583
t = torch.linspace(t0, t1, N, dtype=self.begin.dtype, device=self.device)
8684
points = self(t) # NxD or BxNxD
8785

@@ -126,7 +124,7 @@ def euclidean_length(self, t0: float = 0.0, t1: float = 1.0, N: int = 100) -> to
126124
if not is_batched:
127125
points = points.unsqueeze(0) # 1xNxD
128126
delta = points[:, 1:] - points[:, :-1] # Bx(N-1)xD
129-
energies = (delta ** 2).sum(dim=2) # Bx(N-1)
127+
energies = (delta**2).sum(dim=2) # Bx(N-1)
130128
lengths = energies.sqrt().sum(dim=1) # B
131129
return lengths
132130

stochman/manifold.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -444,7 +444,7 @@ def curve_energy(self, curve: BasicCurve, reduction: Optional[str] = "sum", dt=N
444444
emb_curve = self.embed(curve) # BxNxD
445445
B, N, D = emb_curve.shape
446446
delta = emb_curve[:, 1:, :] - emb_curve[:, :-1, :] # Bx(N-1)xD
447-
energy = (delta ** 2).sum((1, 2)) * dt # B
447+
energy = (delta**2).sum((1, 2)) * dt # B
448448
return tensor_reduction(energy, reduction)
449449

450450
def curve_length(self, curve: BasicCurve, dt=None):
@@ -541,7 +541,7 @@ def __init__(self, data, sigma, rho, device=None):
541541
"""
542542
super().__init__()
543543
self.data = data
544-
self.sigma2 = sigma ** 2
544+
self.sigma2 = sigma**2
545545
self.rho = rho
546546
self.device = device
547547

@@ -586,7 +586,7 @@ def metric(self, c, return_deriv=False):
586586
if return_deriv:
587587
weighted_delta = (w_p / sigma2).reshape(-1, 1).expand(-1, D) * delta # NxD
588588
dSdc = 2.0 * torch.diag(w_p.mm(delta).flatten()) - weighted_delta.t().mm(delta2) # DxD
589-
dM = dSdc.t() * (m ** 2).reshape(-1, 1).expand(-1, D) # DxD
589+
dM = dSdc.t() * (m**2).reshape(-1, 1).expand(-1, D) # DxD
590590
dMdc.append(dM.reshape(1, D, D))
591591

592592
if return_deriv:

0 commit comments

Comments
 (0)