Skip to content

Commit 761cb21

Browse files
committed
KDTree: Supported 'cosine', 'manhattan' and 'hamming' distance (#216)
1 parent cf3dff5 commit 761cb21

File tree

8 files changed

+177
-41
lines changed

8 files changed

+177
-41
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,6 @@ doc/api/
1111
pubspec.lock
1212

1313
test/.test_coverage.dart
14+
15+
.DS_Store
16+
*/**/.DS_Store

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
# Changelog
22

3+
## 16.9.0
4+
- KDTree:
5+
- Supported `cosine`, `manhattan` and `hamming` distance
6+
37
## 16.8.0
48
- DecisionTreeClassifier:
59
- Added Gini index assessor type

benchmark/kd_tree/kd_tree_building.dart

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// 0.5 sec (MacBook Air mid 2017)
1+
// 0.14 sec (MacBook Air mid 2017)
22
import 'package:benchmark_harness/benchmark_harness.dart';
33
import 'package:ml_algo/src/retrieval/kd_tree/kd_tree.dart';
44
import 'package:ml_dataframe/ml_dataframe.dart';

benchmark/kd_tree/kd_tree_querying.dart

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ Future main() async {
3030
final points = Matrix.random(20000, 10, seed: 1, min: -5000, max: 5000);
3131

3232
trainData = DataFrame.fromMatrix(points);
33-
tree = KDTree(trainData, leafSize: 1);
33+
tree = KDTree(trainData);
3434
point = Vector.randomFilled(trainData.rows.first.length,
3535
seed: 10, min: -5000, max: 5000);
3636

lib/src/retrieval/kd_tree/kd_tree.dart

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import 'package:ml_algo/src/retrieval/kd_tree/kd_tree_impl.dart';
55
import 'package:ml_algo/src/retrieval/kd_tree/kd_tree_neighbour.dart';
66
import 'package:ml_algo/src/retrieval/kd_tree/kd_tree_split_strategy.dart';
77
import 'package:ml_dataframe/ml_dataframe.dart';
8+
import 'package:ml_linalg/distance.dart';
89
import 'package:ml_linalg/dtype.dart';
910
import 'package:ml_linalg/matrix.dart';
1011
import 'package:ml_linalg/vector.dart';
@@ -24,28 +25,28 @@ abstract class KDTree implements Serializable {
2425
/// The bigger the number, the less effective search is. If [leafSize] is
2526
/// equal to the number of [points], a regular KNN-search will take place.
2627
///
27-
/// Extremely small [leafSize] leads to ineffective memory usage since in
28-
/// this case a lot of kd-tree nodes will be allocated
28+
/// Small [leafSize] leads to an increasing amount of time for tree building,
29+
/// but querying will be very fast
2930
///
3031
/// [dtype] A data type which will be used to convert raw data from [points]
3132
/// into internal numerical representation
3233
///
3334
/// [splitStrategy] Describes how to choose a split dimension. Default value
34-
/// is [KDTreeSplitStrategy.largestVariance]
35+
/// is [KDTreeSplitStrategy.inOrder]
36+
///
37+
/// if [splitStrategy] is [KDTreeSplitStrategy.inOrder], dimension for data
38+
/// splits will be chosen one by one in order, in this case, tree building is
39+
/// very fast
3540
///
3641
/// if [splitStrategy] is [KDTreeSplitStrategy.largestVariance], dimension with
3742
/// the widest column (in terms of variance) will be chosen to split the data
3843
///
39-
/// if [splitStrategy] is [KDTreeSplitStrategy.inOrder], dimension for data
40-
/// splits will be chosen one by one in order
41-
///
42-
/// [KDTreeSplitStrategy.largestVariance] provides more accurate KNN-search,
44+
/// [KDTreeSplitStrategy.largestVariance] results in more balanced tree,
4345
/// but this strategy takes much more time to build the tree than [KDTreeSplitStrategy.inOrder]
4446
factory KDTree(DataFrame points,
4547
{int leafSize = 1,
4648
DType dtype = DType.float32,
47-
KDTreeSplitStrategy splitStrategy =
48-
KDTreeSplitStrategy.largestVariance}) =>
49+
KDTreeSplitStrategy splitStrategy = KDTreeSplitStrategy.inOrder}) =>
4950
createKDTree(points, leafSize, dtype, splitStrategy);
5051

5152
/// [pointsSrc] Data points which will be used to build the tree.
@@ -115,7 +116,8 @@ abstract class KDTree implements Serializable {
115116
/// final kdTree = KDTree(data);
116117
/// final neighbours = kdTree.query(Vector.fromList([1, 2, 3, 4]), 2);
117118
///
118-
/// print(neighbours.index); // let's say, it outputs `3` which means that the nearest neighbour is kdTree.points[3]
119+
/// print(neighbours[0].index); // let's say, it outputs `3` which means that the nearest neighbour is kdTree.points[3]
119120
/// ```
120-
Iterable<KDTreeNeighbour> query(Vector point, int k);
121+
Iterable<KDTreeNeighbour> query(Vector point, int k,
122+
[Distance distance = Distance.euclidean]);
121123
}

lib/src/retrieval/kd_tree/kd_tree_impl.dart

Lines changed: 68 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import 'package:ml_algo/src/retrieval/kd_tree/kd_tree.dart';
66
import 'package:ml_algo/src/retrieval/kd_tree/kd_tree_json_keys.dart';
77
import 'package:ml_algo/src/retrieval/kd_tree/kd_tree_neighbour.dart';
88
import 'package:ml_algo/src/retrieval/kd_tree/kd_tree_node.dart';
9+
import 'package:ml_linalg/distance.dart';
910
import 'package:ml_linalg/dtype.dart';
1011
import 'package:ml_linalg/matrix.dart';
1112
import 'package:ml_linalg/vector.dart';
@@ -45,61 +46,62 @@ class KDTreeImpl with SerializableMixin implements KDTree {
4546
int searchIterationCount = 0;
4647

4748
@override
48-
Iterable<KDTreeNeighbour> query(Vector point, int k) {
49+
Iterable<KDTreeNeighbour> query(Vector point, int k,
50+
[Distance distanceType = Distance.euclidean]) {
4951
if (point.length != points.columnsNum) {
5052
throw InvalidQueryPointLength(point.length, points.columnsNum);
5153
}
5254

5355
searchIterationCount = 0;
5456

55-
final neighbours = HeapPriorityQueue<KDTreeNeighbour>((a, b) =>
56-
(point.distanceTo(points[b.index]) - point.distanceTo(points[a.index]))
57-
.toInt());
57+
final neighbours = _createQueue(point, distanceType);
5858

59-
_findKNNRecursively(root, point, k, neighbours);
59+
_findKNNRecursively(root, point, k, neighbours, distanceType);
6060

6161
return neighbours.toList().reversed;
6262
}
6363

6464
void _findKNNRecursively(KDTreeNode? node, Vector point, int k,
65-
HeapPriorityQueue<KDTreeNeighbour> neighbours) {
65+
HeapPriorityQueue<KDTreeNeighbour> neighbours, Distance distanceType) {
6666
if (node == null) {
6767
return;
6868
}
6969

7070
if (node.isLeaf) {
71-
_knnSearch(point, node.pointIndices, neighbours, k);
71+
_knnSearch(point, node.pointIndices, neighbours, k, distanceType);
7272

7373
return;
7474
}
7575

7676
final nodePoint = points[node.pointIndices[0]];
77-
final isNodeTooFar = neighbours.length > 0 &&
78-
(point[node.splitIndex] - nodePoint[node.splitIndex]).abs() >
79-
neighbours.first.distance;
8077
final isQueueFilled = neighbours.length == k;
8178

82-
if (isQueueFilled && isNodeTooFar) {
79+
if (isQueueFilled && _isNodeToFar(node, point, neighbours, distanceType)) {
8380
return;
8481
}
8582

86-
_knnSearch(point, node.pointIndices, neighbours, k);
83+
_knnSearch(point, node.pointIndices, neighbours, k, distanceType);
8784

8885
if (point[node.splitIndex] < nodePoint[node.splitIndex]) {
89-
_findKNNRecursively(node.left, point, k, neighbours);
90-
_findKNNRecursively(node.right, point, k, neighbours);
86+
_findKNNRecursively(node.left, point, k, neighbours, distanceType);
87+
_findKNNRecursively(node.right, point, k, neighbours, distanceType);
9188
} else {
92-
_findKNNRecursively(node.right, point, k, neighbours);
93-
_findKNNRecursively(node.left, point, k, neighbours);
89+
_findKNNRecursively(node.right, point, k, neighbours, distanceType);
90+
_findKNNRecursively(node.left, point, k, neighbours, distanceType);
9491
}
9592
}
9693

97-
void _knnSearch(Vector point, List<int> pointIndices,
98-
HeapPriorityQueue<KDTreeNeighbour> neighbours, int k) {
94+
void _knnSearch(
95+
Vector point,
96+
List<int> pointIndices,
97+
HeapPriorityQueue<KDTreeNeighbour> neighbours,
98+
int k,
99+
Distance distanceType) {
99100
pointIndices.forEach((candidateIdx) {
100101
searchIterationCount++;
101102
final candidate = points[candidateIdx];
102-
final candidateDistance = candidate.distanceTo(point);
103+
final candidateDistance =
104+
candidate.distanceTo(point, distance: distanceType);
103105
final lastNeighbourDistance =
104106
neighbours.length > 0 ? neighbours.first.distance : candidateDistance;
105107
final isGoodCandidate = candidateDistance < lastNeighbourDistance;
@@ -114,4 +116,51 @@ class KDTreeImpl with SerializableMixin implements KDTree {
114116
}
115117
});
116118
}
119+
120+
bool _isNodeToFar(KDTreeNode node, Vector point,
121+
HeapPriorityQueue<KDTreeNeighbour> neighbours, Distance distanceType) {
122+
if (neighbours.length == 0) {
123+
return false;
124+
}
125+
126+
final nodePoint = points[node.pointIndices[0]];
127+
128+
switch (distanceType) {
129+
case Distance.euclidean:
130+
case Distance.manhattan:
131+
return (point[node.splitIndex] - nodePoint[node.splitIndex]).abs() >
132+
neighbours.first.distance;
133+
134+
case Distance.hamming:
135+
case Distance.cosine:
136+
final other = nodePoint.set(node.splitIndex, point[node.splitIndex]);
137+
138+
return other.distanceTo(nodePoint, distance: distanceType) >
139+
neighbours.first.distance;
140+
141+
default:
142+
throw UnsupportedError(
143+
'Distance type $distanceType is not supported yet');
144+
}
145+
}
146+
147+
HeapPriorityQueue<KDTreeNeighbour> _createQueue(
148+
Vector point, Distance distanceType) {
149+
return HeapPriorityQueue<KDTreeNeighbour>((a, b) {
150+
final distanceA =
151+
point.distanceTo(points[a.index], distance: distanceType);
152+
final distanceB =
153+
point.distanceTo(points[b.index], distance: distanceType);
154+
155+
if (distanceA < distanceB) {
156+
return 1;
157+
}
158+
159+
if (distanceA > distanceB) {
160+
return -1;
161+
}
162+
163+
return 0;
164+
});
165+
}
117166
}

pubspec.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name: ml_algo
22
description: Machine learning algorithms, Machine learning models performance evaluation functionality
3-
version: 16.8.0
3+
version: 16.9.0
44
homepage: https://github.com/gyrdym/ml_algo
55

66
environment:

test/retrieval/kd_tree/kd_tree_test.dart

Lines changed: 86 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ import 'package:ml_algo/src/retrieval/kd_tree/kd_tree.dart';
33
import 'package:ml_algo/src/retrieval/kd_tree/kd_tree_impl.dart';
44
import 'package:ml_algo/src/retrieval/kd_tree/kd_tree_split_strategy.dart';
55
import 'package:ml_dataframe/ml_dataframe.dart';
6+
import 'package:ml_linalg/distance.dart';
67
import 'package:ml_linalg/dtype.dart';
78
import 'package:ml_linalg/vector.dart';
89
import 'package:test/test.dart';
@@ -61,10 +62,10 @@ void main() {
6162
});
6263

6364
test(
64-
'should find the closest neighbours for [2.79, -9.15, 6.56, -18.59, 13.53], leafSize=3, splitStrategy=KDTreeSplitStrategy.inOrder',
65+
'should find the closest neighbours for [2.79, -9.15, 6.56, -18.59, 13.53], leafSize=3, splitStrategy=KDTreeSplitStrategy.largestVariance',
6566
() {
6667
final kdTree = KDTree(DataFrame(data, headerExists: false),
67-
leafSize: 3, splitStrategy: KDTreeSplitStrategy.inOrder);
68+
leafSize: 3, splitStrategy: KDTreeSplitStrategy.largestVariance);
6869
final sample = Vector.fromList([2.79, -9.15, 6.56, -18.59, 13.53]);
6970
final result = kdTree.query(sample, 3).toList();
7071

@@ -84,7 +85,7 @@ void main() {
8485
expect(result[0].index, 4);
8586
expect(result[1].index, 12);
8687
expect(result[2].index, 3);
87-
expect(result[3].index, 1);
88+
expect(result[3].index, 18);
8889
expect(result, hasLength(4));
8990
});
9091

@@ -98,7 +99,7 @@ void main() {
9899
expect(result[0].index, 4);
99100
expect(result[1].index, 12);
100101
expect(result[2].index, 3);
101-
expect(result[3].index, 1);
102+
expect(result[3].index, 18);
102103
expect(result, hasLength(4));
103104
}, skip: false);
104105

@@ -112,8 +113,24 @@ void main() {
112113
expect(result[0].index, 19);
113114
expect(result[1].index, 11);
114115
expect(result[2].index, 6);
115-
expect(result[4].index, 9);
116-
expect(result[3].index, 18);
116+
expect(result[3].index, 9);
117+
expect(result[4].index, 18);
118+
expect(result[5].index, 2);
119+
expect(result[6].index, 14);
120+
expect(result[7].index, 10);
121+
expect(result[8].index, 15);
122+
expect(result[9].index, 5);
123+
expect(result[10].index, 7);
124+
expect(result[10].index, 7);
125+
expect(result[11].index, 13);
126+
expect(result[12].index, 3);
127+
expect(result[13].index, 12);
128+
expect(result[14].index, 17);
129+
expect(result[15].index, 1);
130+
expect(result[16].index, 0);
131+
expect(result[17].index, 16);
132+
expect(result[18].index, 4);
133+
expect(result[19].index, 8);
117134
expect(result, hasLength(20));
118135
});
119136

@@ -143,7 +160,7 @@ void main() {
143160

144161
kdTree.query(sample, 1).toList();
145162

146-
expect((kdTree as KDTreeImpl).searchIterationCount, 14);
163+
expect((kdTree as KDTreeImpl).searchIterationCount, 4);
147164
});
148165

