Skip to content

Commit 4567245

Browse files
authored
MAINT cython typedefs in _quad_tree (scikit-learn#27351)
1 parent e9b3d1c commit 4567245

File tree

2 files changed

+63
-67
lines changed

2 files changed

+63
-67
lines changed

sklearn/neighbors/_quad_tree.pxd

Lines changed: 34 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,7 @@
44
# See quad_tree.pyx for details.
55

66
cimport numpy as cnp
7-
8-
ctypedef cnp.npy_float32 DTYPE_t # Type of X
9-
ctypedef cnp.npy_intp SIZE_t # Type for indices and counters
10-
ctypedef cnp.npy_int32 INT32_t # Signed 32 bit integer
11-
ctypedef cnp.npy_uint32 UINT32_t # Unsigned 32 bit integer
7+
from ..utils._typedefs cimport float32_t, intp_t
128

139
# This is effectively an ifdef statement in Cython
1410
# It allows us to write printf debugging lines
@@ -25,26 +21,26 @@ cdef struct Cell:
2521
# Base storage structure for cells in a QuadTree object
2622

2723
# Tree structure
28-
SIZE_t parent # Parent cell of this cell
29-
SIZE_t[8] children # Array pointing to children of this cell
24+
intp_t parent # Parent cell of this cell
25+
intp_t[8] children # Array pointing to children of this cell
3026

3127
# Cell description
32-
SIZE_t cell_id # Id of the cell in the cells array in the Tree
33-
SIZE_t point_index # Index of the point at this cell (only defined
34-
# # in non empty leaf)
35-
bint is_leaf # Does this cell have children?
36-
DTYPE_t squared_max_width # Squared value of the maximum width w
37-
SIZE_t depth # Depth of the cell in the tree
38-
SIZE_t cumulative_size # Number of points included in the subtree with
39-
# # this cell as a root.
28+
intp_t cell_id # Id of the cell in the cells array in the Tree
29+
intp_t point_index # Index of the point at this cell (only defined
30+
# # in non empty leaf)
31+
bint is_leaf # Does this cell have children?
32+
float32_t squared_max_width # Squared value of the maximum width w
33+
intp_t depth # Depth of the cell in the tree
34+
intp_t cumulative_size # Number of points included in the subtree with
35+
# # this cell as a root.
4036

4137
# Internal constants
42-
DTYPE_t[3] center # Store the center for quick split of cells
43-
DTYPE_t[3] barycenter # Keep track of the center of mass of the cell
38+
float32_t[3] center # Store the center for quick split of cells
39+
float32_t[3] barycenter # Keep track of the center of mass of the cell
4440

4541
# Cell boundaries
46-
DTYPE_t[3] min_bounds # Inferior boundaries of this cell (inclusive)
47-
DTYPE_t[3] max_bounds # Superior boundaries of this cell (exclusive)
42+
float32_t[3] min_bounds # Inferior boundaries of this cell (inclusive)
43+
float32_t[3] max_bounds # Superior boundaries of this cell (exclusive)
4844

4945

5046
cdef class _QuadTree:
@@ -57,40 +53,40 @@ cdef class _QuadTree:
5753
# Parameters of the tree
5854
cdef public int n_dimensions # Number of dimensions in X
5955
cdef public int verbose # Verbosity of the output
60-
cdef SIZE_t n_cells_per_cell # Number of children per node. (2 ** n_dimension)
56+
cdef intp_t n_cells_per_cell # Number of children per node. (2 ** n_dimension)
6157

6258
# Tree inner structure
63-
cdef public SIZE_t max_depth # Max depth of the tree
64-
cdef public SIZE_t cell_count # Counter for node IDs
65-
cdef public SIZE_t capacity # Capacity of tree, in terms of nodes
66-
cdef public SIZE_t n_points # Total number of points
59+
cdef public intp_t max_depth # Max depth of the tree
60+
cdef public intp_t cell_count # Counter for node IDs
61+
cdef public intp_t capacity # Capacity of tree, in terms of nodes
62+
cdef public intp_t n_points # Total number of points
6763
cdef Cell* cells # Array of nodes
6864

6965
# Point insertion methods
70-
cdef int insert_point(self, DTYPE_t[3] point, SIZE_t point_index,
71-
SIZE_t cell_id=*) except -1 nogil
72-
cdef SIZE_t _insert_point_in_new_child(self, DTYPE_t[3] point, Cell* cell,
73-
SIZE_t point_index, SIZE_t size=*
66+
cdef int insert_point(self, float32_t[3] point, intp_t point_index,
67+
intp_t cell_id=*) except -1 nogil
68+
cdef intp_t _insert_point_in_new_child(self, float32_t[3] point, Cell* cell,
69+
intp_t point_index, intp_t size=*
7470
) noexcept nogil
75-
cdef SIZE_t _select_child(self, DTYPE_t[3] point, Cell* cell) noexcept nogil
76-
cdef bint _is_duplicate(self, DTYPE_t[3] point1, DTYPE_t[3] point2) noexcept nogil
71+
cdef intp_t _select_child(self, float32_t[3] point, Cell* cell) noexcept nogil
72+
cdef bint _is_duplicate(self, float32_t[3] point1, float32_t[3] point2) noexcept nogil
7773

7874
# Create a summary of the Tree compare to a query point
79-
cdef long summarize(self, DTYPE_t[3] point, DTYPE_t* results,
80-
float squared_theta=*, SIZE_t cell_id=*, long idx=*
75+
cdef long summarize(self, float32_t[3] point, float32_t* results,
76+
float squared_theta=*, intp_t cell_id=*, long idx=*
8177
) noexcept nogil
8278

8379
# Internal cell initialization methods
84-
cdef void _init_cell(self, Cell* cell, SIZE_t parent, SIZE_t depth) noexcept nogil
85-
cdef void _init_root(self, DTYPE_t[3] min_bounds, DTYPE_t[3] max_bounds
80+
cdef void _init_cell(self, Cell* cell, intp_t parent, intp_t depth) noexcept nogil
81+
cdef void _init_root(self, float32_t[3] min_bounds, float32_t[3] max_bounds
8682
) noexcept nogil
8783

8884
# Private methods
89-
cdef int _check_point_in_cell(self, DTYPE_t[3] point, Cell* cell
85+
cdef int _check_point_in_cell(self, float32_t[3] point, Cell* cell
9086
) except -1 nogil
9187

9288
# Private array manipulation to manage the ``cells`` array
93-
cdef int _resize(self, SIZE_t capacity) except -1 nogil
94-
cdef int _resize_c(self, SIZE_t capacity=*) except -1 nogil
95-
cdef int _get_cell(self, DTYPE_t[3] point, SIZE_t cell_id=*) except -1 nogil
89+
cdef int _resize(self, intp_t capacity) except -1 nogil
90+
cdef int _resize_c(self, intp_t capacity=*) except -1 nogil
91+
cdef int _get_cell(self, float32_t[3] point, intp_t cell_id=*) except -1 nogil
9692
cdef Cell[:] _get_cell_ndarray(self)

sklearn/neighbors/_quad_tree.pyx

Lines changed: 29 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -80,11 +80,11 @@ cdef class _QuadTree:
8080
"""Build a tree from an array of points X."""
8181
cdef:
8282
int i
83-
DTYPE_t[3] pt
84-
DTYPE_t[3] min_bounds, max_bounds
83+
float32_t[3] pt
84+
float32_t[3] min_bounds, max_bounds
8585

8686
# validate X and prepare for query
87-
# X = check_array(X, dtype=DTYPE_t, order='C')
87+
# X = check_array(X, dtype=float32_t, order='C')
8888
n_samples = X.shape[0]
8989

9090
capacity = 100
@@ -113,13 +113,13 @@ cdef class _QuadTree:
113113
# Shrink the cells array to reduce memory usage
114114
self._resize(capacity=self.cell_count)
115115

116-
cdef int insert_point(self, DTYPE_t[3] point, SIZE_t point_index,
117-
SIZE_t cell_id=0) except -1 nogil:
116+
cdef int insert_point(self, float32_t[3] point, intp_t point_index,
117+
intp_t cell_id=0) except -1 nogil:
118118
"""Insert a point in the QuadTree."""
119119
cdef int ax
120-
cdef SIZE_t selected_child
120+
cdef intp_t selected_child
121121
cdef Cell* cell = &self.cells[cell_id]
122-
cdef SIZE_t n_point = cell.cumulative_size
122+
cdef intp_t n_point = cell.cumulative_size
123123

124124
if self.verbose > 10:
125125
printf("[QuadTree] Inserting depth %li\n", cell.depth)
@@ -177,16 +177,16 @@ cdef class _QuadTree:
177177
return self.insert_point(point, point_index, cell_id)
178178

179179
# XXX: This operation is not Thread safe
180-
cdef SIZE_t _insert_point_in_new_child(
181-
self, DTYPE_t[3] point, Cell* cell, SIZE_t point_index, SIZE_t size=1
180+
cdef intp_t _insert_point_in_new_child(
181+
self, float32_t[3] point, Cell* cell, intp_t point_index, intp_t size=1
182182
) noexcept nogil:
183183
"""Create a child of cell which will contain point."""
184184

185185
# Local variable definition
186186
cdef:
187-
SIZE_t cell_id, cell_child_id, parent_id
188-
DTYPE_t[3] save_point
189-
DTYPE_t width
187+
intp_t cell_id, cell_child_id, parent_id
188+
float32_t[3] save_point
189+
float32_t width
190190
Cell* child
191191
int i
192192

@@ -247,7 +247,7 @@ cdef class _QuadTree:
247247

248248
return cell_id
249249

250-
cdef bint _is_duplicate(self, DTYPE_t[3] point1, DTYPE_t[3] point2) noexcept nogil:
250+
cdef bint _is_duplicate(self, float32_t[3] point1, float32_t[3] point2) noexcept nogil:
251251
"""Check if the two given points are equals."""
252252
cdef int i
253253
cdef bint res = True
@@ -256,11 +256,11 @@ cdef class _QuadTree:
256256
res &= fabsf(point1[i] - point2[i]) <= EPSILON
257257
return res
258258

259-
cdef SIZE_t _select_child(self, DTYPE_t[3] point, Cell* cell) noexcept nogil:
259+
cdef intp_t _select_child(self, float32_t[3] point, Cell* cell) noexcept nogil:
260260
"""Select the child of cell which contains the given query point."""
261261
cdef:
262262
int i
263-
SIZE_t selected_child = 0
263+
intp_t selected_child = 0
264264

265265
for i in range(self.n_dimensions):
266266
# Select the correct child cell to insert the point by comparing
@@ -270,7 +270,7 @@ cdef class _QuadTree:
270270
selected_child += 1
271271
return cell.children[selected_child]
272272

273-
cdef void _init_cell(self, Cell* cell, SIZE_t parent, SIZE_t depth) noexcept nogil:
273+
cdef void _init_cell(self, Cell* cell, intp_t parent, intp_t depth) noexcept nogil:
274274
"""Initialize a cell structure with some constants."""
275275
cell.parent = parent
276276
cell.is_leaf = True
@@ -280,12 +280,12 @@ cdef class _QuadTree:
280280
for i in range(self.n_cells_per_cell):
281281
cell.children[i] = SIZE_MAX
282282

283-
cdef void _init_root(self, DTYPE_t[3] min_bounds, DTYPE_t[3] max_bounds
283+
cdef void _init_root(self, float32_t[3] min_bounds, float32_t[3] max_bounds
284284
) noexcept nogil:
285285
"""Initialize the root node with the given space boundaries"""
286286
cdef:
287287
int i
288-
DTYPE_t width
288+
float32_t width
289289
Cell* root = &self.cells[0]
290290

291291
self._init_cell(root, -1, 0)
@@ -299,7 +299,7 @@ cdef class _QuadTree:
299299

300300
self.cell_count += 1
301301

302-
cdef int _check_point_in_cell(self, DTYPE_t[3] point, Cell* cell
302+
cdef int _check_point_in_cell(self, float32_t[3] point, Cell* cell
303303
) except -1 nogil:
304304
"""Check that the given point is in the cell boundaries."""
305305

@@ -366,8 +366,8 @@ cdef class _QuadTree:
366366
"in children."
367367
.format(self.n_points, self.cells[0].cumulative_size))
368368

369-
cdef long summarize(self, DTYPE_t[3] point, DTYPE_t* results,
370-
float squared_theta=.5, SIZE_t cell_id=0, long idx=0
369+
cdef long summarize(self, float32_t[3] point, float32_t* results,
370+
float squared_theta=.5, intp_t cell_id=0, long idx=0
371371
) noexcept nogil:
372372
"""Summarize the tree compared to a query point.
373373
@@ -429,7 +429,7 @@ cdef class _QuadTree:
429429
# Otherwise, we go a higher level of resolution and into the leaves.
430430
if cell.is_leaf or (
431431
(cell.squared_max_width / results[idx_d]) < squared_theta):
432-
results[idx_d + 1] = <DTYPE_t> cell.cumulative_size
432+
results[idx_d + 1] = <float32_t> cell.cumulative_size
433433
return idx + self.n_dimensions + 2
434434

435435
else:
@@ -446,7 +446,7 @@ cdef class _QuadTree:
446446
"""return the id of the cell containing the query point or raise
447447
ValueError if the point is not in the tree
448448
"""
449-
cdef DTYPE_t[3] query_pt
449+
cdef float32_t[3] query_pt
450450
cdef int i
451451

452452
assert len(point) == self.n_dimensions, (
@@ -458,14 +458,14 @@ cdef class _QuadTree:
458458

459459
return self._get_cell(query_pt, 0)
460460

461-
cdef int _get_cell(self, DTYPE_t[3] point, SIZE_t cell_id=0
461+
cdef int _get_cell(self, float32_t[3] point, intp_t cell_id=0
462462
) except -1 nogil:
463463
"""guts of get_cell.
464464
465465
Return the id of the cell containing the query point or raise ValueError
466466
if the point is not in the tree"""
467467
cdef:
468-
SIZE_t selected_child
468+
intp_t selected_child
469469
Cell* cell = &self.cells[cell_id]
470470

471471
if cell.is_leaf:
@@ -562,7 +562,7 @@ cdef class _QuadTree:
562562
raise ValueError("Can't initialize array!")
563563
return arr
564564

565-
cdef int _resize(self, SIZE_t capacity) except -1 nogil:
565+
cdef int _resize(self, intp_t capacity) except -1 nogil:
566566
"""Resize all inner arrays to `capacity`, if `capacity` == -1, then
567567
double the size of the inner arrays.
568568
@@ -574,7 +574,7 @@ cdef class _QuadTree:
574574
with gil:
575575
raise MemoryError()
576576

577-
cdef int _resize_c(self, SIZE_t capacity=SIZE_MAX) except -1 nogil:
577+
cdef int _resize_c(self, intp_t capacity=SIZE_MAX) except -1 nogil:
578578
"""Guts of _resize
579579
580580
Returns -1 in case of failure to allocate memory (and raise MemoryError)
@@ -598,10 +598,10 @@ cdef class _QuadTree:
598598
self.capacity = capacity
599599
return 0
600600

601-
def _py_summarize(self, DTYPE_t[:] query_pt, DTYPE_t[:, :] X, float angle):
601+
def _py_summarize(self, float32_t[:] query_pt, float32_t[:, :] X, float angle):
602602
# Used for testing summarize
603603
cdef:
604-
DTYPE_t[:] summary
604+
float32_t[:] summary
605605
int n_samples
606606

607607
n_samples = X.shape[0]

0 commit comments

Comments
 (0)