You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
-[KDTree](https://github.com/gyrdym/ml_algo/blob/master/lib/src/retrieval/kd_tree/kd_tree.dart) An algorithm for
75
+
efficient data retrieval.
74
76
75
77
For more information on the library's API, please visit the [API reference](https://pub.dev/documentation/ml_algo/latest/ml_algo/ml_algo-library.html)
76
78
@@ -580,7 +582,7 @@ void main() async {
580
582
````
581
583
</details>
582
584
583
-
## Decision tree-based classification
585
+
###Decision tree-based classification
584
586
585
587
Let's try to classify data from a well-known [Iris](https://www.kaggle.com/datasets/uciml/iris) dataset using a non-linear algorithm - [decision trees](https://en.wikipedia.org/wiki/Decision_tree)
Let's take a look at another field of machine learning - data retrieval. The field is represented by a family of algorithms,
657
+
one of them is `KDTree` which is exposed by the library.
658
+
659
+
`KDTree` is an algorithm that divides the whole search space into partitions in form of the binary tree which makes it
660
+
efficient to retrieve data.
661
+
662
+
Let's retrieve some data points through a kd-tree built on the [Iris](https://www.kaggle.com/datasets/uciml/iris) dataset.
663
+
664
+
First, we need to prepare the data. To do so, it's needed to load the dataset. For this purpose, we may use
665
+
[loadIrisDataset](https://pub.dev/documentation/ml_dataframe/latest/ml_dataframe/loadIrisDataset.html) function from [ml_dataframe](https://pub.dev/packages/ml_dataframe). The function returns prefilled with the Iris data DataFrame instance:
666
+
667
+
```dart
668
+
import 'package:ml_algo/ml_algo.dart';
669
+
import 'package:ml_dataframe/ml_dataframe.dart';
670
+
671
+
void main() async {
672
+
final originalData = await loadIrisDataset();
673
+
}
674
+
```
675
+
676
+
Since the dataset contains `Id` column that doesn't make sense and `Species` column that contains text data, we need to
677
+
drop these columns:
678
+
679
+
```dart
680
+
import 'package:ml_algo/ml_algo.dart';
681
+
import 'package:ml_dataframe/ml_dataframe.dart';
682
+
683
+
void main() async {
684
+
final originalData = await loadIrisDataset();
685
+
final data = originalData.dropSeries(names: ['Id', 'Species']);
686
+
}
687
+
```
688
+
689
+
Next, we can build the tree:
690
+
691
+
```dart
692
+
import 'package:ml_algo/ml_algo.dart';
693
+
import 'package:ml_dataframe/ml_dataframe.dart';
694
+
695
+
void main() async {
696
+
final originalData = await loadIrisDataset();
697
+
final data = originalData.dropSeries(names: ['Id', 'Species']);
698
+
final tree = KDTree(data);
699
+
}
700
+
```
701
+
702
+
And query nearest neighbours for an arbitrary point. Let's say, we want to find 5 nearest neighbours for the point `[6.5, 3.01, 4.5, 1.5]`:
703
+
704
+
```dart
705
+
import 'package:ml_algo/ml_algo.dart';
706
+
import 'package:ml_dataframe/ml_dataframe.dart';
707
+
import 'package:ml_linalg/vector.dart';
708
+
709
+
void main() async {
710
+
final originalData = await loadIrisDataset();
711
+
final data = originalData.dropSeries(names: ['Id', 'Species']);
712
+
final tree = KDTree(data);
713
+
final neighbourCount = 5;
714
+
final point = Vector.fromList([6.5, 3.01, 4.5, 1.5]);
715
+
final neighbours = tree.query(point, neighbourCount);
The nearest point has an index 75 in the original data. Let's check a record at the index:
728
+
729
+
```dart
730
+
import 'package:ml_dataframe/ml_dataframe.dart';
731
+
732
+
void main() async {
733
+
final originalData = await loadIrisDataset();
734
+
735
+
print(originalData.rows.elementAt(75));
736
+
}
737
+
```
738
+
739
+
It prints the following:
740
+
741
+
```
742
+
(76, 6.6, 3.0, 4.4, 1.4, Iris-versicolor)
743
+
```
744
+
745
+
Remember, we dropped `Id` and `Species` columns which are the very first and the very last elements in the output, so the
746
+
rest elements, `6.6, 3.0, 4.4, 1.4` look quite similar to our target point - `6.5, 3.01, 4.5, 1.5`, so the query result makes
747
+
sense.
748
+
749
+
If you want to use `KDTree` outside the ml_algo ecosystem, meaning you don't want to use [ml_linalg](https://pub.dev/packages/ml_linalg) and [ml_dataframe](https://pub.dev/packages/ml_dataframe)
750
+
packages in your application, you may import only `KDTree` library and use `fromIterable` constructor and `queryIterable`
751
+
method to perform the query:
752
+
753
+
```dart
754
+
import 'package:ml_algo/kd_tree.dart';
755
+
756
+
void main() async {
757
+
final tree = KDTree.fromIterable([
758
+
// some data here
759
+
]);
760
+
final neighbourCount = 5;
761
+
final neighbours = tree.queryIterable([/* some point here */], neighbourCount);
762
+
763
+
print(neighbours);
764
+
}
765
+
```
766
+
652
767
## Models retraining
653
768
654
769
Someday our previously shining model can degrade in terms of prediction accuracy - in this case, we can retrain it.
0 commit comments