Skip to content

CrossValidator: isolates used #164

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions lib/src/classifier/_mixins/assessable_classifier_mixin.dart
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
import 'package:ml_algo/src/classifier/classifier.dart';
import 'package:ml_algo/src/di/injector.dart';
import 'package:ml_algo/src/metric/metric_type.dart';
import 'package:ml_algo/src/model_selection/assessable.dart';
import 'package:ml_algo/src/model_selection/model_assessor/model_assessor.dart';
import 'package:ml_dataframe/ml_dataframe.dart';

mixin AssessableClassifierMixin implements Assessable, Classifier {
mixin AssessableClassifierMixin implements Classifier {
@override
double assess(
DataFrame samples,
Expand Down
2 changes: 1 addition & 1 deletion lib/src/classifier/classifier.dart
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ import 'package:ml_dataframe/ml_dataframe.dart';

/// An interface for any classifier (linear, non-linear, parametric,
/// non-parametric, etc.)
abstract class Classifier extends Predictor {
abstract class Classifier implements Predictor {
/// Returns predicted distribution of probabilities for each observation in
/// the passed [testFeatures]
DataFrame predictProbabilities(DataFrame testFeatures);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
import 'package:ml_algo/src/classifier/classifier.dart';
import 'package:ml_algo/src/classifier/decision_tree_classifier/_init_module.dart';
import 'package:ml_algo/src/classifier/decision_tree_classifier/decision_tree_classifier_factory.dart';
import 'package:ml_algo/src/common/serializable/serializable.dart';
import 'package:ml_algo/src/model_selection/assessable.dart';
import 'package:ml_algo/src/predictor/retrainable.dart';
import 'package:ml_dataframe/ml_dataframe.dart';
import 'package:ml_linalg/dtype.dart';

Expand All @@ -16,12 +13,7 @@ import 'package:ml_linalg/dtype.dart';
/// decision tree learning. Once a decision tree learned, it may be used to
/// classify new samples with the same features that were used to learn the
/// tree.
abstract class DecisionTreeClassifier
implements
Assessable,
Serializable,
Retrainable,
Classifier {
abstract class DecisionTreeClassifier implements Classifier {

/// Parameters:
///
Expand Down
10 changes: 1 addition & 9 deletions lib/src/classifier/knn_classifier/knn_classifier.dart
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
import 'package:ml_algo/src/classifier/classifier.dart';
import 'package:ml_algo/src/classifier/knn_classifier/_init_module.dart';
import 'package:ml_algo/src/classifier/knn_classifier/knn_classifier_factory.dart';
import 'package:ml_algo/src/common/serializable/serializable.dart';
import 'package:ml_algo/src/knn_kernel/kernel_type.dart';
import 'package:ml_algo/src/model_selection/assessable.dart';
import 'package:ml_algo/src/predictor/retrainable.dart';
import 'package:ml_dataframe/ml_dataframe.dart';
import 'package:ml_linalg/distance.dart';
import 'package:ml_linalg/dtype.dart';
Expand All @@ -21,12 +18,7 @@ import 'package:ml_linalg/dtype.dart';
/// imprecise result. Thus the weighted version of KNN algorithm is used in the
/// classifier. To get weight of a particular observation one may use a kernel
/// function.
abstract class KnnClassifier
implements
Assessable,
Serializable,
Retrainable,
Classifier {
abstract class KnnClassifier implements Classifier {
/// Parameters:
///
/// [trainData] Labelled observations. Must contain [targetName] column.
Expand Down
10 changes: 1 addition & 9 deletions lib/src/classifier/logistic_regressor/logistic_regressor.dart
Original file line number Diff line number Diff line change
@@ -1,13 +1,10 @@
import 'package:ml_algo/src/classifier/linear_classifier.dart';
import 'package:ml_algo/src/classifier/logistic_regressor/_init_module.dart';
import 'package:ml_algo/src/classifier/logistic_regressor/logistic_regressor_factory.dart';
import 'package:ml_algo/src/common/serializable/serializable.dart';
import 'package:ml_algo/src/linear_optimizer/gradient_optimizer/learning_rate_generator/learning_rate_type.dart';
import 'package:ml_algo/src/linear_optimizer/initial_coefficients_generator/initial_coefficients_type.dart';
import 'package:ml_algo/src/linear_optimizer/linear_optimizer_type.dart';
import 'package:ml_algo/src/linear_optimizer/regularization_type.dart';
import 'package:ml_algo/src/model_selection/assessable.dart';
import 'package:ml_algo/src/predictor/retrainable.dart';
import 'package:ml_dataframe/ml_dataframe.dart';
import 'package:ml_linalg/dtype.dart';
import 'package:ml_linalg/vector.dart';
Expand All @@ -19,12 +16,7 @@ import 'package:ml_linalg/vector.dart';
/// In other words, the regressor iteratively tries to select coefficients
/// that makes combination of passed features and these coefficients most
/// likely.
abstract class LogisticRegressor
implements
Assessable,
Serializable,
Retrainable,
LinearClassifier {
abstract class LogisticRegressor implements LinearClassifier {

/// Parameters:
///
Expand Down
10 changes: 1 addition & 9 deletions lib/src/classifier/softmax_regressor/softmax_regressor.dart
Original file line number Diff line number Diff line change
@@ -1,13 +1,10 @@
import 'package:ml_algo/src/classifier/linear_classifier.dart';
import 'package:ml_algo/src/classifier/softmax_regressor/_init_module.dart';
import 'package:ml_algo/src/classifier/softmax_regressor/softmax_regressor_factory.dart';
import 'package:ml_algo/src/common/serializable/serializable.dart';
import 'package:ml_algo/src/linear_optimizer/gradient_optimizer/learning_rate_generator/learning_rate_type.dart';
import 'package:ml_algo/src/linear_optimizer/initial_coefficients_generator/initial_coefficients_type.dart';
import 'package:ml_algo/src/linear_optimizer/linear_optimizer_type.dart';
import 'package:ml_algo/src/linear_optimizer/regularization_type.dart';
import 'package:ml_algo/src/model_selection/assessable.dart';
import 'package:ml_algo/src/predictor/retrainable.dart';
import 'package:ml_dataframe/ml_dataframe.dart';
import 'package:ml_linalg/dtype.dart';
import 'package:ml_linalg/matrix.dart';
Expand All @@ -26,12 +23,7 @@ import 'package:ml_linalg/matrix.dart';
///
/// Also it is worth to mention that the algorithm is a generalization of
/// [Logistic regression](https://en.wikipedia.org/wiki/Logistic_regression))
abstract class SoftmaxRegressor
implements
Assessable,
Serializable,
Retrainable,
LinearClassifier {
abstract class SoftmaxRegressor implements LinearClassifier {

/// Parameters:
///
Expand Down
1 change: 1 addition & 0 deletions lib/src/common/constants/common_json_keys.dart
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
const jsonSchemaVersionJsonKey = '\$V';
const predictorTypeJsonKey = '\$PT';
7 changes: 6 additions & 1 deletion lib/src/di/common/init_common_module.dart
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ import 'package:ml_algo/src/model_selection/model_assessor/classifier_assessor.d
import 'package:ml_algo/src/model_selection/model_assessor/model_assessor.dart';
import 'package:ml_algo/src/model_selection/model_assessor/regressor_assessor.dart';
import 'package:ml_algo/src/predictor/predictor.dart';
import 'package:ml_algo/src/service/worker_manager/worker_manager.dart';
import 'package:ml_algo/src/service/worker_manager/worker_manager_impl.dart';
import 'package:ml_dataframe/ml_dataframe.dart';
import 'package:ml_preprocessing/ml_preprocessing.dart';

Expand Down Expand Up @@ -76,5 +78,8 @@ void initCommonModule() {
RegressorAssessor(
injector.get<MetricFactory>(),
featuresTargetSplit,
));
))

..registerSingletonIf<WorkerManager>(
() => WorkerManagerImpl());
}
6 changes: 6 additions & 0 deletions lib/src/metric/metric_type_encoded_values.dart
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
const mapeMetricTypeEncodedValue = 'MAPE';
const rmseMetricTypeEncodedValue = 'RMSE';
const rssMetricTypeEncodedValue = 'RSS';
const accuracyMetricTypeEncodedValue = 'ACR';
const precisionMetricTypeEncodedValue = 'PRC';
const recallMetricTypeEncodedValue = 'RC';
59 changes: 59 additions & 0 deletions lib/src/metric/metric_type_json_converter.dart
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import 'package:json_annotation/json_annotation.dart';
import 'package:ml_algo/src/metric/metric_type.dart';
import 'package:ml_algo/src/metric/metric_type_encoded_values.dart';

class MetricTypeJsonConverter implements JsonConverter<MetricType, String> {
const MetricTypeJsonConverter();

@override
MetricType fromJson(String json) {
switch (json) {
case mapeMetricTypeEncodedValue:
return MetricType.mape;

case rmseMetricTypeEncodedValue:
return MetricType.rmse;

case rssMetricTypeEncodedValue:
return MetricType.rss;

case accuracyMetricTypeEncodedValue:
return MetricType.accuracy;

case precisionMetricTypeEncodedValue:
return MetricType.precision;

case recallMetricTypeEncodedValue:
return MetricType.recall;

default:
throw UnsupportedError('Unsupported encoded metric value - $json');
}
}

@override
String toJson(MetricType metricType) {
switch (metricType) {
case MetricType.mape:
return mapeMetricTypeEncodedValue;

case MetricType.rmse:
return rmseMetricTypeEncodedValue;

case MetricType.rss:
return rssMetricTypeEncodedValue;

case MetricType.accuracy:
return accuracyMetricTypeEncodedValue;

case MetricType.precision:
return precisionMetricTypeEncodedValue;

case MetricType.recall:
return recallMetricTypeEncodedValue;

default:
throw UnsupportedError('Unsupported metric type - $metricType');
}
}
}
12 changes: 7 additions & 5 deletions lib/src/model_selection/_init_module.dart
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import 'package:ml_algo/src/di/common/init_common_module.dart';
import 'package:ml_algo/src/model_selection/_injector.dart';
import 'package:ml_algo/src/model_selection/split_indices_provider/split_indices_provider_factory.dart';
import 'package:ml_algo/src/model_selection/split_indices_provider/split_indices_provider_factory_impl.dart';
import 'package:ml_algo/src/extensions/injector.dart';

void initModelSelectionModule() {
if (!modelSelectionInjector.exists<SplitIndicesProviderFactory>()) {
modelSelectionInjector
..registerSingleton<SplitIndicesProviderFactory>(
() => const SplitIndicesProviderFactoryImpl());
}
initCommonModule();

modelSelectionInjector
..registerSingletonIf<SplitIndicesProviderFactory>(
() => const SplitIndicesProviderFactoryImpl());
}
10 changes: 0 additions & 10 deletions lib/src/model_selection/assessable.dart

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import 'package:ml_algo/src/model_selection/cross_validator/cross_validator_isolate_message.dart';

num assessPredictor(Map<String, dynamic> encodedMessage) {
final message = CrossValidatorIsolateMessage
.fromJson(encodedMessage);
final predictor = message
.predictorPrototype
.retrain(message.trainData);

return predictor
.assess(message.testData, message.metricType);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import 'dart:convert';

import 'package:ml_algo/ml_algo.dart';
import 'package:ml_algo/src/classifier/logistic_regressor/logistic_regressor.dart';
import 'package:ml_algo/src/classifier/softmax_regressor/softmax_regressor.dart';
import 'package:ml_algo/src/model_selection/cross_validator/predictor_type.dart';
import 'package:ml_algo/src/predictor/predictor.dart';

Predictor decodePredictor(
PredictorType predictorType,
Map<String, dynamic> json,
) {
switch (predictorType) {
case PredictorType.softmaxRegressor:
return SoftmaxRegressor.fromJson(jsonEncode(json));

case PredictorType.logisticRegressor:
return LogisticRegressor.fromJson(jsonEncode(json));

case PredictorType.knnClassifier:
return KnnClassifier.fromJson(jsonEncode(json));

case PredictorType.decisionTreeClassifier:
return DecisionTreeClassifier.fromJson(jsonEncode(json));

case PredictorType.knnRegressor:
return KnnRegressor.fromJson(jsonEncode(json));

case PredictorType.linearRegressor:
return LinearRegressor.fromJson(jsonEncode(json));

default:
throw UnsupportedError('Unsupported predictor type - ${predictorType}');
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import 'package:ml_algo/src/model_selection/cross_validator/predictor_type.dart';
import 'package:ml_algo/src/model_selection/cross_validator/predictor_type_encoded_values.dart';

PredictorType decodePredictorType(String json) {
switch (json) {
case knnRegressorPredictorTypeEncodedValue:
return PredictorType.knnRegressor;

case linearRegressorPredictorTypeEncodedValue:
return PredictorType.linearRegressor;

case decisionTreeClassifierPredictorTypeEncodedValue:
return PredictorType.decisionTreeClassifier;

case knnClassifierPredictorTypeEncodedValue:
return PredictorType.knnClassifier;

case logisticRegressorPredictorTypeEncodedValue:
return PredictorType.logisticRegressor;

case softmaxRegressorPredictorTypeEncodedValue:
return PredictorType.softmaxRegressor;

default:
throw UnsupportedError('Unsupported predictor encoded value - $json');
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import 'package:ml_algo/src/model_selection/cross_validator/predictor_type.dart';
import 'package:ml_algo/src/model_selection/cross_validator/predictor_type_encoded_values.dart';

String encodePredictorType(PredictorType predictorType) {
switch (predictorType) {
case PredictorType.knnRegressor:
return knnRegressorPredictorTypeEncodedValue;

case PredictorType.linearRegressor:
return linearRegressorPredictorTypeEncodedValue;

case PredictorType.decisionTreeClassifier:
return decisionTreeClassifierPredictorTypeEncodedValue;

case PredictorType.knnClassifier:
return knnClassifierPredictorTypeEncodedValue;

case PredictorType.logisticRegressor:
return logisticRegressorPredictorTypeEncodedValue;

case PredictorType.softmaxRegressor:
return softmaxRegressorPredictorTypeEncodedValue;

default:
throw UnsupportedError('Unsupported predictor type - $predictorType');
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import 'package:ml_algo/src/classifier/decision_tree_classifier/decision_tree_classifier.dart';
import 'package:ml_algo/src/classifier/knn_classifier/knn_classifier.dart';
import 'package:ml_algo/src/classifier/logistic_regressor/logistic_regressor.dart';
import 'package:ml_algo/src/classifier/softmax_regressor/softmax_regressor.dart';
import 'package:ml_algo/src/model_selection/cross_validator/predictor_type.dart';
import 'package:ml_algo/src/predictor/predictor.dart';
import 'package:ml_algo/src/regressor/knn_regressor/knn_regressor.dart';
import 'package:ml_algo/src/regressor/linear_regressor/linear_regressor.dart';

PredictorType getPredictorType(Predictor predictor) {
if (predictor is LogisticRegressor) {
return PredictorType.logisticRegressor;
}

if (predictor is SoftmaxRegressor) {
return PredictorType.softmaxRegressor;
}

if (predictor is DecisionTreeClassifier) {
return PredictorType.decisionTreeClassifier;
}

if (predictor is KnnClassifier) {
return PredictorType.knnClassifier;
}

if (predictor is LinearRegressor) {
return PredictorType.linearRegressor;
}

if (predictor is KnnRegressor) {
return PredictorType.knnRegressor;
}

throw UnsupportedError(
'Unsupported predictor type - ${predictor.runtimeType}');
}
5 changes: 5 additions & 0 deletions lib/src/model_selection/cross_validator/_init_module.dart
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
import 'package:ml_algo/src/di/common/init_common_module.dart';

void initCrossValidatorModule() {
initCommonModule();
}
3 changes: 3 additions & 0 deletions lib/src/model_selection/cross_validator/_injector.dart
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
import 'package:injector/injector.dart';

final crossValidatorInjector = Injector();
Loading