Skip to content

Commit e2570e7

Browse files
authored
KNN models: serialization/deserialization added (#166)
1 parent 212918d commit e2570e7

40 files changed

+905
-170
lines changed

.github/workflows/ci_pipeline.yml

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,3 @@ jobs:
3030

3131
- name: Run e2e tests
3232
run: dart pub run test e2e -p vm
33-
34-
- name: Code coverage
35-
run: dart pub run test_coverage
36-
37-
- name: Coveralls GitHub Action
38-
uses: coverallsapp/github-action@v1.1.2
39-
with:
40-
github-token: ${{ secrets.GITHUB_TOKEN }}

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,5 @@ doc/api/
99
/.dart_tool/
1010

1111
pubspec.lock
12+
13+
test/.test_coverage.dart

CHANGELOG.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,11 @@
11
# Changelog
22

3+
## 15.4.0
4+
- `KnnClassifier`:
5+
- serialization/deserialization functionality added with possibility to save the model into a json file
6+
- `KnnRegressor`:
7+
- serialization/deserialization functionality added with possibility to save the model into a json file
8+
39
## 15.3.6
410
- `ml_dataframe`: version 0.3.0 supported
511
- `README.md`: build badge corrected

lib/src/classifier/decision_tree_classifier/decision_tree_classifier.dart

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,9 @@ abstract class DecisionTreeClassifier implements
9393
/// // here you can use previously fitted restored classifier to make
9494
/// // some prediction, e.g. via `DecisionTreeClassifier.predict(...)`;
9595
/// ````
96-
factory DecisionTreeClassifier.fromJson(String json) =>
97-
createDecisionTreeClassifierFromJson(json);
96+
factory DecisionTreeClassifier.fromJson(String json) {
97+
initDecisionTreeModule();
98+
99+
return createDecisionTreeClassifierFromJson(json);
100+
}
98101
}

lib/src/classifier/knn_classifier/_helpers/create_knn_classifier.dart

