Skip to content

Commit 8dbd494

Browse files
committed
Fix unittest
1 parent f5de05b commit 8dbd494

File tree

1 file changed

+5
-6
lines changed

1 file changed

+5
-6
lines changed

test/test.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,7 @@
1717
# This is required to import the functions we need to test
1818
sys.path.append(ROOT_DIR)
1919

20-
# Finally import the space, vector, arithmetic operators, model, and parser utilities
21-
from hdlib.space import Space, Vector
20+
from hdlib import Space, Vector
2221
from hdlib.arithmetic import bundle, bind, permute
2322
from hdlib.model import ClassificationModel, GraphModel
2423

@@ -182,7 +181,7 @@ def test_mlmodel(self):
182181
# Collect the accuracy scores computed on each fold
183182
scores = list()
184183

185-
for y_indices, y_pred, _, _ in predictions:
184+
for y_indices, y_pred, _, _, _ in predictions:
186185
y_true = [label for position, label in enumerate(classes) if position in y_indices]
187186
accuracy = accuracy_score(y_true, y_pred)
188187

@@ -233,17 +232,17 @@ def test_graph(self):
233232
graph.fit(edges)
234233

235234
# Compute the error rate of the graph model based on its set of edge
236-
error_rate, _, _ = graph.error_rate(edges)
235+
error_rate, _ = graph.error_rate()
237236

238237
if error_rate > 0.0:
239238
# Mitigate the error rate, up to 10 iterations
240-
graph.error_mitigation(edges, max_iter=10)
239+
graph.error_mitigation(max_iter=10)
241240

242241
# Define the distance threshold to establish whether an edge exists in the graph model
243242
threshold = 0.7
244243

245244
# Check whether the edge <2, 3> exists
246-
edge_exists, dist = graph.edge_exists("2", "3", 0.2, threshold=threshold)
245+
edge_exists, _, _ = graph.edge_exists("2", "3", 0.2, threshold=threshold)
247246

248247
self.assertTrue(edge_exists)
249248

0 commit comments

Comments
 (0)