Skip to content

Commit a049091

Browse files
authored
KDTree: added 'queryIterable' method (#217)
1 parent 761cb21 commit a049091

File tree

5 files changed

+76
-1
lines changed

5 files changed

+76
-1
lines changed

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
# Changelog
22

3+
## 16.10.0
4+
- KDTree:
5+
- Added `queryIterable` method
6+
37
## 16.9.0
48
- KDTree:
59
- Supported `cosine`, `manhattan` and `hamming` distance

lib/src/retrieval/kd_tree/kd_tree.dart

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,4 +120,31 @@ abstract class KDTree implements Serializable {
120120
/// ```
121121
Iterable<KDTreeNeighbour> query(Vector point, int k,
122122
[Distance distance = Distance.euclidean]);
123+
124+
/// Returns [k] nearest neighbours for [point], [point] is [Iterable] unlike
125+
/// [query] method where [point] is [Vector] from ml_linalg library
126+
///
127+
/// One can use the method in the scenario apart from ml_algo ecosystem,
128+
/// since the method doesn't require dependencies such as ml_linalg
129+
///
130+
/// The neighbour is represented by an index and the distance between [point]
131+
/// and the neighbour itself. The index is a zero-based index of a point in
132+
/// the source [points] matrix. Example:
133+
///
134+
/// ```dart
135+
/// import 'package:ml_dataframe/ml_dataframe.dart';
136+
/// import 'package:ml_linalg/vector.dart';
137+
///
138+
/// final data = [
139+
/// [21, 34, 22, 11],
140+
/// [11, 33, 44, 55],
141+
/// ...,
142+
/// ];
143+
/// final kdTree = KDTree.fromIterable(data);
144+
/// final neighbours = kdTree.queryIterable([1, 2, 3, 4], 2);
145+
///
146+
/// print(neighbours[0].index); // let's say, it outputs `3` which means that the nearest neighbour is kdTree.points[3]
147+
/// ```
148+
Iterable<KDTreeNeighbour> queryIterable(Iterable<num> point, int k,
149+
[Distance distance = Distance.euclidean]);
123150
}

lib/src/retrieval/kd_tree/kd_tree_impl.dart

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,11 @@ class KDTreeImpl with SerializableMixin implements KDTree {
6161
return neighbours.toList().reversed;
6262
}
6363

64+
@override
65+
Iterable<KDTreeNeighbour> queryIterable(Iterable<num> point, int k,
66+
[Distance distanceType = Distance.euclidean]) =>
67+
query(Vector.fromList(point.toList(), dtype: dtype), k);
68+
6469
void _findKNNRecursively(KDTreeNode? node, Vector point, int k,
6570
HeapPriorityQueue<KDTreeNeighbour> neighbours, Distance distanceType) {
6671
if (node == null) {

pubspec.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name: ml_algo
22
description: Machine learning algorithms, Machine learning models performance evaluation functionality
3-
version: 16.9.0
3+
version: 16.10.0
44
homepage: https://github.com/gyrdym/ml_algo
55

66
environment:

test/retrieval/kd_tree/kd_tree_test.dart

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,19 @@ void main() {
6161
expect(result, hasLength(3));
6262
});
6363

64+
test(
65+
'should find the closest neighbours for Iterable [2.79, -9.15, 6.56, -18.59, 13.53], leafSize=1',
66+
() {
67+
final kdTree = KDTree(DataFrame(data, headerExists: false), leafSize: 1);
68+
final sample = [2.79, -9.15, 6.56, -18.59, 13.53];
69+
final result = kdTree.queryIterable(sample, 3).toList();
70+
71+
expect(result[0].index, 12);
72+
expect(result[1].index, 4);
73+
expect(result[2].index, 18);
74+
expect(result, hasLength(3));
75+
});
76+
6477
test(
6578
'should find the closest neighbours for [2.79, -9.15, 6.56, -18.59, 13.53], leafSize=3, splitStrategy=KDTreeSplitStrategy.largestVariance',
6679
() {
@@ -235,6 +248,32 @@ void main() {
235248
expect((kdTree as KDTreeImpl).searchIterationCount, 13);
236249
});
237250

251+
test(
252+
'should find the closest neighbours for [2.79, -9.15, 6.56, -18.59, 13.53], leafSize=3, manhatten distance',
253+
() {
254+
final kdTree = KDTree(DataFrame(data, headerExists: false), leafSize: 3);
255+
final sample = Vector.fromList([2.79, -9.15, 6.56, -18.59, 13.53]);
256+
final result = kdTree.query(sample, 3, Distance.manhattan).toList();
257+
258+
expect(result[0].index, 12);
259+
expect(result[1].index, 4);
260+
expect(result[2].index, 13);
261+
expect(result, hasLength(3));
262+
});
263+
264+
test(
265+
'should find the closest neighbours for [2.79, -9.15, 6.56, -18.59, 13.53], leafSize=1, manhatten distance',
266+
() {
267+
final kdTree = KDTree(DataFrame(data, headerExists: false), leafSize: 1);
268+
final sample = Vector.fromList([2.79, -9.15, 6.56, -18.59, 13.53]);
269+
final result = kdTree.query(sample, 3, Distance.manhattan).toList();
270+
271+
expect(result[0].index, 12);
272+
expect(result[1].index, 4);
273+
expect(result[2].index, 13);
274+
expect(result, hasLength(3));
275+
});
276+
238277
test('should throw an exception if the query point is of invalid length',
239278
() {
240279
final kdTree = KDTree(DataFrame(data, headerExists: false), leafSize: 3);

0 commit comments

Comments
 (0)