Lines changed: 26 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ KnnClassifier createKnnClassifier(
1616
int k,
1717
KernelType kernelType,
1818
Distance distance,
19+
String columnPrefix,
1920
DType dtype,
2021
) {
2122
validateTrainData(trainData, [targetName]);
@@ -25,8 +26,10 @@ KnnClassifier createKnnClassifier(
2526
).toList();
2627
final featuresSplit = splits[0];
2728
final targetSplit = splits[1];
28-
final trainFeatures = featuresSplit.toMatrix(dtype);
29-
final trainLabels = targetSplit.toMatrix(dtype);
29+
final trainFeatures = featuresSplit
30+
.toMatrix(dtype);
31+
final trainLabels = targetSplit
32+
.toMatrix(dtype);
3033
final classLabels = targetSplit[targetName].isDiscrete
3134
? targetSplit[targetName]
3235
.discreteValues
@@ -37,27 +40,27 @@ KnnClassifier createKnnClassifier(
3740
.getColumn(0)
3841
.unique()
3942
.toList(growable: false);
40-
final kernelFactory = knnClassifierInjector
41-
.get<KernelFactory>();
42-
final kernel = kernelFactory
43+
final kernel = knnClassifierInjector
44+
.get<KernelFactory>()
4345
.createByType(kernelType);
44-
final solverFactory = knnClassifierInjector
45-
.get<KnnSolverFactory>();
46-
final solver = solverFactory.create(
47-
trainFeatures,
48-
trainLabels,
49-
k,
50-
distance,
51-
true,
52-
);
53-
final knnClassifierFactory = knnClassifierInjector
54-
.get<KnnClassifierFactory>();
46+
final solver = knnClassifierInjector
47+
.get<KnnSolverFactory>()
48+
.create(
49+
trainFeatures,
50+
trainLabels,
51+
k,
52+
distance,
53+
true,
54+
);
5555

56-
return knnClassifierFactory.create(
57-
targetName,
58-
classLabels,
59-
kernel,
60-
solver,
61-
dtype,
62-
);
56+
return knnClassifierInjector
57+
.get<KnnClassifierFactory>()
58+
.create(
59+
targetName,
60+
classLabels,
61+
kernel,
62+
solver,
63+
columnPrefix,
64+
dtype,
65+
);
6366
}
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
import 'dart:convert';
2+
3+
import 'package:ml_algo/src/classifier/knn_classifier/knn_classifier.dart';
4+
import 'package:ml_algo/src/classifier/knn_classifier/knn_classifier_impl.dart';
5+
6+
KnnClassifier createKnnClassifierFromJson(String json) {
7+
if (json.isEmpty) {
8+
throw Exception('Provided JSON object is empty, please provide a proper '
9+
'JSON object');
10+
}
11+
12+
final decodedJson = jsonDecode(json) as Map<String, dynamic>;
13+
14+
return KnnClassifierImpl.fromJson(decodedJson);
15+
}

lib/src/classifier/knn_classifier/knn_classifier.dart

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import 'package:ml_algo/src/classifier/classifier.dart';
22
import 'package:ml_algo/src/classifier/knn_classifier/_helpers/create_knn_classifier.dart';
3+
import 'package:ml_algo/src/classifier/knn_classifier/_helpers/create_knn_classifier_from_json.dart';
34
import 'package:ml_algo/src/classifier/knn_classifier/_init_module.dart';
5+
import 'package:ml_algo/src/common/serializable/serializable.dart';
46
import 'package:ml_algo/src/knn_kernel/kernel_type.dart';
57
import 'package:ml_algo/src/model_selection/assessable.dart';
68
import 'package:ml_dataframe/ml_dataframe.dart';
@@ -19,7 +21,7 @@ import 'package:ml_linalg/dtype.dart';
1921
/// imprecise result. Thus the weighted version of KNN algorithm is used in the
2022
/// classifier. To get weight of a particular observation one may use a kernel
2123
/// function.
22-
abstract class KnnClassifier implements Assessable, Classifier {
24+
abstract class KnnClassifier implements Assessable, Classifier, Serializable {
2325
/// Parameters:
2426
///
2527
/// [trainData] Labelled observations. Must contain [targetName] column.
@@ -45,6 +47,7 @@ abstract class KnnClassifier implements Assessable, Classifier {
4547
{
4648
KernelType kernel = KernelType.gaussian,
4749
Distance distance = Distance.euclidean,
50+
String classLabelPrefix = 'Class label',
4851
DType dtype = DType.float32,
4952
}
5053
) {
@@ -56,7 +59,47 @@ abstract class KnnClassifier implements Assessable, Classifier {
5659
k,
5760
kernel,
5861
distance,
62+
classLabelPrefix,
5963
dtype,
6064
);
6165
}
66+
67+
/// Restores previously fitted classifier instance from the given [json]
68+
///
69+
/// ````dart
70+
/// import 'dart:io';
71+
/// import 'package:ml_dataframe/ml_dataframe.dart';
72+
///
73+
/// final data = <Iterable>[
74+
/// ['feature 1', 'feature 2', 'feature 3', 'outcome']
75+
/// [ 5.0, 7.0, 6.0, 1.0],
76+
/// [ 1.0, 2.0, 3.0, 0.0],
77+
/// [ 10.0, 12.0, 31.0, 0.0],
78+
/// [ 9.0, 8.0, 5.0, 0.0],
79+
/// [ 4.0, 0.0, 1.0, 1.0],
80+
/// ];
81+
/// final targetName = 'outcome';
82+
/// final samples = DataFrame(data, headerExists: true);
83+
/// final classifier = KnnClassifier(
84+
/// samples,
85+
/// targetName,
86+
/// 3,
87+
/// );
88+
///
89+
/// final pathToFile = './classifier.json';
90+
///
91+
/// await classifier.saveAsJson(pathToFile);
92+
///
93+
/// final file = File(pathToFile);
94+
/// final json = await file.readAsString();
95+
/// final restoredClassifier = KnnClassifier.fromJson(json);
96+
///
97+
/// // here you can use previously fitted restored classifier to make
98+
/// // some prediction, e.g. via `KnnClassifier.predict(...)`;
99+
/// ````
100+
factory KnnClassifier.fromJson(String json) {
101+
initKnnClassifierModule();
102+
103+
return createKnnClassifierFromJson(json);
104+
}
62105
}

lib/src/classifier/knn_classifier/knn_classifier_factory.dart

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ abstract class KnnClassifierFactory {
99
List<num> classLabels,
1010
Kernel kernel,
1111
KnnSolver solver,
12+
String columnPrefix,
1213
DType dtype,
1314
);
1415
}

lib/src/classifier/knn_classifier/knn_classifier_factory_impl.dart

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,14 @@ class KnnClassifierFactoryImpl implements KnnClassifierFactory {
1414
List<num> classLabels,
1515
Kernel kernel,
1616
KnnSolver solver,
17+
String columnPrefix,
1718
DType dtype,
18-
) => KnnClassifierImpl(targetName, classLabels, kernel, solver, dtype);
19+
) => KnnClassifierImpl(
20+
targetName,
21+
classLabels,
22+
kernel,
23+
solver,
24+
columnPrefix,
25+
dtype,
26+
);
1927
}

lib/src/classifier/knn_classifier/knn_classifier_impl.dart

Lines changed: 53 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,42 +1,75 @@
1+
import 'package:json_annotation/json_annotation.dart';
12
import 'package:ml_algo/src/classifier/_mixins/assessable_classifier_mixin.dart';
23
import 'package:ml_algo/src/classifier/knn_classifier/knn_classifier.dart';
4+
import 'package:ml_algo/src/classifier/knn_classifier/knn_classifier_json_keys.dart';
5+
import 'package:ml_algo/src/common/serializable/serializable_mixin.dart';
36
import 'package:ml_algo/src/helpers/validate_class_label_list.dart';
47
import 'package:ml_algo/src/helpers/validate_test_features.dart';
58
import 'package:ml_algo/src/knn_kernel/kernel.dart';
9+
import 'package:ml_algo/src/knn_kernel/kernel_json_converter.dart';
610
import 'package:ml_algo/src/knn_solver/knn_solver.dart';
11+
import 'package:ml_algo/src/knn_solver/knn_solver_json_converter.dart';
712
import 'package:ml_algo/src/knn_solver/neigbour.dart';
813
import 'package:ml_dataframe/ml_dataframe.dart';
914
import 'package:ml_linalg/dtype.dart';
15+
import 'package:ml_linalg/dtype_to_json.dart';
16+
import 'package:ml_linalg/from_dtype_json.dart';
1017
import 'package:ml_linalg/matrix.dart';
1118
import 'package:ml_linalg/vector.dart';
1219

