Skip to content

Commit f474d43

Browse files
authored
RSS metric added (#156)
1 parent 219faa5 commit f474d43

File tree

8 files changed

+109
-32
lines changed

8 files changed

+109
-32
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
# Changelog
22

3+
## 15.3.0
4+
- RSS metric added
5+
36
## 15.2.4
47
- Documentation for classification metrics improved
58

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
import 'package:ml_algo/src/common/exception/matrix_column_exception.dart';
2+
import 'package:ml_linalg/matrix.dart';
3+
4+
void validateMatrixColumns(Iterable<Matrix> matrices) {
5+
final firstInvalidMatrix = matrices
6+
.firstWhere((matrix) => matrix.columnsNum != 1, orElse: () => null);
7+
8+
if (firstInvalidMatrix == null) {
9+
return;
10+
}
11+
12+
throw MatrixColumnException(firstInvalidMatrix);
13+
}

lib/src/metric/metric_type.dart

Lines changed: 15 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,13 @@
11
/// Metrics for measuring the quality of the prediction.
22
enum MetricType {
3-
///
4-
///
53
/// Mean percentage absolute error (MAPE), a regression metric. The formula
64
/// is:
75
///
86
///
97
/// ![{\mbox{Score}}={\frac{1}{n}}\sum_{{t=1}}^{n}\left|{\frac{y_{t}-\widehat{y}_{t}}{y_{t}}}\right|](https://latex.codecogs.com/gif.latex?%7B%5Cmbox%7BScore%7D%7D%3D%7B%5Cfrac%7B1%7D%7Bn%7D%7D%5Csum_%7B%7Bt%3D1%7D%7D%5E%7Bn%7D%5Cleft%7C%7B%5Cfrac%7By_%7Bt%7D-%5Cwidehat%7By%7D_%7Bt%7D%7D%7By_%7Bt%7D%7D%7D%5Cright%7C)
108
///
119
///
12-
/// where y - original value, y with hat - predicted one
10+
/// where `y` - original value, `y` with hat - predicted one
1311
///
1412
///
1513
/// The less the score produced by the metric, the better the prediction's
@@ -19,9 +17,7 @@ enum MetricType {
1917
/// can produce scores which are greater than 1.
2018
mape,
2119

22-
///
23-
///
24-
/// Root mean squared error (RMSE), a regression metric. The formula is:
20+
/// Root mean squared error (RMSE), a regression metric. The formula is
2521
///
2622
///
2723
/// ![{\mbox{Score}}=\sqrt{\frac{1}{n}\sum_{{t=1}}^{n}({\widehat{y}_{t} - y_{t}})^2}](https://latex.codecogs.com/gif.latex?%7B%5Cmbox%7BScore%7D%7D%3D%5Csqrt%7B%5Cfrac%7B1%7D%7Bn%7D%5Csum_%7B%7Bt%3D1%7D%7D%5E%7Bn%7D%28%7B%5Cwidehat%7By%7D_%7Bt%7D%20-%20y_%7Bt%7D%7D%29%5E2%7D)
@@ -35,8 +31,15 @@ enum MetricType {
3531
/// scores within the range [0, +Infinity]
3632
rmse,
3733

34+
/// Residual sum of squares (RSS), a regression metric. The formula is
35+
///
36+
/// ![{\mbox{Score}}=\sum_{t=1}^{n}(y_{t} - \widehat{y}_{t})^{2}](https://latex.codecogs.com/gif.latex?%7B%5Cmbox%7BScore%7D%7D%3D%5Csum_%7Bt%3D1%7D%5E%7Bn%7D%28y_%7Bt%7D%20-%20%5Cwidehat%7By%7D_%7Bt%7D%29%5E%7B2%7D)
3837
///
38+
/// where `n` is a total amount of labels, `y` is an original value, `y` with
39+
/// hat - predicted one
3940
///
41+
rss,
42+
4043
/// A classification metric. The formula is
4144
///
4245
///
@@ -51,17 +54,15 @@ enum MetricType {
5154
/// quality is. The metric produces scores within the range [0, 1]
5255
accuracy,
5356

54-
///
55-
///
5657
/// A classification metric. The formula for a single-class problem is
5758
///
5859
///
5960
/// ![{\mbox{Score}}=\frac{TP}{TP + FP}](https://latex.codecogs.com/gif.latex?%7B%5Cmbox%7BScore%7D%7D%3D%5Cfrac%7BTP%7D%7BTP%20&plus;%20FP%7D)
6061
///
6162
///
62-
/// where TP is a number of correctly predicted positive labels (true positive),
63-
/// FP - a number of incorrectly predicted positive labels (false positive). In
64-
/// other words, TP + FP is a number of all the labels predicted to be positive
63+
/// where `TP` is a number of correctly predicted positive labels (true positive),
64+
/// `FP` - a number of incorrectly predicted positive labels (false positive). In
65+
/// other words, `TP + FP` is a number of all the labels predicted to be positive
6566
///
6667
/// The formula for a multi-class problem is
6768
///
@@ -76,17 +77,15 @@ enum MetricType {
7677
/// range [0, 1]
7778
precision,
7879

79-
///
80-
///
8180
/// A classification metric. The formula for a single-class problem is
8281
///
8382
///
8483
/// ![{\mbox{Score}}=\frac{TP}{TP + FN}](https://latex.codecogs.com/gif.latex?%7B%5Cmbox%7BScore%7D%7D%3D%5Cfrac%7BTP%7D%7BTP%20&plus;%20FN%7D)
8584
///
8685
///
87-
/// where TP is a number of correctly predicted positive labels (true positive),
88-
/// FN - a number of incorrectly predicted negative labels (false negative). In
89-
/// other words, TP + FN is a total amount of positive labels for a class in
86+
/// where `TP` is a number of correctly predicted positive labels (true positive),
87+
/// `FN` - a number of incorrectly predicted negative labels (false negative). In
88+
/// other words, `TP + FN` is a total amount of positive labels for a class in
9089
/// the given data
9190
///
9291
/// The formula for a multi-class problem is

lib/src/metric/regression/mape.dart

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import 'package:ml_algo/src/common/exception/matrix_column_exception.dart';
1+
import 'package:ml_algo/src/helpers/validate_matrix_columns.dart';
22
import 'package:ml_algo/src/metric/metric.dart';
33
import 'package:ml_linalg/linalg.dart';
44

@@ -7,16 +7,12 @@ class MapeMetric implements Metric {
77

88
@override
99
double getScore(Matrix predictedLabels, Matrix originalLabels) {
10-
if (predictedLabels.columnsNum != 1) {
11-
throw MatrixColumnException(predictedLabels);
12-
}
10+
validateMatrixColumns([predictedLabels, originalLabels]);
1311

14-
if (originalLabels.columnsNum != 1) {
15-
throw MatrixColumnException(originalLabels);
16-
}
17-
18-
final predicted = predictedLabels.getColumn(0);
19-
final original = originalLabels.getColumn(0);
12+
final predicted = predictedLabels
13+
.toVector();
14+
final original = originalLabels
15+
.toVector();
2016

2117
return ((original - predicted) / original)
2218
.abs()

lib/src/metric/regression/rmse.dart

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import 'dart:math' as math;
22

3+
import 'package:ml_algo/src/helpers/validate_matrix_columns.dart';
34
import 'package:ml_algo/src/metric/metric.dart';
45
import 'package:ml_linalg/linalg.dart';
56

@@ -8,16 +9,14 @@ class RmseMetric implements Metric {
89

910
@override
1011
double getScore(Matrix predictedLabels, Matrix origLabels) {
11-
if (predictedLabels.columnsNum != 1 || origLabels.columnsNum != 1) {
12-
throw Exception('Both predicted labels and original labels have to be '
13-
'a matrix-column');
14-
}
12+
validateMatrixColumns([predictedLabels, origLabels]);
1513

1614
final predicted = predictedLabels
1715
.getColumn(0);
1816
final original = origLabels
1917
.getColumn(0);
2018

21-
return math.sqrt(((predicted - original).pow(2)).mean());
19+
return math
20+
.sqrt(((predicted - original).pow(2)).mean());
2221
}
2322
}

lib/src/metric/regression/rss.dart

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
import 'package:ml_algo/src/helpers/validate_matrix_columns.dart';
2+
import 'package:ml_algo/src/metric/metric.dart';
3+
import 'package:ml_linalg/matrix.dart';
4+
5+
class RssMetric implements Metric {
6+
const RssMetric();
7+
8+
@override
9+
double getScore(Matrix predictedLabels, Matrix origLabels) {
10+
validateMatrixColumns([predictedLabels, origLabels]);
11+
12+
final predicted = predictedLabels
13+
.toVector();
14+
final original = origLabels
15+
.toVector();
16+
17+
return (predicted - original)
18+
.pow(2)
19+
.sum();
20+
}
21+
}

pubspec.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name: ml_algo
22
description: Machine learning algorithms, Machine learning models performance evaluation functionality
3-
version: 15.2.4
3+
version: 15.3.0
44
homepage: https://github.com/gyrdym/ml_algo
55

66
environment:

test/metric/regression/rss_test.dart

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
import 'package:ml_algo/src/common/exception/matrix_column_exception.dart';
2+
import 'package:ml_algo/src/metric/regression/rss.dart';
3+
import 'package:ml_linalg/matrix.dart';
4+
import 'package:test/test.dart';
5+
6+
void main() {
7+
group('RssMetric', () {
8+
const metric = RssMetric();
9+
final predictedLabels = Matrix.column([12, 18, 12, 90, 78]);
10+
final originalLabels = Matrix.column([10, 20 , 30, 60, 70]);
11+
12+
test('should throw an error if predicted labels matrix\'s columns count '
13+
'is empty', () {
14+
final actual = () => metric.getScore(Matrix.empty(), originalLabels);
15+
16+
expect(actual, throwsA(isA<MatrixColumnException>()));
17+
});
18+
19+
test('should throw an error if predicted labels matrix\'s columns count '
20+
'is greater than 1', () {
21+
final actual = () => metric.getScore(Matrix.row([1, 2]), originalLabels);
22+
23+
expect(actual, throwsA(isA<MatrixColumnException>()));
24+
});
25+
26+
test('should throw an error if original labels matrix\'s columns count '
27+
'is empty', () {
28+
final actual = () => metric.getScore(predictedLabels, Matrix.empty());
29+
30+
expect(actual, throwsA(isA<MatrixColumnException>()));
31+
});
32+
33+
test('should throw an error if original labels matrix\'s columns count '
34+
'is greater than 1', () {
35+
final actual = () => metric.getScore(predictedLabels, Matrix.row([1, 2]));
36+
37+
expect(actual, throwsA(isA<MatrixColumnException>()));
38+
});
39+
40+
test('should count score', () {
41+
final actual = metric.getScore(predictedLabels, originalLabels);
42+
43+
expect(actual, 1296);
44+
});
45+
});
46+
}

0 commit comments

Comments
 (0)