Skip to content

Commit b1cb036

Browse files
committed
KDTree exported as a separate library; an example added to README (#224)
1 parent 51c16c8 commit b1cb036

File tree

5 files changed

+158
-4
lines changed

5 files changed

+158
-4
lines changed

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
# Changelog
22

3+
## 16.11.1
4+
- `KDTree` example added to README
5+
- `kd_tree` exported as a separate library
6+
37
## 16.11.0
48
- `ml_preprocessing` version upgraded to 7.0.2
59
- `ml_dataframe` version upgraded to 1.0.0

README.md

Lines changed: 118 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ The library is a part of the ecosystem:
2020
- [Logistic regression](#logistic-regression)
2121
- [Linear regression](#linear-regression)
2222
- [Decision tree-based classification](#decision-tree-based-classification)
23+
- [KDTree-based data retrieval](#kdtree-based-data-retrieval)
2324
- [Models retraining](#models-retraining)
2425
- [Notes on gradient-based optimisation algorithms](#a-couple-of-words-about-linear-models-which-use-gradient-optimisation-methods)
2526

@@ -31,7 +32,7 @@ The main purpose of the library is to give native Dart implementation of machine
3132
interested both in Dart language and data science. This library aims at Dart VM and Flutter, it's impossible to use
3233
it in web applications.
3334

34-
## The library's content
35+
## The library content
3536

3637
- #### Model selection
3738
- [CrossValidator](https://github.com/gyrdym/ml_algo/blob/master/lib/src/model_selection/cross_validator/cross_validator.dart).
@@ -70,7 +71,8 @@ it in web applications.
7071
training data. It may catch non-linear patterns of the data.
7172

7273
- #### Clustering and retrieval algorithms
73-
- [KDTree](https://github.com/gyrdym/ml_algo/blob/master/lib/src/retrieval/kd_tree/kd_tree.dart)
74+
- [KDTree](https://github.com/gyrdym/ml_algo/blob/master/lib/src/retrieval/kd_tree/kd_tree.dart) An algorithm for
75+
efficient data retrieval.
7476

7577
For more information on the library's API, please visit the [API reference](https://pub.dev/documentation/ml_algo/latest/ml_algo/ml_algo-library.html)
7678

@@ -580,7 +582,7 @@ void main() async {
580582
````
581583
</details>
582584

583-
## Decision tree-based classification
585+
### Decision tree-based classification
584586

585587
Let's try to classify data from a well-known [Iris](https://www.kaggle.com/datasets/uciml/iris) dataset using a non-linear algorithm - [decision trees](https://en.wikipedia.org/wiki/Decision_tree)
586588

@@ -649,6 +651,119 @@ resulting SVG image:
649651
<img height="600" src="https://raw.github.com/gyrdym/ml_algo/master/e2e/decision_tree_classifier/iris_tree.svg?sanitize=true">
650652
</p>
651653

654+
### KDTree-based data retrieval
655+
656+
Let's take a look at another field of machine learning - data retrieval. The field is represented by a family of algorithms,
657+
one of them is `KDTree` which is exposed by the library.
658+
659+
`KDTree` is an algorithm that divides the whole search space into partitions in form of the binary tree which makes it
660+
efficient to retrieve data.
661+
662+
Let's retrieve some data points through a kd-tree built on the [Iris](https://www.kaggle.com/datasets/uciml/iris) dataset.
663+
664+
First, we need to prepare the data. To do so, it's needed to load the dataset. For this purpose, we may use
665+
[loadIrisDataset](https://pub.dev/documentation/ml_dataframe/latest/ml_dataframe/loadIrisDataset.html) function from [ml_dataframe](https://pub.dev/packages/ml_dataframe). The function returns prefilled with the Iris data DataFrame instance:
666+
667+
```dart
668+
import 'package:ml_algo/ml_algo.dart';
669+
import 'package:ml_dataframe/ml_dataframe.dart';
670+
671+
void main() async {
672+
final originalData = await loadIrisDataset();
673+
}
674+
```
675+
676+
Since the dataset contains `Id` column that doesn't make sense and `Species` column that contains text data, we need to
677+
drop these columns:
678+
679+
```dart
680+
import 'package:ml_algo/ml_algo.dart';
681+
import 'package:ml_dataframe/ml_dataframe.dart';
682+
683+
void main() async {
684+
final originalData = await loadIrisDataset();
685+
final data = originalData.dropSeries(names: ['Id', 'Species']);
686+
}
687+
```
688+
689+
Next, we can build the tree:
690+
691+
```dart
692+
import 'package:ml_algo/ml_algo.dart';
693+
import 'package:ml_dataframe/ml_dataframe.dart';
694+
695+
void main() async {
696+
final originalData = await loadIrisDataset();
697+
final data = originalData.dropSeries(names: ['Id', 'Species']);
698+
final tree = KDTree(data);
699+
}
700+
```
701+
702+
And query nearest neighbours for an arbitrary point. Let's say, we want to find 5 nearest neighbours for the point `[6.5, 3.01, 4.5, 1.5]`:
703+
704+
```dart
705+
import 'package:ml_algo/ml_algo.dart';
706+
import 'package:ml_dataframe/ml_dataframe.dart';
707+
import 'package:ml_linalg/vector.dart';
708+
709+
void main() async {
710+
final originalData = await loadIrisDataset();
711+
final data = originalData.dropSeries(names: ['Id', 'Species']);
712+
final tree = KDTree(data);
713+
final neighbourCount = 5;
714+
final point = Vector.fromList([6.5, 3.01, 4.5, 1.5]);
715+
final neighbours = tree.query(point, neighbourCount);
716+
717+
print(neighbours);
718+
}
719+
```
720+
721+
The last instruction prints the following:
722+
723+
```
724+
(Index: 75, Distance: 0.17349341930302867), (Index: 51, Distance: 0.21470911402365767), (Index: 65, Distance: 0.26095956499211426), (Index: 86, Distance: 0.29681616124778537), (Index: 56, Distance: 0.4172527193942372))
725+
```
726+
727+
The nearest point has an index 75 in the original data. Let's check a record at the index:
728+
729+
```dart
730+
import 'package:ml_dataframe/ml_dataframe.dart';
731+
732+
void main() async {
733+
final originalData = await loadIrisDataset();
734+
735+
print(originalData.rows.elementAt(75));
736+
}
737+
```
738+
739+
It prints the following:
740+
741+
```
742+
(76, 6.6, 3.0, 4.4, 1.4, Iris-versicolor)
743+
```
744+
745+
Remember, we dropped `Id` and `Species` columns which are the very first and the very last elements in the output, so the
746+
rest elements, `6.6, 3.0, 4.4, 1.4` look quite similar to our target point - `6.5, 3.01, 4.5, 1.5`, so the query result makes
747+
sense.
748+
749+
If you want to use `KDTree` outside the ml_algo ecosystem, meaning you don't want to use [ml_linalg](https://pub.dev/packages/ml_linalg) and [ml_dataframe](https://pub.dev/packages/ml_dataframe)
750+
packages in your application, you may import only `KDTree` library and use `fromIterable` constructor and `queryIterable`
751+
method to perform the query:
752+
753+
```dart
754+
import 'package:ml_algo/kd_tree.dart';
755+
756+
void main() async {
757+
final tree = KDTree.fromIterable([
758+
// some data here
759+
]);
760+
final neighbourCount = 5;
761+
final neighbours = tree.queryIterable([/* some point here */], neighbourCount);
762+
763+
print(neighbours);
764+
}
765+
```
766+
652767
## Models retraining
653768

654769
Someday our previously shining model can degrade in terms of prediction accuracy - in this case, we can retrain it.

e2e/kd_tree/kd_tree_test.dart

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
import 'package:ml_algo/ml_algo.dart';
2+
import 'package:ml_dataframe/ml_dataframe.dart';
3+
import 'package:ml_linalg/dtype.dart';
4+
import 'package:ml_linalg/vector.dart';
5+
import 'package:test/test.dart';
6+
7+
void main() async {
8+
group('KDTree', () {
9+
test('should return correct list of neighbours, dtype=DType.float32',
10+
() async {
11+
final originalData = await loadIrisDataset();
12+
final data = originalData.dropSeries(names: ['Id', 'Species']);
13+
final tree = KDTree(data);
14+
final neighbours = tree.query(Vector.fromList([6.5, 3.01, 4.5, 1.5]), 5);
15+
16+
expect(neighbours, hasLength(5));
17+
expect(neighbours.toString(),
18+
'((Index: 75, Distance: 0.17349341930302867), (Index: 51, Distance: 0.21470911402365767), (Index: 65, Distance: 0.26095956499211426), (Index: 86, Distance: 0.29681616124778537), (Index: 56, Distance: 0.4172527193942372))');
19+
});
20+
21+
test('should return correct list of neighbours, dtype=DType.float64',
22+
() async {
23+
final originalData = await loadIrisDataset();
24+
final data = originalData.dropSeries(names: ['Id', 'Species']);
25+
final tree = KDTree(data, dtype: DType.float64);
26+
final neighbours = tree.query(
27+
Vector.fromList([6.5, 3.01, 4.5, 1.5], dtype: DType.float64), 5);
28+
29+
expect(neighbours, hasLength(5));
30+
expect(neighbours.toString(),
31+
'((Index: 75, Distance: 0.17349351572897434), (Index: 51, Distance: 0.21470910553583905), (Index: 65, Distance: 0.2609597670139979), (Index: 86, Distance: 0.29681644159311693), (Index: 56, Distance: 0.41725292090050153))');
32+
});
33+
});
34+
}

lib/kd_tree.dart

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
export 'package:ml_algo/src/retrieval/kd_tree/kd_tree.dart';

pubspec.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ dependencies:
1010
collection: ^1.16.0
1111
injector: ^2.0.0
1212
json_annotation: ^4.0.0
13-
ml_dataframe: ^1.0.0
13+
ml_dataframe: ^1.4.2
1414
ml_linalg: ^13.7.0
1515
ml_preprocessing: ^7.0.2
1616
quiver: ^3.0.0

0 commit comments

Comments
 (0)