20+
part 'knn_classifier_impl.g.dart';
21+
22+
@JsonSerializable()
23+
@KnnSolverJsonConverter()
24+
@KernelJsonConverter()
1325
class KnnClassifierImpl
1426
with
15-
AssessableClassifierMixin
27+
AssessableClassifierMixin,
28+
SerializableMixin
1629
implements
1730
KnnClassifier {
1831
KnnClassifierImpl(
19-
this._targetColumnName,
20-
this._classLabels,
21-
this._kernel,
22-
this._solver,
32+
this.targetColumnName,
33+
this.classLabels,
34+
this.kernel,
35+
this.solver,
36+
this.classLabelPrefix,
2337
this.dtype,
2438
) {
25-
validateClassLabelList(_classLabels);
39+
validateClassLabelList(classLabels);
2640
}
2741

28-
final String _targetColumnName;
42+
factory KnnClassifierImpl.fromJson(Map<String, dynamic> json) =>
43+
_$KnnClassifierImplFromJson(json);
44+
45+
@override
46+
Map<String, dynamic> toJson() => _$KnnClassifierImplToJson(this);
47+
48+
@JsonKey(name: knnClassifierTargetColumnNameJsonKey)
49+
final String targetColumnName;
2950

3051
@override
52+
@JsonKey(
53+
name: knnClassifierDTypeJsonKey,
54+
toJson: dTypeToJson,
55+
fromJson: fromDTypeJson,
56+
)
3157
final DType dtype;
3258

3359
@override
34-
Iterable<String> get targetNames => [_targetColumnName];
60+
Iterable<String> get targetNames => [targetColumnName];
3561

36-
final List<num> _classLabels;
37-
final Kernel _kernel;
38-
final KnnSolver _solver;
39-
final String _columnPrefix = 'Class label';
62+
@JsonKey(name: knnClassifierClassLabelsJsonKey)
63+
final List<num> classLabels;
64+
65+
@JsonKey(name: knnClassifierKernelJsonKey)
66+
final Kernel kernel;
67+
68+
@JsonKey(name: knnClassifierSolverJsonKey)
69+
final KnnSolver solver;
70+
71+
@JsonKey(name: knnClassifierClassLabelPrefixJsonKey)
72+
final String classLabelPrefix;
4073

4174
@override
4275
final num positiveLabel = null;
@@ -83,10 +116,12 @@ class KnnClassifierImpl
83116
DataFrame predictProbabilities(DataFrame features) {
84117
final labelsToProbabilities = _getLabelToProbabilityMapping(features);
85118
final probabilityMatrix = _getProbabilityMatrix(labelsToProbabilities);
86-
87119
final header = labelsToProbabilities
88120
.keys
89-
.map((label) => '${_columnPrefix} ${label.toString()}');
121+
.map((label) =>
122+
[classLabelPrefix.trim(), label.toString().trim()]
123+
.where((element) => element.isNotEmpty)
124+
.join(' '));
90125

91126
return DataFrame.fromMatrix(probabilityMatrix, header: header);
92127
}
@@ -114,8 +149,8 @@ class KnnClassifierImpl
114149
/// where each row is a classes probability distribution for the appropriate
115150
/// feature record from the test feature matrix
116151
Map<num, List<num>> _getLabelToProbabilityMapping(DataFrame features) {
117-
final kNeighbourGroups = _solver.findKNeighbours(features.toMatrix(dtype));
118-
final classLabelsAsSet = Set<num>.from(_classLabels);
152+
final kNeighbourGroups = solver.findKNeighbours(features.toMatrix(dtype));
153+
final classLabelsAsSet = Set<num>.from(classLabels);
119154

120155
return kNeighbourGroups.fold<Map<num, List<num>>>(
121156
{}, (allLabelsToProbabilities, kNeighbours) {
@@ -142,7 +177,7 @@ class KnnClassifierImpl
142177
// if labels are equiprobable, make the first neighbour's label
143178
// probability equal to 1 and probabilities of the rest neighbour labels -
144179
// equal to 0
145-
_classLabels.forEach((label) {
180+
classLabels.forEach((label) {
146181
final probability = areLabelsEquiprobable
147182
? label == kNeighbours.first.label.first
148183
? 1
@@ -174,7 +209,7 @@ class KnnClassifierImpl
174209
Map<num, num> labelToWeightMapping,
175210
Neighbour<Vector> neighbour,
176211
) {
177-
final weight = _kernel.getWeightByDistance(neighbour.distance);
212+
final weight = kernel.getWeightByDistance(neighbour.distance);
178213
return labelToWeightMapping
179214
..update(
180215
neighbour.label.first,

0 commit comments

Comments
 (0)