149166
test(
@@ -154,7 +171,68 @@ void main() {
154171

155172
kdTree.query(sample, 1).toList();
156173

157-
expect((kdTree as KDTreeImpl).searchIterationCount, 7);
174+
expect((kdTree as KDTreeImpl).searchIterationCount, 6);
175+
});
176+
177+
test(
178+
'should find the closest neighbours for [12, 23, 22, 11, -20], k=1, leafSize=1, cosine distance',
179+
() {
180+
final kdTree = KDTree(DataFrame(data, headerExists: false), leafSize: 1);
181+
final sample = Vector.fromList([12, 23, 22, 11, -20]);
182+
final result = kdTree.query(sample, 1, Distance.cosine).toList();
183+
184+
expect(result, hasLength(1));
185+
expect(result[0].index, 17);
186+
});
187+
188+
test(
189+
'should find the closest neighbours for [12, 23, 22, 11, -20], k=2, leafSize=1, cosine distance',
190+
() {
191+
final kdTree = KDTree(DataFrame(data, headerExists: false), leafSize: 1);
192+
final sample = Vector.fromList([12, 23, 22, 11, -20]);
193+
final result = kdTree.query(sample, 2, Distance.cosine).toList();
194+
195+
expect(result, hasLength(2));
196+
expect(result[0].index, 17);
197+
expect(result[1].index, 8);
198+
});
199+
200+
test(
201+
'should find the closest neighbours for [12, 23, 22, 11, -20], k=3, leafSize=1, cosine distance',
202+
() {
203+
final kdTree = KDTree(DataFrame(data, headerExists: false), leafSize: 1);
204+
final sample = Vector.fromList([12, 23, 22, 11, -20]);
205+
final result = kdTree.query(sample, 3, Distance.cosine).toList();
206+
207+
expect(result, hasLength(3));
208+
expect(result[0].index, 17);
209+
expect(result[1].index, 8);
210+
expect(result[2].index, 0);
211+
});
212+
213+
test(
214+
'should find the closest neighbours for [12, 23, 22, 11, -20], k=3, leafSize=3, cosine distance',
215+
() {
216+
final kdTree = KDTree(DataFrame(data, headerExists: false), leafSize: 3);
217+
final sample = Vector.fromList([12, 23, 22, 11, -20]);
218+
final result = kdTree.query(sample, 3, Distance.cosine).toList();
219+
220+
expect(result, hasLength(3));
221+
expect(result[0].index, 17);
222+
expect(result[1].index, 8);
223+
expect(result[2].index, 0);
224+
});
225+
226+
test(
227+
'should find the closest neighbours for [12, 23, 22, 11, -20], k=3, leafSize=3, cosine distance for conceivable amount of iterations',
228+
() {
229+
final kdTree = KDTree(DataFrame(data, headerExists: false),
230+
splitStrategy: KDTreeSplitStrategy.largestVariance, leafSize: 1);
231+
final sample = Vector.fromList([12, 23, 22, 11, -20]);
232+
233+
kdTree.query(sample, 3, Distance.cosine).toList();
234+
235+
expect((kdTree as KDTreeImpl).searchIterationCount, 13);
158236
});
159237

160238
test('should throw an exception if the query point is of invalid length',

0 commit comments

Comments
 (0)