Skip to content

Commit b8126e7

Browse files
authored
Hyperparameters added to models interfaces (#168)
1 parent d90cff0 commit b8126e7

File tree

90 files changed

+2606
-188
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

90 files changed

+2606
-188
lines changed

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
# Changelog
22

3+
## 15.5.0
4+
- `KnnClassifier`, `DecisionTreeClassifier`, `LogisticRegressor`, `SoftmaxRegressor`, `KnnRegressor`, `LinearRegressor`
5+
- hyperparameters added to the interfaces
6+
37
## 15.4.1
48
- `DTypeJsonConverter` added
59
- `MatrixJsonConverter` added
File renamed without changes.
File renamed without changes.
File renamed without changes.
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
import 'dart:io';
2+
3+
import 'package:ml_algo/src/classifier/decision_tree_classifier/decision_tree_classifier.dart';
4+
import 'package:ml_linalg/dtype.dart';
5+
import 'package:test/test.dart';
6+
7+
void main() {
8+
group('DecisionTreeClassifier', () {
9+
test('should deserialize v0 schema version', () async {
10+
final file = File('e2e/decision_tree_classifier/decision_tree_classifier_v0.json');
11+
final encodedData = await file.readAsString();
12+
final classifier = DecisionTreeClassifier.fromJson(encodedData);
13+
14+
expect(classifier.dtype, DType.float32);
15+
expect(classifier.targetNames, ['Species']);
16+
expect(classifier.positiveLabel, isNull);
17+
expect(classifier.negativeLabel, isNull);
18+
expect(classifier.minSamplesCount, isNull);
19+
expect(classifier.maxDepth, isNull);
20+
expect(classifier.minError, isNull);
21+
});
22+
});
23+
}

e2e/decision_tree_classifier_test.dart renamed to e2e/decision_tree_classifier/decision_tree_classifier_test.dart

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ import 'package:ml_preprocessing/ml_preprocessing.dart';
66
import 'package:test/test.dart';
77

88
Future<Vector> evaluateClassifier(MetricType metric, DType dtype) async {
9-
final samples = (await fromCsv('e2e/datasets/iris.csv'))
9+
final samples = (await fromCsv('e2e/_datasets/iris.csv'))
1010
.shuffle()
1111
.dropSeries(seriesNames: ['Id']);
1212
final pipeline = Pipeline(samples, [
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
{"DT":"F32","T":"Species","R":{"CN":[{"LB":{"V":2.0,"P":1.0000000000000004},"PT":"LT","SV":2.449999988079071,"SI":2,"LV":1},{"CN":[{"LB":{"V":0.0,"P":0.9230769230769238},"PT":"LT","SV":1.7000000476837158,"SI":3,"LV":2},{"LB":{"V":1.0,"P":0.9583333333333339},"PT":"GET","SV":1.7000000476837158,"SI":3,"LV":2}],"PT":"GET","SV":2.449999988079071,"SI":2,"LV":1}],"LV":0}}
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
import 'dart:io';
2+
3+
import 'package:ml_algo/ml_algo.dart';
4+
import 'package:ml_algo/src/knn_kernel/kernel_type.dart';
5+
import 'package:ml_linalg/distance.dart';
6+
import 'package:ml_linalg/dtype.dart';
7+
import 'package:test/test.dart';
8+
9+
void main() {
10+
group('KnnClassifier', () {
11+
test('should deserialize v0 schema version', () async {
12+
final file = File('e2e/knn_classifier/knn_classifier_v0.json');
13+
final encodedData = await file.readAsString();
14+
final classifier = KnnClassifier.fromJson(encodedData);
15+
16+
expect(classifier.distanceType, Distance.euclidean);
17+
expect(classifier.kernelType, KernelType.gaussian);
18+
expect(classifier.k, 5);
19+
expect(classifier.negativeLabel, isNull);
20+
expect(classifier.positiveLabel, isNull);
21+
expect(classifier.dtype, DType.float32);
22+
expect(classifier.targetNames, ['Species']);
23+
});
24+
});
25+
}

e2e/knn_classifier_test.dart renamed to e2e/knn_classifier/knn_classifier_test.dart

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ import 'package:ml_preprocessing/ml_preprocessing.dart';
66
import 'package:test/test.dart';
77

88
Future<Vector> evaluateKnnClassifier(MetricType metric, DType dtype) async {
9-
final samples = (await fromCsv('e2e/datasets/iris.csv'))
9+
final samples = (await fromCsv('e2e/_datasets/iris.csv'))
1010
.shuffle()
1111
.dropSeries(seriesNames: ['Id']);
1212
final targetName = 'Species';

0 commit comments

Comments
 (0)