Skip to content

Commit 538c1d2

Browse files
committed
select(), rank(), and set_tree()
1 parent 9bbfd1c commit 538c1d2

File tree

4 files changed

+58
-31
lines changed

4 files changed

+58
-31
lines changed

pydatastructs/trees/_backend/cpp/AVLTree.hpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,13 @@ static PyObject* AVLTree___str__(AVLTree *self) {
4040
return BinarySearchTree___str__(self->sbbt->bst);
4141
}
4242

43+
static PyObject* AVLTree_set_tree(AVLTree* self, PyObject *args) {
44+
ArrayForTrees* arr = reinterpret_cast<ArrayForTrees*>(PyObject_GetItem(args, PyZero));
45+
self->sbbt->bst->binary_tree->tree = arr;
46+
self->tree = self->sbbt->bst->binary_tree->tree;
47+
Py_RETURN_NONE;
48+
}
49+
4350
static PyObject* AVLTree_search(AVLTree* self, PyObject *args, PyObject *kwds) {
4451
return BinarySearchTree_search(self->sbbt->bst, args, kwds);
4552
}
@@ -236,10 +243,21 @@ static PyObject* AVLTree_insert(AVLTree* self, PyObject *args) {
236243
Py_RETURN_NONE;
237244
}
238245

246+
static PyObject* AVLTree_rank(AVLTree* self, PyObject *args) {
247+
return BinarySearchTree_rank(self->sbbt->bst, args);
248+
}
249+
250+
static PyObject* AVLTree_select(AVLTree* self, PyObject *args) {
251+
return BinarySearchTree_select(self->sbbt->bst, args);
252+
}
253+
239254
static struct PyMethodDef AVLTree_PyMethodDef[] = {
240255
{"search", (PyCFunction) AVLTree_search, METH_VARARGS | METH_KEYWORDS, NULL},
241256
{"insert", (PyCFunction) AVLTree_insert, METH_VARARGS, NULL},
257+
{"set_tree", (PyCFunction) AVLTree_set_tree, METH_VARARGS, NULL},
242258
{"balance_factor", (PyCFunction) AVLTree_balance_factor, METH_VARARGS, NULL},
259+
{"rank", (PyCFunction) AVLTree_rank, METH_VARARGS, NULL},
260+
{"select", (PyCFunction) AVLTree_select, METH_VARARGS, NULL},
243261
{"_left_right_rotate", (PyCFunction) AVLTree__left_right_rotate, METH_VARARGS, NULL},
244262
{"_right_left_rotate", (PyCFunction) AVLTree__right_left_rotate, METH_VARARGS, NULL},
245263
{"_left_rotate", (PyCFunction) AVLTree__left_rotate, METH_VARARGS, NULL},

pydatastructs/trees/_backend/cpp/BinarySearchTree.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,8 +84,9 @@ static PyObject* BinarySearchTree_search(BinarySearchTree* self, PyObject* args,
8484
return NULL;
8585
}
8686
BinaryTree* bt = self->binary_tree;
87+
Py_INCREF(Py_None);
8788
PyObject* parent = Py_None;
88-
PyObject* walk = PyLong_FromLong(PyLong_AsLong(bt->root_idx));
89+
PyObject* walk = bt->root_idx;
8990

9091
if (reinterpret_cast<TreeNode*>(bt->tree->_one_dimensional_array->_data[PyLong_AsLong(walk)])->key == Py_None) {
9192
Py_RETURN_NONE;

pydatastructs/trees/binary_trees.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -932,7 +932,7 @@ def __new__(cls, key=None, root_data=None, comp=None,
932932

933933
@classmethod
934934
def methods(cls):
935-
return ['__new__', 'insert', 'delete']
935+
return ['__new__', 'set_tree', 'insert', 'delete']
936936

937937
left_height = lambda self, node: self.tree[node.left].height \
938938
if node.left is not None else -1
@@ -941,6 +941,9 @@ def methods(cls):
941941
balance_factor = lambda self, node: self.right_height(node) - \
942942
self.left_height(node)
943943

944+
def set_tree(self, arr):
945+
self.tree = arr
946+
944947
def _right_rotate(self, j, k):
945948
super(AVLTree, self)._right_rotate(j, k)
946949
self.tree[j].height = max(self.left_height(self.tree[j]),

pydatastructs/trees/tests/test_binary_trees.py

Lines changed: 34 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -214,8 +214,9 @@ def _test_AVLTree(backend):
214214
assert [node.key for node in in_order] == [1]
215215
assert [node.key for node in pre_order] == [1]
216216

217-
a3 = AVLTree(0,0,backend=backend)
218-
for i in range(1,7):
217+
a3 = AVLTree()
218+
a3.set_tree( ArrayForTrees(TreeNode, 0, backend=backend) )
219+
for i in range(0,7):
219220
a3.tree.append(TreeNode(i, i, backend=backend))
220221
a3.tree[0].left = 1
221222
a3.tree[0].right = 6
@@ -224,15 +225,17 @@ def _test_AVLTree(backend):
224225
a3.tree[2].left = 3
225226
a3.tree[2].right = 4
226227
a3._left_right_rotate(0, 1)
228+
assert str(a3) == "[(4, 0, 0, 6), (5, 1, 1, 3), (1, 2, 2, 0), (None, 3, 3, None), (None, 4, 4, None), (None, 5, 5, None), (None, 6, 6, None)]"
227229

228230
trav = BinaryTreeTraversal(a3, backend=backend)
229231
in_order = trav.depth_first_search(order='in_order')
230232
pre_order = trav.depth_first_search(order='pre_order')
231233
assert [node.key for node in in_order] == [5, 1, 3, 2, 4, 0, 6]
232234
assert [node.key for node in pre_order] == [2, 1, 5, 3, 0, 4, 6]
233235

234-
a4 = AVLTree(0,0,backend=backend)
235-
for i in range(1,7):
236+
a4 = AVLTree()
237+
a4.set_tree( ArrayForTrees(TreeNode, 0, backend=backend) )
238+
for i in range(0,7):
236239
a4.tree.append(TreeNode(i, i,backend=backend))
237240
a4.tree[0].left = 1
238241
a4.tree[0].right = 2
@@ -250,7 +253,7 @@ def _test_AVLTree(backend):
250253

251254
a5 = AVLTree(is_order_statistic=True,backend=backend)
252255
if backend==Backend.PYTHON:
253-
a5.tree = ArrayForTrees(TreeNode, [
256+
a5.set_tree( ArrayForTrees(TreeNode, [
254257
TreeNode(10, 10),
255258
TreeNode(5, 5),
256259
TreeNode(17, 17),
@@ -265,9 +268,9 @@ def _test_AVLTree(backend):
265268
TreeNode(30, 30),
266269
TreeNode(13, 13),
267270
TreeNode(33, 33)
268-
])
271+
]) )
269272
else:
270-
a5.tree = ArrayForTrees(_nodes.TreeNode, [
273+
a5.set_tree( ArrayForTrees(_nodes.TreeNode, [
271274
TreeNode(10, 10,backend=backend),
272275
TreeNode(5, 5,backend=backend),
273276
TreeNode(17, 17,backend=backend),
@@ -282,7 +285,7 @@ def _test_AVLTree(backend):
282285
TreeNode(30, 30,backend=backend),
283286
TreeNode(13, 13,backend=backend),
284287
TreeNode(33, 33,backend=backend)
285-
],backend=backend)
288+
],backend=backend) )
286289

287290
a5.tree[0].left, a5.tree[0].right, a5.tree[0].parent, a5.tree[0].height = \
288291
1, 2, None, 4
@@ -328,30 +331,32 @@ def _test_AVLTree(backend):
328331
a5.tree[11].size = 2
329332
a5.tree[12].size = 1
330333
a5.tree[13].size = 1
331-
332-
# assert raises(ValueError, lambda: a5.select(0))
333-
# assert raises(ValueError, lambda: a5.select(15))
334-
# assert a5.rank(-1) is None
335-
# def test_select_rank(expected_output):
336-
# output = []
337-
# for i in range(len(expected_output)):
338-
# output.append(a5.select(i + 1).key)
339-
# assert output == expected_output
340-
341-
# output = []
342-
# expected_ranks = [i + 1 for i in range(len(expected_output))]
343-
# for i in range(len(expected_output)):
344-
# output.append(a5.rank(expected_output[i]))
345-
# assert output == expected_ranks
346-
347-
# test_select_rank([2, 3, 5, 9, 10, 11, 12, 13, 15, 17, 18, 20, 30, 33])
334+
assert str(a5) == "[(1, 10, 10, 2), (3, 5, 5, 4), (5, 17, 17, 6), (None, 2, 2, 7), (None, 9, 9, None), (8, 12, 12, 9), (10, 20, 20, 11), (None, 3, 3, None), (None, 11, 11, None), (12, 15, 15, None), (None, 18, 18, None), (None, 30, 30, 13), (None, 13, 13, None), (None, 33, 33, None)]"
335+
336+
assert raises(ValueError, lambda: a5.select(0))
337+
assert raises(ValueError, lambda: a5.select(15))
338+
339+
assert a5.rank(-1) is None
340+
def test_select_rank(expected_output):
341+
output = []
342+
for i in range(len(expected_output)):
343+
output.append(a5.select(i + 1).key)
344+
assert output == expected_output
345+
346+
output = []
347+
expected_ranks = [i + 1 for i in range(len(expected_output))]
348+
for i in range(len(expected_output)):
349+
output.append(a5.rank(expected_output[i]))
350+
assert output == expected_ranks
351+
352+
test_select_rank([2, 3, 5, 9, 10, 11, 12, 13, 15, 17, 18, 20, 30, 33])
348353
# a5.delete(9)
349354
# a5.delete(13)
350355
# a5.delete(20)
351356

352-
# trav = BinaryTreeTraversal(a5)
353-
# in_order = trav.depth_first_search(order='in_order')
354-
# pre_order = trav.depth_first_search(order='pre_order')
357+
trav = BinaryTreeTraversal(a5)
358+
in_order = trav.depth_first_search(order='in_order')
359+
pre_order = trav.depth_first_search(order='pre_order')
355360
# assert [node.key for node in in_order] == [2, 3, 5, 10, 11, 12, 15, 17, 18, 30, 33]
356361
# assert [node.key for node in pre_order] == [17, 10, 3, 2, 5, 12, 11, 15, 30, 18, 33]
357362

@@ -384,7 +389,7 @@ def test_cpp_AVLTree():
384389
_test_AVLTree(backend=Backend.CPP)
385390

386391
test_AVLTree()
387-
test_cpp_AVLTree()
392+
# test_cpp_AVLTree()
388393

389394
def _test_BinaryIndexedTree(backend):
390395

0 commit comments

Comments
 (0)