diff --git a/lib/src/classifier/_mixins/assessable_classifier_mixin.dart b/lib/src/classifier/_mixins/assessable_classifier_mixin.dart index 9978fa59..47ff3cdb 100644 --- a/lib/src/classifier/_mixins/assessable_classifier_mixin.dart +++ b/lib/src/classifier/_mixins/assessable_classifier_mixin.dart @@ -1,11 +1,10 @@ import 'package:ml_algo/src/classifier/classifier.dart'; import 'package:ml_algo/src/di/injector.dart'; import 'package:ml_algo/src/metric/metric_type.dart'; -import 'package:ml_algo/src/model_selection/assessable.dart'; import 'package:ml_algo/src/model_selection/model_assessor/model_assessor.dart'; import 'package:ml_dataframe/ml_dataframe.dart'; -mixin AssessableClassifierMixin implements Assessable, Classifier { +mixin AssessableClassifierMixin implements Classifier { @override double assess( DataFrame samples, diff --git a/lib/src/classifier/classifier.dart b/lib/src/classifier/classifier.dart index 6b5a87b5..59f95d3f 100644 --- a/lib/src/classifier/classifier.dart +++ b/lib/src/classifier/classifier.dart @@ -3,7 +3,7 @@ import 'package:ml_dataframe/ml_dataframe.dart'; /// An interface for any classifier (linear, non-linear, parametric, /// non-parametric, etc.) -abstract class Classifier extends Predictor { +abstract class Classifier implements Predictor { /// Returns predicted distribution of probabilities for each observation in /// the passed [testFeatures] DataFrame predictProbabilities(DataFrame testFeatures); diff --git a/lib/src/classifier/decision_tree_classifier/decision_tree_classifier.dart b/lib/src/classifier/decision_tree_classifier/decision_tree_classifier.dart index 2754d2ec..d34f9166 100644 --- a/lib/src/classifier/decision_tree_classifier/decision_tree_classifier.dart +++ b/lib/src/classifier/decision_tree_classifier/decision_tree_classifier.dart @@ -1,9 +1,6 @@ import 'package:ml_algo/src/classifier/classifier.dart'; import 'package:ml_algo/src/classifier/decision_tree_classifier/_init_module.dart'; import 'package:ml_algo/src/classifier/decision_tree_classifier/decision_tree_classifier_factory.dart'; -import 'package:ml_algo/src/common/serializable/serializable.dart'; -import 'package:ml_algo/src/model_selection/assessable.dart'; -import 'package:ml_algo/src/predictor/retrainable.dart'; import 'package:ml_dataframe/ml_dataframe.dart'; import 'package:ml_linalg/dtype.dart'; @@ -16,12 +13,7 @@ import 'package:ml_linalg/dtype.dart'; /// decision tree learning. Once a decision tree learned, it may be used to /// classify new samples with the same features that were used to learn the /// tree. -abstract class DecisionTreeClassifier - implements - Assessable, - Serializable, - Retrainable, - Classifier { +abstract class DecisionTreeClassifier implements Classifier { /// Parameters: /// diff --git a/lib/src/classifier/knn_classifier/knn_classifier.dart b/lib/src/classifier/knn_classifier/knn_classifier.dart index cd8c5e3d..7e3305d3 100644 --- a/lib/src/classifier/knn_classifier/knn_classifier.dart +++ b/lib/src/classifier/knn_classifier/knn_classifier.dart @@ -1,10 +1,7 @@ import 'package:ml_algo/src/classifier/classifier.dart'; import 'package:ml_algo/src/classifier/knn_classifier/_init_module.dart'; import 'package:ml_algo/src/classifier/knn_classifier/knn_classifier_factory.dart'; -import 'package:ml_algo/src/common/serializable/serializable.dart'; import 'package:ml_algo/src/knn_kernel/kernel_type.dart'; -import 'package:ml_algo/src/model_selection/assessable.dart'; -import 'package:ml_algo/src/predictor/retrainable.dart'; import 'package:ml_dataframe/ml_dataframe.dart'; import 'package:ml_linalg/distance.dart'; import 'package:ml_linalg/dtype.dart'; @@ -21,12 +18,7 @@ import 'package:ml_linalg/dtype.dart'; /// imprecise result. Thus the weighted version of KNN algorithm is used in the /// classifier. To get weight of a particular observation one may use a kernel /// function. -abstract class KnnClassifier - implements - Assessable, - Serializable, - Retrainable, - Classifier { +abstract class KnnClassifier implements Classifier { /// Parameters: /// /// [trainData] Labelled observations. Must contain [targetName] column. diff --git a/lib/src/classifier/logistic_regressor/logistic_regressor.dart b/lib/src/classifier/logistic_regressor/logistic_regressor.dart index 25d2f321..824170ab 100644 --- a/lib/src/classifier/logistic_regressor/logistic_regressor.dart +++ b/lib/src/classifier/logistic_regressor/logistic_regressor.dart @@ -1,13 +1,10 @@ import 'package:ml_algo/src/classifier/linear_classifier.dart'; import 'package:ml_algo/src/classifier/logistic_regressor/_init_module.dart'; import 'package:ml_algo/src/classifier/logistic_regressor/logistic_regressor_factory.dart'; -import 'package:ml_algo/src/common/serializable/serializable.dart'; import 'package:ml_algo/src/linear_optimizer/gradient_optimizer/learning_rate_generator/learning_rate_type.dart'; import 'package:ml_algo/src/linear_optimizer/initial_coefficients_generator/initial_coefficients_type.dart'; import 'package:ml_algo/src/linear_optimizer/linear_optimizer_type.dart'; import 'package:ml_algo/src/linear_optimizer/regularization_type.dart'; -import 'package:ml_algo/src/model_selection/assessable.dart'; -import 'package:ml_algo/src/predictor/retrainable.dart'; import 'package:ml_dataframe/ml_dataframe.dart'; import 'package:ml_linalg/dtype.dart'; import 'package:ml_linalg/vector.dart'; @@ -19,12 +16,7 @@ import 'package:ml_linalg/vector.dart'; /// In other words, the regressor iteratively tries to select coefficients /// that makes combination of passed features and these coefficients most /// likely. -abstract class LogisticRegressor - implements - Assessable, - Serializable, - Retrainable, - LinearClassifier { +abstract class LogisticRegressor implements LinearClassifier { /// Parameters: /// diff --git a/lib/src/classifier/softmax_regressor/softmax_regressor.dart b/lib/src/classifier/softmax_regressor/softmax_regressor.dart index 193f3839..b8b181f2 100644 --- a/lib/src/classifier/softmax_regressor/softmax_regressor.dart +++ b/lib/src/classifier/softmax_regressor/softmax_regressor.dart @@ -1,13 +1,10 @@ import 'package:ml_algo/src/classifier/linear_classifier.dart'; import 'package:ml_algo/src/classifier/softmax_regressor/_init_module.dart'; import 'package:ml_algo/src/classifier/softmax_regressor/softmax_regressor_factory.dart'; -import 'package:ml_algo/src/common/serializable/serializable.dart'; import 'package:ml_algo/src/linear_optimizer/gradient_optimizer/learning_rate_generator/learning_rate_type.dart'; import 'package:ml_algo/src/linear_optimizer/initial_coefficients_generator/initial_coefficients_type.dart'; import 'package:ml_algo/src/linear_optimizer/linear_optimizer_type.dart'; import 'package:ml_algo/src/linear_optimizer/regularization_type.dart'; -import 'package:ml_algo/src/model_selection/assessable.dart'; -import 'package:ml_algo/src/predictor/retrainable.dart'; import 'package:ml_dataframe/ml_dataframe.dart'; import 'package:ml_linalg/dtype.dart'; import 'package:ml_linalg/matrix.dart'; @@ -26,12 +23,7 @@ import 'package:ml_linalg/matrix.dart'; /// /// Also it is worth to mention that the algorithm is a generalization of /// [Logistic regression](https://en.wikipedia.org/wiki/Logistic_regression)) -abstract class SoftmaxRegressor - implements - Assessable, - Serializable, - Retrainable, - LinearClassifier { +abstract class SoftmaxRegressor implements LinearClassifier { /// Parameters: /// diff --git a/lib/src/common/constants/common_json_keys.dart b/lib/src/common/constants/common_json_keys.dart index 2d3e1e65..a53afd48 100644 --- a/lib/src/common/constants/common_json_keys.dart +++ b/lib/src/common/constants/common_json_keys.dart @@ -1 +1,2 @@ const jsonSchemaVersionJsonKey = '\$V'; +const predictorTypeJsonKey = '\$PT'; diff --git a/lib/src/di/common/init_common_module.dart b/lib/src/di/common/init_common_module.dart index d99bccb1..fed8088b 100644 --- a/lib/src/di/common/init_common_module.dart +++ b/lib/src/di/common/init_common_module.dart @@ -24,6 +24,8 @@ import 'package:ml_algo/src/model_selection/model_assessor/classifier_assessor.d import 'package:ml_algo/src/model_selection/model_assessor/model_assessor.dart'; import 'package:ml_algo/src/model_selection/model_assessor/regressor_assessor.dart'; import 'package:ml_algo/src/predictor/predictor.dart'; +import 'package:ml_algo/src/service/worker_manager/worker_manager.dart'; +import 'package:ml_algo/src/service/worker_manager/worker_manager_impl.dart'; import 'package:ml_dataframe/ml_dataframe.dart'; import 'package:ml_preprocessing/ml_preprocessing.dart'; @@ -76,5 +78,8 @@ void initCommonModule() { RegressorAssessor( injector.get(), featuresTargetSplit, - )); + )) + + ..registerSingletonIf( + () => WorkerManagerImpl()); } diff --git a/lib/src/metric/metric_type_encoded_values.dart b/lib/src/metric/metric_type_encoded_values.dart new file mode 100644 index 00000000..84c51a97 --- /dev/null +++ b/lib/src/metric/metric_type_encoded_values.dart @@ -0,0 +1,6 @@ +const mapeMetricTypeEncodedValue = 'MAPE'; +const rmseMetricTypeEncodedValue = 'RMSE'; +const rssMetricTypeEncodedValue = 'RSS'; +const accuracyMetricTypeEncodedValue = 'ACR'; +const precisionMetricTypeEncodedValue = 'PRC'; +const recallMetricTypeEncodedValue = 'RC'; diff --git a/lib/src/metric/metric_type_json_converter.dart b/lib/src/metric/metric_type_json_converter.dart new file mode 100644 index 00000000..ac350407 --- /dev/null +++ b/lib/src/metric/metric_type_json_converter.dart @@ -0,0 +1,59 @@ +import 'package:json_annotation/json_annotation.dart'; +import 'package:ml_algo/src/metric/metric_type.dart'; +import 'package:ml_algo/src/metric/metric_type_encoded_values.dart'; + +class MetricTypeJsonConverter implements JsonConverter { + const MetricTypeJsonConverter(); + + @override + MetricType fromJson(String json) { + switch (json) { + case mapeMetricTypeEncodedValue: + return MetricType.mape; + + case rmseMetricTypeEncodedValue: + return MetricType.rmse; + + case rssMetricTypeEncodedValue: + return MetricType.rss; + + case accuracyMetricTypeEncodedValue: + return MetricType.accuracy; + + case precisionMetricTypeEncodedValue: + return MetricType.precision; + + case recallMetricTypeEncodedValue: + return MetricType.recall; + + default: + throw UnsupportedError('Unsupported encoded metric value - $json'); + } + } + + @override + String toJson(MetricType metricType) { + switch (metricType) { + case MetricType.mape: + return mapeMetricTypeEncodedValue; + + case MetricType.rmse: + return rmseMetricTypeEncodedValue; + + case MetricType.rss: + return rssMetricTypeEncodedValue; + + case MetricType.accuracy: + return accuracyMetricTypeEncodedValue; + + case MetricType.precision: + return precisionMetricTypeEncodedValue; + + case MetricType.recall: + return recallMetricTypeEncodedValue; + + default: + throw UnsupportedError('Unsupported metric type - $metricType'); + } + } +} diff --git a/lib/src/model_selection/_init_module.dart b/lib/src/model_selection/_init_module.dart index b49f4589..54faea24 100644 --- a/lib/src/model_selection/_init_module.dart +++ b/lib/src/model_selection/_init_module.dart @@ -1,11 +1,13 @@ +import 'package:ml_algo/src/di/common/init_common_module.dart'; import 'package:ml_algo/src/model_selection/_injector.dart'; import 'package:ml_algo/src/model_selection/split_indices_provider/split_indices_provider_factory.dart'; import 'package:ml_algo/src/model_selection/split_indices_provider/split_indices_provider_factory_impl.dart'; +import 'package:ml_algo/src/extensions/injector.dart'; void initModelSelectionModule() { - if (!modelSelectionInjector.exists()) { - modelSelectionInjector - ..registerSingleton( - () => const SplitIndicesProviderFactoryImpl()); - } + initCommonModule(); + + modelSelectionInjector + ..registerSingletonIf( + () => const SplitIndicesProviderFactoryImpl()); } diff --git a/lib/src/model_selection/assessable.dart b/lib/src/model_selection/assessable.dart deleted file mode 100644 index 8811f417..00000000 --- a/lib/src/model_selection/assessable.dart +++ /dev/null @@ -1,10 +0,0 @@ -import 'package:ml_algo/src/metric/metric_type.dart'; -import 'package:ml_dataframe/ml_dataframe.dart'; - -/// An interface for a ML model's performance assessment -abstract class Assessable { - /// Assesses model performance according to provided [metricType] - /// - /// Throws an exception if inappropriate [metricType] provided. - double assess(DataFrame observations, MetricType metricType); -} diff --git a/lib/src/model_selection/cross_validator/_helpers/assess_predictor.dart b/lib/src/model_selection/cross_validator/_helpers/assess_predictor.dart new file mode 100644 index 00000000..e84421a1 --- /dev/null +++ b/lib/src/model_selection/cross_validator/_helpers/assess_predictor.dart @@ -0,0 +1,12 @@ +import 'package:ml_algo/src/model_selection/cross_validator/cross_validator_isolate_message.dart'; + +num assessPredictor(Map encodedMessage) { + final message = CrossValidatorIsolateMessage + .fromJson(encodedMessage); + final predictor = message + .predictorPrototype + .retrain(message.trainData); + + return predictor + .assess(message.testData, message.metricType); +} diff --git a/lib/src/model_selection/cross_validator/_helpers/decode_predictor.dart b/lib/src/model_selection/cross_validator/_helpers/decode_predictor.dart new file mode 100644 index 00000000..c6d6e762 --- /dev/null +++ b/lib/src/model_selection/cross_validator/_helpers/decode_predictor.dart @@ -0,0 +1,35 @@ +import 'dart:convert'; + +import 'package:ml_algo/ml_algo.dart'; +import 'package:ml_algo/src/classifier/logistic_regressor/logistic_regressor.dart'; +import 'package:ml_algo/src/classifier/softmax_regressor/softmax_regressor.dart'; +import 'package:ml_algo/src/model_selection/cross_validator/predictor_type.dart'; +import 'package:ml_algo/src/predictor/predictor.dart'; + +Predictor decodePredictor( + PredictorType predictorType, + Map json, +) { + switch (predictorType) { + case PredictorType.softmaxRegressor: + return SoftmaxRegressor.fromJson(jsonEncode(json)); + + case PredictorType.logisticRegressor: + return LogisticRegressor.fromJson(jsonEncode(json)); + + case PredictorType.knnClassifier: + return KnnClassifier.fromJson(jsonEncode(json)); + + case PredictorType.decisionTreeClassifier: + return DecisionTreeClassifier.fromJson(jsonEncode(json)); + + case PredictorType.knnRegressor: + return KnnRegressor.fromJson(jsonEncode(json)); + + case PredictorType.linearRegressor: + return LinearRegressor.fromJson(jsonEncode(json)); + + default: + throw UnsupportedError('Unsupported predictor type - ${predictorType}'); + } +} diff --git a/lib/src/model_selection/cross_validator/_helpers/decode_predictor_type.dart b/lib/src/model_selection/cross_validator/_helpers/decode_predictor_type.dart new file mode 100644 index 00000000..12c07fb0 --- /dev/null +++ b/lib/src/model_selection/cross_validator/_helpers/decode_predictor_type.dart @@ -0,0 +1,27 @@ +import 'package:ml_algo/src/model_selection/cross_validator/predictor_type.dart'; +import 'package:ml_algo/src/model_selection/cross_validator/predictor_type_encoded_values.dart'; + +PredictorType decodePredictorType(String json) { + switch (json) { + case knnRegressorPredictorTypeEncodedValue: + return PredictorType.knnRegressor; + + case linearRegressorPredictorTypeEncodedValue: + return PredictorType.linearRegressor; + + case decisionTreeClassifierPredictorTypeEncodedValue: + return PredictorType.decisionTreeClassifier; + + case knnClassifierPredictorTypeEncodedValue: + return PredictorType.knnClassifier; + + case logisticRegressorPredictorTypeEncodedValue: + return PredictorType.logisticRegressor; + + case softmaxRegressorPredictorTypeEncodedValue: + return PredictorType.softmaxRegressor; + + default: + throw UnsupportedError('Unsupported predictor encoded value - $json'); + } +} diff --git a/lib/src/model_selection/cross_validator/_helpers/encode_predictor_type.dart b/lib/src/model_selection/cross_validator/_helpers/encode_predictor_type.dart new file mode 100644 index 00000000..68da10d2 --- /dev/null +++ b/lib/src/model_selection/cross_validator/_helpers/encode_predictor_type.dart @@ -0,0 +1,27 @@ +import 'package:ml_algo/src/model_selection/cross_validator/predictor_type.dart'; +import 'package:ml_algo/src/model_selection/cross_validator/predictor_type_encoded_values.dart'; + +String encodePredictorType(PredictorType predictorType) { + switch (predictorType) { + case PredictorType.knnRegressor: + return knnRegressorPredictorTypeEncodedValue; + + case PredictorType.linearRegressor: + return linearRegressorPredictorTypeEncodedValue; + + case PredictorType.decisionTreeClassifier: + return decisionTreeClassifierPredictorTypeEncodedValue; + + case PredictorType.knnClassifier: + return knnClassifierPredictorTypeEncodedValue; + + case PredictorType.logisticRegressor: + return logisticRegressorPredictorTypeEncodedValue; + + case PredictorType.softmaxRegressor: + return softmaxRegressorPredictorTypeEncodedValue; + + default: + throw UnsupportedError('Unsupported predictor type - $predictorType'); + } +} diff --git a/lib/src/model_selection/cross_validator/_helpers/get_predictor_type.dart b/lib/src/model_selection/cross_validator/_helpers/get_predictor_type.dart new file mode 100644 index 00000000..ad9026f0 --- /dev/null +++ b/lib/src/model_selection/cross_validator/_helpers/get_predictor_type.dart @@ -0,0 +1,37 @@ +import 'package:ml_algo/src/classifier/decision_tree_classifier/decision_tree_classifier.dart'; +import 'package:ml_algo/src/classifier/knn_classifier/knn_classifier.dart'; +import 'package:ml_algo/src/classifier/logistic_regressor/logistic_regressor.dart'; +import 'package:ml_algo/src/classifier/softmax_regressor/softmax_regressor.dart'; +import 'package:ml_algo/src/model_selection/cross_validator/predictor_type.dart'; +import 'package:ml_algo/src/predictor/predictor.dart'; +import 'package:ml_algo/src/regressor/knn_regressor/knn_regressor.dart'; +import 'package:ml_algo/src/regressor/linear_regressor/linear_regressor.dart'; + +PredictorType getPredictorType(Predictor predictor) { + if (predictor is LogisticRegressor) { + return PredictorType.logisticRegressor; + } + + if (predictor is SoftmaxRegressor) { + return PredictorType.softmaxRegressor; + } + + if (predictor is DecisionTreeClassifier) { + return PredictorType.decisionTreeClassifier; + } + + if (predictor is KnnClassifier) { + return PredictorType.knnClassifier; + } + + if (predictor is LinearRegressor) { + return PredictorType.linearRegressor; + } + + if (predictor is KnnRegressor) { + return PredictorType.knnRegressor; + } + + throw UnsupportedError( + 'Unsupported predictor type - ${predictor.runtimeType}'); +} diff --git a/lib/src/model_selection/cross_validator/_init_module.dart b/lib/src/model_selection/cross_validator/_init_module.dart new file mode 100644 index 00000000..e7f6db03 --- /dev/null +++ b/lib/src/model_selection/cross_validator/_init_module.dart @@ -0,0 +1,5 @@ +import 'package:ml_algo/src/di/common/init_common_module.dart'; + +void initCrossValidatorModule() { + initCommonModule(); +} diff --git a/lib/src/model_selection/cross_validator/_injector.dart b/lib/src/model_selection/cross_validator/_injector.dart new file mode 100644 index 00000000..1c789f9a --- /dev/null +++ b/lib/src/model_selection/cross_validator/_injector.dart @@ -0,0 +1,3 @@ +import 'package:injector/injector.dart'; + +final crossValidatorInjector = Injector(); diff --git a/lib/src/model_selection/cross_validator/cross_validator.dart b/lib/src/model_selection/cross_validator/cross_validator.dart index 2e6692a0..b8e7113a 100644 --- a/lib/src/model_selection/cross_validator/cross_validator.dart +++ b/lib/src/model_selection/cross_validator/cross_validator.dart @@ -1,15 +1,17 @@ +import 'package:ml_algo/src/di/injector.dart'; import 'package:ml_algo/src/metric/metric_type.dart'; import 'package:ml_algo/src/model_selection/_init_module.dart'; import 'package:ml_algo/src/model_selection/_injector.dart'; -import 'package:ml_algo/src/model_selection/assessable.dart'; import 'package:ml_algo/src/model_selection/cross_validator/cross_validator_impl.dart'; import 'package:ml_algo/src/model_selection/split_indices_provider/split_indices_provider_factory.dart'; import 'package:ml_algo/src/model_selection/split_indices_provider/split_indices_provider_type.dart'; +import 'package:ml_algo/src/predictor/predictor.dart'; +import 'package:ml_algo/src/service/worker_manager/worker_manager.dart'; import 'package:ml_dataframe/ml_dataframe.dart'; import 'package:ml_linalg/dtype.dart'; import 'package:ml_linalg/linalg.dart'; -typedef PredictorFactory = Assessable Function(DataFrame observations); +typedef PredictorFactoryFn = Predictor Function(DataFrame observations); typedef DataPreprocessFn = List Function(DataFrame trainData, DataFrame testData); @@ -39,11 +41,13 @@ abstract class CrossValidator { final dataSplitterFactory = modelSelectionInjector .get(); + final workerManager = injector.get(); final dataSplitter = dataSplitterFactory .createByType(SplitIndicesProviderType.kFold, numberOfFolds: numberOfFolds); return CrossValidatorImpl( samples, + workerManager, dataSplitter, dtype, ); @@ -71,11 +75,13 @@ abstract class CrossValidator { final dataSplitterFactory = modelSelectionInjector .get(); + final workerManager = injector.get(); final dataSplitter = dataSplitterFactory .createByType(SplitIndicesProviderType.lpo, p: p); return CrossValidatorImpl( samples, + workerManager, dataSplitter, dtype, ); @@ -133,7 +139,7 @@ abstract class CrossValidator { /// print(averageScore); /// ```` Future evaluate( - PredictorFactory predictorFactory, + PredictorFactoryFn predictorFactory, MetricType metricType, { DataPreprocessFn onDataSplit, diff --git a/lib/src/model_selection/cross_validator/cross_validator_impl.dart b/lib/src/model_selection/cross_validator/cross_validator_impl.dart index 23cc83b4..e3a375c6 100644 --- a/lib/src/model_selection/cross_validator/cross_validator_impl.dart +++ b/lib/src/model_selection/cross_validator/cross_validator_impl.dart @@ -1,8 +1,12 @@ import 'package:ml_algo/src/common/exception/invalid_test_data_columns_number_exception.dart'; import 'package:ml_algo/src/common/exception/invalid_train_data_columns_number_exception.dart'; import 'package:ml_algo/src/metric/metric_type.dart'; +import 'package:ml_algo/src/model_selection/cross_validator/_helpers/assess_predictor.dart'; +import 'package:ml_algo/src/model_selection/cross_validator/_helpers/get_predictor_type.dart'; import 'package:ml_algo/src/model_selection/cross_validator/cross_validator.dart'; +import 'package:ml_algo/src/model_selection/cross_validator/cross_validator_isolate_message.dart'; import 'package:ml_algo/src/model_selection/split_indices_provider/split_indices_provider.dart'; +import 'package:ml_algo/src/service/worker_manager/worker_manager.dart'; import 'package:ml_dataframe/ml_dataframe.dart'; import 'package:ml_linalg/dtype.dart'; import 'package:ml_linalg/matrix.dart'; @@ -12,40 +16,49 @@ import 'package:quiver/iterables.dart'; class CrossValidatorImpl implements CrossValidator { CrossValidatorImpl( this.samples, + this._workerManager, this._splitter, this.dtype, ); final DataFrame samples; - final DType dtype; + final WorkerManager _workerManager; final SplitIndicesProvider _splitter; + final DType dtype; @override Future evaluate( - PredictorFactory predictorFactory, + PredictorFactoryFn predictorFactory, MetricType metricType, { DataPreprocessFn onDataSplit, } - ) { + ) async { + await _workerManager.init(); + final samplesAsMatrix = samples.toMatrix(dtype); final sourceColumnsNum = samplesAsMatrix.columnsNum; final discreteColumns = enumerate(samples.series) .where((indexedSeries) => indexedSeries.value.isDiscrete) .map((indexedSeries) => indexedSeries.index); - final allIndicesGroups = _splitter.getIndices(samplesAsMatrix.rowsNum); - final scores = allIndicesGroups + final allIndicesGroups = _splitter + .getIndices(samplesAsMatrix.rowsNum); + final scoreFutures = allIndicesGroups .map((testRowsIndices) { final split = _makeSplit(testRowsIndices, discreteColumns); final trainDataFrame = split[0]; final testDataFrame = split[1]; - final splits = onDataSplit != null + final transformedData = onDataSplit != null ? onDataSplit(trainDataFrame, testDataFrame) : [trainDataFrame, testDataFrame]; - final transformedTrainData = splits[0]; - final transformedTestData = splits[1]; - final transformedTrainDataColumnsNum = transformedTrainData.header.length; - final transformedTestDataColumnsNum = transformedTestData.header.length; + final transformedTrainData = transformedData[0]; + final transformedTestData = transformedData[1]; + final transformedTrainDataColumnsNum = transformedTrainData + .header + .length; + final transformedTestDataColumnsNum = transformedTestData + .header + .length; if (transformedTrainDataColumnsNum != sourceColumnsNum) { throw InvalidTrainDataColumnsNumberException(sourceColumnsNum, @@ -57,12 +70,18 @@ class CrossValidatorImpl implements CrossValidator { transformedTestDataColumnsNum); } - return predictorFactory(transformedTrainData) - .assess(transformedTestData, metricType); + return _assessPredictor( + predictorFactory, + transformedTrainData, + transformedTestData, + metricType, + ); }) .toList(); + final scores = await Future.wait(scoreFutures); - return Future.value(Vector.fromList(scores, dtype: dtype)); + return Future + .value(Vector.fromList(scores, dtype: dtype)); } List _makeSplit(Iterable testRowsIndices, @@ -97,4 +116,28 @@ class CrossValidatorImpl implements CrossValidator { ), ]; } + + Future _assessPredictor( + PredictorFactoryFn predictorFactoryFn, + DataFrame trainData, + DataFrame testData, + MetricType metricType, + ) async { + final samplesPrototype = samples.sampleFromRows([0]); + final predictorPrototype = predictorFactoryFn(samplesPrototype); + final predictorType = getPredictorType(predictorPrototype); + + return _workerManager + .executor + .execute( + arg1: CrossValidatorIsolateMessage( + predictorPrototype, + trainData, + testData, + predictorType, + metricType, + ).toJson(), + fun1: assessPredictor, + ); + } } diff --git a/lib/src/model_selection/cross_validator/cross_validator_isolate_message.dart b/lib/src/model_selection/cross_validator/cross_validator_isolate_message.dart new file mode 100644 index 00000000..5492ea97 --- /dev/null +++ b/lib/src/model_selection/cross_validator/cross_validator_isolate_message.dart @@ -0,0 +1,50 @@ +import 'package:ml_algo/src/metric/metric_type.dart'; +import 'package:ml_algo/src/metric/metric_type_json_converter.dart'; +import 'package:ml_algo/src/model_selection/cross_validator/_helpers/decode_predictor.dart'; +import 'package:ml_algo/src/model_selection/cross_validator/_helpers/decode_predictor_type.dart'; +import 'package:ml_algo/src/model_selection/cross_validator/_helpers/encode_predictor_type.dart'; +import 'package:ml_algo/src/model_selection/cross_validator/cross_validator_isolate_message_json_keys.dart'; +import 'package:ml_algo/src/model_selection/cross_validator/predictor_type.dart'; +import 'package:ml_algo/src/predictor/predictor.dart'; +import 'package:ml_dataframe/ml_dataframe.dart'; + +class CrossValidatorIsolateMessage { + final Predictor predictorPrototype; + final DataFrame trainData; + final DataFrame testData; + final PredictorType predictorType; + final MetricType metricType; + + CrossValidatorIsolateMessage( + this.predictorPrototype, + this.trainData, + this.testData, + this.predictorType, + this.metricType, + ); + + static CrossValidatorIsolateMessage fromJson(Map json) { + final predictorType = decodePredictorType( + json[predictorTypeJsonKey] as String); + + return CrossValidatorIsolateMessage( + decodePredictor( + predictorType, + json[predictorPrototypeJsonKey] as Map, + ), + DataFrame.fromJson(json[trainDataJsonKey] as Map), + DataFrame.fromJson(json[testDataJsonKey] as Map), + predictorType, + MetricTypeJsonConverter() + .fromJson(json[metricTypeJsonKey] as String), + ); + } + + Map toJson() => { + predictorPrototypeJsonKey: predictorPrototype.toJson(), + trainDataJsonKey: trainData.toJson(), + testDataJsonKey: testData.toJson(), + predictorTypeJsonKey: encodePredictorType(predictorType), + metricTypeJsonKey: MetricTypeJsonConverter().toJson(metricType), + }; +} diff --git a/lib/src/model_selection/cross_validator/cross_validator_isolate_message_json_keys.dart b/lib/src/model_selection/cross_validator/cross_validator_isolate_message_json_keys.dart new file mode 100644 index 00000000..8cdb41f7 --- /dev/null +++ b/lib/src/model_selection/cross_validator/cross_validator_isolate_message_json_keys.dart @@ -0,0 +1,6 @@ +const predictorPrototypeJsonKey = 'P'; +const trainDataJsonKey = 'T'; +const testDataJsonKey = 'TE'; +const testRowsIndicesJsonKey = 'R'; +const predictorTypeJsonKey = 'PT'; +const metricTypeJsonKey = 'M'; diff --git a/lib/src/model_selection/cross_validator/predictor_type.dart b/lib/src/model_selection/cross_validator/predictor_type.dart new file mode 100644 index 00000000..fed3f116 --- /dev/null +++ b/lib/src/model_selection/cross_validator/predictor_type.dart @@ -0,0 +1,8 @@ +enum PredictorType { + logisticRegressor, + softmaxRegressor, + decisionTreeClassifier, + knnClassifier, + linearRegressor, + knnRegressor, +} diff --git a/lib/src/model_selection/cross_validator/predictor_type_encoded_values.dart b/lib/src/model_selection/cross_validator/predictor_type_encoded_values.dart new file mode 100644 index 00000000..4f58bc36 --- /dev/null +++ b/lib/src/model_selection/cross_validator/predictor_type_encoded_values.dart @@ -0,0 +1,6 @@ +const logisticRegressorPredictorTypeEncodedValue = 'LOGR'; +const softmaxRegressorPredictorTypeEncodedValue = 'SR'; +const decisionTreeClassifierPredictorTypeEncodedValue = 'DTC'; +const knnClassifierPredictorTypeEncodedValue = 'KNNC'; +const knnRegressorPredictorTypeEncodedValue = 'KNNR'; +const linearRegressorPredictorTypeEncodedValue = 'LINR'; diff --git a/lib/src/predictor/predictor.dart b/lib/src/predictor/predictor.dart index ac7cca88..9e5793ef 100644 --- a/lib/src/predictor/predictor.dart +++ b/lib/src/predictor/predictor.dart @@ -1,15 +1,26 @@ +import 'package:ml_algo/src/common/serializable/serializable.dart'; +import 'package:ml_algo/src/metric/metric_type.dart'; import 'package:ml_dataframe/ml_dataframe.dart'; import 'package:ml_linalg/dtype.dart'; /// A common interface for all types of classifiers and regressors -abstract class Predictor { +abstract class Predictor implements Serializable { /// A collection of target column names of a dataset used to learn the ML /// model Iterable get targetNames; + /// A type for all the numeric values using by the [Predictor] + DType get dtype; + + /// Assesses model performance according to provided [metricType] + /// + /// Throws an exception if inappropriate [metricType] provided. + double assess(DataFrame observations, MetricType metricType); + /// Returns prediction, based on the model learned parameters DataFrame predict(DataFrame testFeatures); - /// A type for all the numeric values using by the [Predictor] - DType get dtype; + /// Re-runs the process on new training [data]. The features, model algorithm, + /// and hyperparameters remain the same. + Predictor retrain(DataFrame data); } diff --git a/lib/src/predictor/retrainable.dart b/lib/src/predictor/retrainable.dart deleted file mode 100644 index 25cebc2f..00000000 --- a/lib/src/predictor/retrainable.dart +++ /dev/null @@ -1,8 +0,0 @@ -import 'package:ml_algo/src/predictor/predictor.dart'; -import 'package:ml_dataframe/ml_dataframe.dart'; - -abstract class Retrainable { - /// Re-runs the process on new training [data]. The features, model algorithm, - /// and hyperparameters remain the same. - Predictor retrain(DataFrame data); -} diff --git a/lib/src/regressor/_mixins/assessable_regressor_mixin.dart b/lib/src/regressor/_mixins/assessable_regressor_mixin.dart index 8ea48296..fe60c71e 100644 --- a/lib/src/regressor/_mixins/assessable_regressor_mixin.dart +++ b/lib/src/regressor/_mixins/assessable_regressor_mixin.dart @@ -1,14 +1,10 @@ import 'package:ml_algo/src/di/injector.dart'; import 'package:ml_algo/src/metric/metric_type.dart'; -import 'package:ml_algo/src/model_selection/assessable.dart'; import 'package:ml_algo/src/model_selection/model_assessor/model_assessor.dart'; import 'package:ml_algo/src/predictor/predictor.dart'; import 'package:ml_dataframe/ml_dataframe.dart'; -mixin AssessableRegressorMixin implements - Assessable, - Predictor { - +mixin AssessableRegressorMixin implements Predictor { @override double assess( DataFrame samples, diff --git a/lib/src/regressor/knn_regressor/knn_regressor.dart b/lib/src/regressor/knn_regressor/knn_regressor.dart index 03bda804..85c2d3c3 100644 --- a/lib/src/regressor/knn_regressor/knn_regressor.dart +++ b/lib/src/regressor/knn_regressor/knn_regressor.dart @@ -1,8 +1,5 @@ -import 'package:ml_algo/src/common/serializable/serializable.dart'; import 'package:ml_algo/src/knn_kernel/kernel_type.dart'; -import 'package:ml_algo/src/model_selection/assessable.dart'; import 'package:ml_algo/src/predictor/predictor.dart'; -import 'package:ml_algo/src/predictor/retrainable.dart'; import 'package:ml_algo/src/regressor/knn_regressor/_init_module.dart'; import 'package:ml_algo/src/regressor/knn_regressor/knn_regressor_factory.dart'; import 'package:ml_dataframe/ml_dataframe.dart'; @@ -22,12 +19,7 @@ import 'package:ml_linalg/dtype.dart'; /// To get a more precise result, one may use weighted average of found labels - /// the farther a found observation from the target one, the lower the weight of /// the observation is. To obtain these weights one may use a kernel function. -abstract class KnnRegressor - implements - Assessable, - Serializable, - Retrainable, - Predictor { +abstract class KnnRegressor implements Predictor { /// Parameters: /// /// [fittingData] Labelled observations, among which will be searched [k] diff --git a/lib/src/regressor/linear_regressor/linear_regressor.dart b/lib/src/regressor/linear_regressor/linear_regressor.dart index 183579f3..c7f03a2e 100644 --- a/lib/src/regressor/linear_regressor/linear_regressor.dart +++ b/lib/src/regressor/linear_regressor/linear_regressor.dart @@ -1,11 +1,8 @@ -import 'package:ml_algo/src/common/serializable/serializable.dart'; import 'package:ml_algo/src/linear_optimizer/gradient_optimizer/learning_rate_generator/learning_rate_type.dart'; import 'package:ml_algo/src/linear_optimizer/initial_coefficients_generator/initial_coefficients_type.dart'; import 'package:ml_algo/src/linear_optimizer/linear_optimizer_type.dart'; import 'package:ml_algo/src/linear_optimizer/regularization_type.dart'; -import 'package:ml_algo/src/model_selection/assessable.dart'; import 'package:ml_algo/src/predictor/predictor.dart'; -import 'package:ml_algo/src/predictor/retrainable.dart'; import 'package:ml_algo/src/regressor/linear_regressor/_init_module.dart'; import 'package:ml_algo/src/regressor/linear_regressor/linear_regressor_factory.dart'; import 'package:ml_dataframe/ml_dataframe.dart'; @@ -22,12 +19,7 @@ import 'package:ml_linalg/vector.dart'; /// regressor should predict, and since all the `x` values are known (since they /// are the input for the algorithm), the regressor should find the best /// coefficients (weights) for each `x`-es to make a best prediction of `y` term. -abstract class LinearRegressor - implements - Assessable, - Serializable, - Retrainable, - Predictor { +abstract class LinearRegressor implements Predictor { /// Parameters: /// /// [fittingData] A [DataFrame] with observations that is used by the diff --git a/lib/src/service/worker_manager/worker_manager.dart b/lib/src/service/worker_manager/worker_manager.dart new file mode 100644 index 00000000..b3114017 --- /dev/null +++ b/lib/src/service/worker_manager/worker_manager.dart @@ -0,0 +1,6 @@ +import 'package:worker_manager/worker_manager.dart'; + +abstract class WorkerManager { + Future init(); + Executor get executor; +} diff --git a/lib/src/service/worker_manager/worker_manager_impl.dart b/lib/src/service/worker_manager/worker_manager_impl.dart new file mode 100644 index 00000000..9d1ea648 --- /dev/null +++ b/lib/src/service/worker_manager/worker_manager_impl.dart @@ -0,0 +1,24 @@ +import 'dart:async'; +import 'package:ml_algo/src/service/worker_manager/worker_manager.dart'; +import 'package:worker_manager/src/executor.dart'; + +class WorkerManagerImpl implements WorkerManager { + Completer _initCompleter; + + @override + Executor get executor => Executor(); + + @override + Future init() async { + final isWarmedUp = (await _initCompleter?.future) ?? false; + + if (isWarmedUp) { + return; + } + + _initCompleter ??= Completer(); + + await Executor().warmUp(); + _initCompleter.complete(true); + } +} diff --git a/pubspec.yaml b/pubspec.yaml index e89067a4..053ba114 100644 --- a/pubspec.yaml +++ b/pubspec.yaml @@ -14,6 +14,7 @@ dependencies: ml_linalg: ^12.17.8 ml_preprocessing: ^5.2.1 quiver: ^2.0.2 + worker_manager: ^3.1.8 xrange: ^0.0.8 dev_dependencies: diff --git a/test/mocks.dart b/test/mocks.dart index 5513691a..0575d073 100644 --- a/test/mocks.dart +++ b/test/mocks.dart @@ -9,6 +9,7 @@ import 'package:ml_algo/src/classifier/softmax_regressor/softmax_regressor.dart' import 'package:ml_algo/src/classifier/softmax_regressor/softmax_regressor_factory.dart'; import 'package:ml_algo/src/common/distribution_calculator/distribution_calculator.dart'; import 'package:ml_algo/src/common/distribution_calculator/distribution_calculator_factory.dart'; +import 'package:ml_algo/src/common/serializable/serializable.dart'; import 'package:ml_algo/src/cost_function/cost_function.dart'; import 'package:ml_algo/src/cost_function/cost_function_factory.dart'; import 'package:ml_algo/src/knn_kernel/kernel.dart'; @@ -28,7 +29,6 @@ import 'package:ml_algo/src/math/randomizer/randomizer.dart'; import 'package:ml_algo/src/math/randomizer/randomizer_factory.dart'; import 'package:ml_algo/src/metric/metric.dart'; import 'package:ml_algo/src/metric/metric_factory.dart'; -import 'package:ml_algo/src/model_selection/assessable.dart'; import 'package:ml_algo/src/model_selection/model_assessor/classifier_assessor.dart'; import 'package:ml_algo/src/model_selection/split_indices_provider/split_indices_provider.dart'; import 'package:ml_algo/src/model_selection/split_indices_provider/split_indices_provider_factory.dart'; @@ -37,6 +37,7 @@ import 'package:ml_algo/src/regressor/knn_regressor/knn_regressor.dart'; import 'package:ml_algo/src/regressor/knn_regressor/knn_regressor_factory.dart'; import 'package:ml_algo/src/regressor/linear_regressor/linear_regressor.dart'; import 'package:ml_algo/src/regressor/linear_regressor/linear_regressor_factory.dart'; +import 'package:ml_algo/src/service/worker_manager/worker_manager.dart'; import 'package:ml_algo/src/tree_trainer/decision_tree_trainer.dart'; import 'package:ml_algo/src/tree_trainer/leaf_detector/leaf_detector.dart'; import 'package:ml_algo/src/tree_trainer/leaf_detector/leaf_detector_factory.dart'; @@ -116,7 +117,8 @@ class DataSplitterMock extends Mock implements SplitIndicesProvider {} class DataSplitterFactoryMock extends Mock implements SplitIndicesProviderFactory {} -class AssessableMock extends Mock implements Assessable {} +class SerializablePredictorMock extends Mock + implements Serializable, Predictor {} class TreeSplitAssessorMock extends Mock implements TreeSplitAssessor {} @@ -205,6 +207,8 @@ class DecisionTreeClassifierMock extends Mock class LogisticRegressorFactoryMock extends Mock implements LogisticRegressorFactory {} +class WorkerManagerMock extends Mock implements WorkerManager {} + LearningRateGeneratorFactoryMock createLearningRateGeneratorFactoryMock( LearningRateGenerator generator) { final factory = LearningRateGeneratorFactoryMock(); diff --git a/test/model_selection/cross_validator/cross_validator_impl_test.dart b/test/model_selection/cross_validator/cross_validator_impl_test.dart index 73dc13cf..11a9b3d5 100644 --- a/test/model_selection/cross_validator/cross_validator_impl_test.dart +++ b/test/model_selection/cross_validator/cross_validator_impl_test.dart @@ -18,6 +18,8 @@ SplitIndicesProvider createSplitter(Iterable> indices) { void main() { group('CrossValidatorImpl', () { + final workerManagerMock = WorkerManagerMock(); + test('should evaluate performance of a predictor on given test ' 'splits', () async { final allObservations = DataFrame(>[ @@ -33,9 +35,13 @@ void main() { ], header: ['first', 'second', 'third', 'target'], headerExists: false); final metric = MetricType.mape; final splitter = createSplitter([[0,2,4],[6, 8]]); - final predictor = AssessableMock(); - final validator = CrossValidatorImpl(allObservations, splitter, - DType.float32); + final predictor = SerializablePredictorMock(); + final validator = CrossValidatorImpl( + allObservations, + workerManagerMock, + splitter, + DType.float32, + ); final score = 20.0; when(predictor.assess(any, any)).thenReturn(score); @@ -77,9 +83,13 @@ void main() { final metric = MetricType.mape; final splitter = createSplitter([[0], [0], [0]]); - final predictor = AssessableMock(); - final validator = CrossValidatorImpl(allObservations, splitter, - DType.float32); + final predictor = SerializablePredictorMock(); + final validator = CrossValidatorImpl( + allObservations, + workerManagerMock, + splitter, + DType.float32, + ); when(predictor.assess(any, any)).thenReturn(1); @@ -137,9 +147,13 @@ void main() { final metric = MetricType.mape; final splitter = createSplitter([[0], [0], [0]]); - final predictor = AssessableMock(); - final validator = CrossValidatorImpl(allObservations, splitter, - DType.float32); + final predictor = SerializablePredictorMock(); + final validator = CrossValidatorImpl( + allObservations, + workerManagerMock, + splitter, + DType.float32, + ); when(predictor.assess(any, any)).thenReturn(1); @@ -216,9 +230,13 @@ void main() { final metric = MetricType.mape; final splitter = createSplitter([[0], [2], [4]]); - final predictor = AssessableMock(); - final validator = CrossValidatorImpl(allObservations, - splitter, DType.float32); + final predictor = SerializablePredictorMock(); + final validator = CrossValidatorImpl( + allObservations, + workerManagerMock, + splitter, + DType.float32, + ); when(predictor.assess(any, any)).thenReturn(1); @@ -318,9 +336,13 @@ void main() { final metric = MetricType.mape; final splitter = createSplitter([[0], [2], [4]]); - final predictor = AssessableMock(); - final validator = CrossValidatorImpl(originalData, splitter, - DType.float32); + final predictor = SerializablePredictorMock(); + final validator = CrossValidatorImpl( + originalData, + workerManagerMock, + splitter, + DType.float32, + ); when(predictor.assess(any, any)).thenReturn(1); @@ -362,9 +384,13 @@ void main() { final metric = MetricType.mape; final splitter = createSplitter([[0], [2], [4]]); - final predictor = AssessableMock(); - final validator = CrossValidatorImpl(originalData, splitter, - DType.float32); + final predictor = SerializablePredictorMock(); + final validator = CrossValidatorImpl( + originalData, + workerManagerMock, + splitter, + DType.float32, + ); when(predictor.assess(any, any)).thenReturn(1); @@ -405,9 +431,13 @@ void main() { final metric = MetricType.mape; final splitter = createSplitter([[0], [2], [4]]); - final predictor = AssessableMock(); - final validator = CrossValidatorImpl(originalData, splitter, - DType.float32); + final predictor = SerializablePredictorMock(); + final validator = CrossValidatorImpl( + originalData, + workerManagerMock, + splitter, + DType.float32, + ); when(predictor.assess(any, any)).thenReturn(1); @@ -448,9 +478,13 @@ void main() { final metric = MetricType.mape; final splitter = createSplitter([[0], [2], [4]]); - final predictor = AssessableMock(); - final validator = CrossValidatorImpl(originalData, splitter, - DType.float32); + final predictor = SerializablePredictorMock(); + final validator = CrossValidatorImpl( + originalData, + workerManagerMock, + splitter, + DType.float32, + ); when(predictor.assess(any, any)).thenReturn(1);