Skip to content

Commit a9c744f

Browse files
authored
Models retraining functionality added (#169)
1 parent b8126e7 commit a9c744f

File tree

87 files changed

+2378
-1321
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

87 files changed

+2378
-1321
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
# Changelog
22

3+
## 15.6.0
4+
- Models retraining functionality added
5+
36
## 15.5.0
47
- `KnnClassifier`, `DecisionTreeClassifier`, `LogisticRegressor`, `SoftmaxRegressor`, `KnnRegressor`, `LinearRegressor`
58
- hyperparameters added to the interfaces

README.md

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -210,8 +210,8 @@ After that we can simply read the model from the file and make predictions:
210210
import 'dart:io';
211211
212212
final file = File(fileName);
213-
final encodedData = await file.readAsString();
214-
final classifier = LogisticRegressor.fromJson(encodedData);
213+
final encodedModel = await file.readAsString();
214+
final classifier = LogisticRegressor.fromJson(encodedModel);
215215
final unlabelledData = await fromCsv('some_unlabelled_data.csv');
216216
final prediction = classifier.predict(unlabelledData);
217217
@@ -226,6 +226,15 @@ print(prediction.rows); // [
226226
// ]
227227
```
228228

229+
Please note that all the hyperparameters that we used to generate the model are persisted as the model's readonly
230+
fields, and we can access it anytime:
231+
232+
```dart
233+
print(classifier.iterationsLimit);
234+
print(classifier.probabilityThreshold);
235+
// and so on
236+
```
237+
229238
All the code above all together:
230239

231240
````dart
@@ -265,6 +274,25 @@ void main() async {
265274
}
266275
````
267276

277+
Someday our previously shining model can degrade in terms of prediction accuracy - in this case we can retrain it.
278+
Retraining means simply re-running the same learning algorithm that was used to generate our current model
279+
keeping the same hyperparameters but using a new data set with the same features:
280+
281+
```dart
282+
import 'dart:io';
283+
284+
final encodedModel = await file.readAsString();
285+
final classifier = LogisticRegressor.fromJson(encodedModel);
286+
287+
// ...
288+
// here we do something and realize that our classifier performance is not so good
289+
// ...
290+
291+
final newData = await fromCsv('path/to/dataset/with/new/data/to/retrain/the/classifier');
292+
final retrainedClassifier = classifier.retrain(newData);
293+
294+
```
295+
268296
The workflow with other predictors (SoftmaxRegressor, DecisionTreeClassifier and so on) is quite similar to the described
269297
above for LogisticRegressor, feel free to experiment with other models.
270298

lib/src/classifier/decision_tree_classifier/_helper/create_decision_tree_classifier.dart

Lines changed: 0 additions & 39 deletions
This file was deleted.

lib/src/classifier/decision_tree_classifier/_helper/create_decision_tree_classifier_from_json.dart

Lines changed: 0 additions & 15 deletions
This file was deleted.

lib/src/classifier/decision_tree_classifier/_init_module.dart

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import 'package:injector/injector.dart';
12
import 'package:ml_algo/src/classifier/decision_tree_classifier/_injector.dart';
23
import 'package:ml_algo/src/classifier/decision_tree_classifier/decision_tree_classifier_factory.dart';
34
import 'package:ml_algo/src/classifier/decision_tree_classifier/decision_tree_classifier_factory_impl.dart';
@@ -22,10 +23,10 @@ import 'package:ml_algo/src/tree_trainer/splitter/splitter_factory_impl.dart';
2223
import 'package:ml_algo/src/tree_trainer/tree_trainer_factory.dart';
2324
import 'package:ml_algo/src/tree_trainer/tree_trainer_factory_impl.dart';
2425

25-
void initDecisionTreeModule() {
26+
Injector initDecisionTreeModule() {
2627
initCommonModule();
2728

28-
decisionTreeInjector
29+
return decisionTreeInjector
2930
..registerSingletonIf<DistributionCalculatorFactory>(
3031
() => const DistributionCalculatorFactoryImpl())
3132

lib/src/classifier/decision_tree_classifier/decision_tree_classifier.dart

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
import 'package:ml_algo/src/classifier/classifier.dart';
2-
import 'package:ml_algo/src/classifier/decision_tree_classifier/_helper/create_decision_tree_classifier.dart';
3-
import 'package:ml_algo/src/classifier/decision_tree_classifier/_helper/create_decision_tree_classifier_from_json.dart';
42
import 'package:ml_algo/src/classifier/decision_tree_classifier/_init_module.dart';
3+
import 'package:ml_algo/src/classifier/decision_tree_classifier/decision_tree_classifier_factory.dart';
54
import 'package:ml_algo/src/common/serializable/serializable.dart';
65
import 'package:ml_algo/src/model_selection/assessable.dart';
6+
import 'package:ml_algo/src/predictor/retrainable.dart';
77
import 'package:ml_dataframe/ml_dataframe.dart';
88
import 'package:ml_linalg/dtype.dart';
99

@@ -16,8 +16,13 @@ import 'package:ml_linalg/dtype.dart';
1616
/// decision tree learning. Once a decision tree learned, it may be used to
1717
/// classify new samples with the same features that were used to learn the
1818
/// tree.
19-
abstract class DecisionTreeClassifier implements
20-
Classifier, Assessable, Serializable {
19+
abstract class DecisionTreeClassifier
20+
implements
21+
Assessable,
22+
Serializable,
23+
Retrainable,
24+
Classifier {
25+
2126
/// Parameters:
2227
///
2328
/// [trainData] A [DataFrame] with observations that will be used by the
@@ -45,18 +50,17 @@ abstract class DecisionTreeClassifier implements
4550
int minSamplesCount,
4651
int maxDepth,
4752
DType dtype = DType.float32,
48-
}) {
49-
initDecisionTreeModule();
50-
51-
return createDecisionTreeClassifier(
53+
}) =>
54+
initDecisionTreeModule()
55+
.get<DecisionTreeClassifierFactory>()
56+
.create(
5257
trainData,
53-
targetName,
5458
minError,
5559
minSamplesCount,
5660
maxDepth,
61+
targetName,
5762
dtype,
5863
);
59-
}
6064

6165
/// Restores previously fitted classifier instance from the given [json]
6266
///
@@ -93,11 +97,10 @@ abstract class DecisionTreeClassifier implements
9397
/// // here you can use previously fitted restored classifier to make
9498
/// // some prediction, e.g. via `DecisionTreeClassifier.predict(...)`;
9599
/// ````
96-
factory DecisionTreeClassifier.fromJson(String json) {
97-
initDecisionTreeModule();
98-
99-
return createDecisionTreeClassifierFromJson(json);
100-
}
100+
factory DecisionTreeClassifier.fromJson(String json) =>
101+
initDecisionTreeModule()
102+
.get<DecisionTreeClassifierFactory>()
103+
.fromJson(json);
101104

102105
/// A minimal error on a single decision tree node. It is used as a
103106
/// stop criteria to avoid farther decision's tree node splitting: if the
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
const decisionTreeClassifierJsonSchemaVersion = 1;
Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,16 @@
11
import 'package:ml_algo/ml_algo.dart';
2-
import 'package:ml_algo/src/tree_trainer/tree_node/tree_node.dart';
2+
import 'package:ml_dataframe/ml_dataframe.dart';
33
import 'package:ml_linalg/dtype.dart';
44

55
abstract class DecisionTreeClassifierFactory {
66
DecisionTreeClassifier create(
7+
DataFrame trainData,
78
num minError,
89
int minSamplesCount,
910
int maxDepth,
10-
TreeNode root,
1111
String targetName,
1212
DType dtype,
1313
);
14+
15+
DecisionTreeClassifier fromJson(String json);
1416
}
Lines changed: 47 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,64 @@
1+
import 'dart:convert';
2+
13
import 'package:ml_algo/src/classifier/decision_tree_classifier/decision_tree_classifier.dart';
24
import 'package:ml_algo/src/classifier/decision_tree_classifier/decision_tree_classifier_factory.dart';
35
import 'package:ml_algo/src/classifier/decision_tree_classifier/decision_tree_classifier_impl.dart';
4-
import 'package:ml_algo/src/tree_trainer/tree_node/tree_node.dart';
6+
import 'package:ml_algo/src/helpers/validate_train_data.dart';
7+
import 'package:ml_algo/src/helpers/validate_tree_solver_max_depth.dart';
8+
import 'package:ml_algo/src/helpers/validate_tree_solver_min_error.dart';
9+
import 'package:ml_algo/src/helpers/validate_tree_solver_min_samples_count.dart';
10+
import 'package:ml_algo/src/tree_trainer/_helpers/create_decision_tree_trainer.dart';
11+
import 'package:ml_dataframe/ml_dataframe.dart';
512
import 'package:ml_linalg/dtype.dart';
613

714
class DecisionTreeClassifierFactoryImpl implements
815
DecisionTreeClassifierFactory {
16+
917
const DecisionTreeClassifierFactoryImpl();
1018

1119
@override
1220
DecisionTreeClassifier create(
21+
DataFrame trainData,
1322
num minError,
1423
int minSamplesCount,
1524
int maxDepth,
16-
TreeNode root,
1725
String targetName,
1826
DType dtype,
19-
) => DecisionTreeClassifierImpl(
20-
minError,
21-
minSamplesCount,
22-
maxDepth,
23-
root,
24-
targetName,
25-
dtype,
26-
);
27+
) {
28+
validateTrainData(trainData, [targetName]);
29+
validateTreeSolverMinError(minError);
30+
validateTreeSolversMinSamplesCount(minSamplesCount);
31+
validateTreeSolverMaxDepth(maxDepth);
32+
33+
final trainer = createDecisionTreeTrainer(
34+
trainData,
35+
targetName,
36+
minError,
37+
minSamplesCount,
38+
maxDepth,
39+
);
40+
final treeRootNode = trainer
41+
.train(trainData.toMatrix(dtype));
42+
43+
return DecisionTreeClassifierImpl(
44+
minError,
45+
minSamplesCount,
46+
maxDepth,
47+
treeRootNode,
48+
targetName,
49+
dtype,
50+
);
51+
}
52+
53+
@override
54+
DecisionTreeClassifier fromJson(String json) {
55+
if (json.isEmpty) {
56+
throw Exception('Provided JSON object is empty, please provide a proper '
57+
'JSON object');
58+
}
59+
60+
final decodedJson = jsonDecode(json) as Map<String, dynamic>;
61+
62+
return DecisionTreeClassifierImpl.fromJson(decodedJson);
63+
}
2764
}

lib/src/classifier/decision_tree_classifier/decision_tree_classifier_impl.dart

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,12 @@
11
import 'package:json_annotation/json_annotation.dart';
22
import 'package:ml_algo/src/classifier/_mixins/assessable_classifier_mixin.dart';
3+
import 'package:ml_algo/src/classifier/decision_tree_classifier/_injector.dart';
34
import 'package:ml_algo/src/classifier/decision_tree_classifier/decision_tree_classifier.dart';
5+
import 'package:ml_algo/src/classifier/decision_tree_classifier/decision_tree_classifier_constants.dart';
6+
import 'package:ml_algo/src/classifier/decision_tree_classifier/decision_tree_classifier_factory.dart';
47
import 'package:ml_algo/src/classifier/decision_tree_classifier/decision_tree_json_keys.dart';
8+
import 'package:ml_algo/src/common/constants/common_json_keys.dart';
9+
import 'package:ml_algo/src/common/exception/outdated_json_schema_exception.dart';
510
import 'package:ml_algo/src/common/json_converter/dtype_json_converter.dart';
611
import 'package:ml_algo/src/common/serializable/serializable_mixin.dart';
712
import 'package:ml_algo/src/tree_trainer/leaf_label/leaf_label.dart';
@@ -31,6 +36,9 @@ class DecisionTreeClassifierImpl
3136
this.treeRootNode,
3237
this.targetColumnName,
3338
this.dtype,
39+
{
40+
this.schemaVersion = decisionTreeClassifierJsonSchemaVersion,
41+
}
3442
);
3543

3644
factory DecisionTreeClassifierImpl.fromJson(Map<String, dynamic> json) =>
@@ -76,6 +84,12 @@ class DecisionTreeClassifierImpl
7684
@JsonKey(includeIfNull: false)
7785
final num negativeLabel = null;
7886

87+
@override
88+
@JsonKey(name: jsonSchemaVersionJsonKey)
89+
final schemaVersion;
90+
91+
final _outdatedSchemaVersions = [null];
92+
7993
@override
8094
DataFrame predict(DataFrame features) {
8195
final predictedLabels = features
@@ -134,4 +148,22 @@ class DecisionTreeClassifierImpl
134148

135149
throw Exception('Given sample does not conform any splitting condition');
136150
}
151+
152+
@override
153+
DecisionTreeClassifier retrain(DataFrame data) {
154+
if (_outdatedSchemaVersions.contains(schemaVersion)) {
155+
throw OutdatedJsonSchemaException();
156+
}
157+
158+
return decisionTreeInjector
159+
.get<DecisionTreeClassifierFactory>()
160+
.create(
161+
data,
162+
minError,
163+
minSamplesCount,
164+
maxDepth,
165+
targetColumnName,
166+
dtype,
167+
);
168+
}
137169
}

0 commit comments

Comments
 (0)