Skip to content

Commit cfe0596

Browse files
committed
fix pairwise option
1 parent ab50aba commit cfe0596

File tree

1 file changed

+11
-24
lines changed

1 file changed

+11
-24
lines changed

GraphRicciCurvature/OllivierRicci.py

Lines changed: 11 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,17 @@ def _distribute_densities(source, target):
149149
# construct the cost dictionary from x to y
150150
t0 = time.time()
151151

152-
d = _apsp[np.ix_(source_topknbr, target_topknbr)] # transportation matrix
152+
if _shortest_path == "pairwise":
153+
d = []
154+
for src in source_topknbr:
155+
tmp = []
156+
for tgt in target_topknbr:
157+
tmp.append(_source_target_shortest_path(src, tgt))
158+
d.append(tmp)
159+
d = np.array(d)
160+
else: # all_pairs
161+
d = _apsp[np.ix_(source_topknbr, target_topknbr)] # transportation matrix
162+
153163
x = np.array([x]).T # the mass that source neighborhood initially owned
154164
y = np.array([y]).T # the mass that target neighborhood needs to received
155165

@@ -181,29 +191,6 @@ def _source_target_shortest_path(source, target):
181191
return length
182192

183193

184-
def _get_pairwise_sp(source, target):
185-
"""Compute pairwise shortest path from `source` to `target`.
186-
187-
Parameters
188-
----------
189-
source : int
190-
Source node index in Networkit graph `_Gk`.
191-
target : int
192-
Target node index in Networkit graph `_Gk`.
193-
194-
Returns
195-
-------
196-
length : float
197-
Pairwise shortest path length.
198-
199-
"""
200-
201-
if _shortest_path == "pairwise":
202-
return _source_target_shortest_path(source, target)
203-
204-
return _apsp[source][target]
205-
206-
207194
def _get_all_pairs_shortest_path():
208195
"""Pre-compute all pairs shortest paths of the assigned graph `_Gk`."""
209196
logger.info("Start to compute all pair shortest path.")

0 commit comments

Comments
 (0)