From 7e5409a63cb10ecb0d9abd596f55c51f54c4efea Mon Sep 17 00:00:00 2001 From: Ilya Gyrdymov Date: Mon, 14 Dec 2020 22:12:23 +0200 Subject: [PATCH 1/5] CrossValidator: isolates used --- lib/src/di/common/init_common_module.dart | 7 +- lib/src/model_selection/_init_module.dart | 12 +-- .../_helpers/assess_predictor.dart | 13 +++ .../from_serializable_predictor_Json.dart | 5 ++ .../cross_validator/_init_module.dart | 5 ++ .../cross_validator/_injector.dart | 3 + .../cross_validator/cross_validator.dart | 12 ++- .../cross_validator/cross_validator_impl.dart | 66 +++++++++++--- .../cross_validator_isolate_message.dart | 39 +++++++++ .../cross_validator_isolate_message.g.dart | 86 +++++++++++++++++++ ...s_validator_isolate_message_json_keys.dart | 5 ++ .../serializable_predictor.dart | 10 +++ .../worker_manager/worker_manager.dart | 6 ++ .../worker_manager/worker_manager_impl.dart | 19 ++++ pubspec.yaml | 1 + test/mocks.dart | 6 ++ .../cross_validator_impl_test.dart | 82 ++++++++++++------ 17 files changed, 331 insertions(+), 46 deletions(-) create mode 100644 lib/src/model_selection/cross_validator/_helpers/assess_predictor.dart create mode 100644 lib/src/model_selection/cross_validator/_helpers/from_serializable_predictor_Json.dart create mode 100644 lib/src/model_selection/cross_validator/_init_module.dart create mode 100644 lib/src/model_selection/cross_validator/_injector.dart create mode 100644 lib/src/model_selection/cross_validator/cross_validator_isolate_message.dart create mode 100644 lib/src/model_selection/cross_validator/cross_validator_isolate_message.g.dart create mode 100644 lib/src/model_selection/cross_validator/cross_validator_isolate_message_json_keys.dart create mode 100644 lib/src/model_selection/serializable_predictor.dart create mode 100644 lib/src/service/worker_manager/worker_manager.dart create mode 100644 lib/src/service/worker_manager/worker_manager_impl.dart diff --git a/lib/src/di/common/init_common_module.dart b/lib/src/di/common/init_common_module.dart index d99bccb1..fed8088b 100644 --- a/lib/src/di/common/init_common_module.dart +++ b/lib/src/di/common/init_common_module.dart @@ -24,6 +24,8 @@ import 'package:ml_algo/src/model_selection/model_assessor/classifier_assessor.d import 'package:ml_algo/src/model_selection/model_assessor/model_assessor.dart'; import 'package:ml_algo/src/model_selection/model_assessor/regressor_assessor.dart'; import 'package:ml_algo/src/predictor/predictor.dart'; +import 'package:ml_algo/src/service/worker_manager/worker_manager.dart'; +import 'package:ml_algo/src/service/worker_manager/worker_manager_impl.dart'; import 'package:ml_dataframe/ml_dataframe.dart'; import 'package:ml_preprocessing/ml_preprocessing.dart'; @@ -76,5 +78,8 @@ void initCommonModule() { RegressorAssessor( injector.get(), featuresTargetSplit, - )); + )) + + ..registerSingletonIf( + () => WorkerManagerImpl()); } diff --git a/lib/src/model_selection/_init_module.dart b/lib/src/model_selection/_init_module.dart index b49f4589..54faea24 100644 --- a/lib/src/model_selection/_init_module.dart +++ b/lib/src/model_selection/_init_module.dart @@ -1,11 +1,13 @@ +import 'package:ml_algo/src/di/common/init_common_module.dart'; import 'package:ml_algo/src/model_selection/_injector.dart'; import 'package:ml_algo/src/model_selection/split_indices_provider/split_indices_provider_factory.dart'; import 'package:ml_algo/src/model_selection/split_indices_provider/split_indices_provider_factory_impl.dart'; +import 'package:ml_algo/src/extensions/injector.dart'; void initModelSelectionModule() { - if (!modelSelectionInjector.exists()) { - modelSelectionInjector - ..registerSingleton( - () => const SplitIndicesProviderFactoryImpl()); - } + initCommonModule(); + + modelSelectionInjector + ..registerSingletonIf( + () => const SplitIndicesProviderFactoryImpl()); } diff --git a/lib/src/model_selection/cross_validator/_helpers/assess_predictor.dart b/lib/src/model_selection/cross_validator/_helpers/assess_predictor.dart new file mode 100644 index 00000000..71e9e33c --- /dev/null +++ b/lib/src/model_selection/cross_validator/_helpers/assess_predictor.dart @@ -0,0 +1,13 @@ +import 'package:ml_algo/src/model_selection/assessable.dart'; +import 'package:ml_algo/src/model_selection/cross_validator/cross_validator_isolate_message.dart'; + +num assessPredictor(Map encodedMessage) { + final message = CrossValidatorIsolateMessage + .fromJson(encodedMessage); + final predictor = message + .predictorPrototype + .retrain(message.trainData); + + return (predictor as Assessable) + .assess(message.testData, message.metricType); +} diff --git a/lib/src/model_selection/cross_validator/_helpers/from_serializable_predictor_Json.dart b/lib/src/model_selection/cross_validator/_helpers/from_serializable_predictor_Json.dart new file mode 100644 index 00000000..65ea2fab --- /dev/null +++ b/lib/src/model_selection/cross_validator/_helpers/from_serializable_predictor_Json.dart @@ -0,0 +1,5 @@ +import 'package:ml_algo/src/model_selection/serializable_predictor.dart'; + +SerializablePredictor fromSerializablePredictorJson(Map json) { + return null; +} diff --git a/lib/src/model_selection/cross_validator/_init_module.dart b/lib/src/model_selection/cross_validator/_init_module.dart new file mode 100644 index 00000000..e7f6db03 --- /dev/null +++ b/lib/src/model_selection/cross_validator/_init_module.dart @@ -0,0 +1,5 @@ +import 'package:ml_algo/src/di/common/init_common_module.dart'; + +void initCrossValidatorModule() { + initCommonModule(); +} diff --git a/lib/src/model_selection/cross_validator/_injector.dart b/lib/src/model_selection/cross_validator/_injector.dart new file mode 100644 index 00000000..1c789f9a --- /dev/null +++ b/lib/src/model_selection/cross_validator/_injector.dart @@ -0,0 +1,3 @@ +import 'package:injector/injector.dart'; + +final crossValidatorInjector = Injector(); diff --git a/lib/src/model_selection/cross_validator/cross_validator.dart b/lib/src/model_selection/cross_validator/cross_validator.dart index 2e6692a0..642b9b27 100644 --- a/lib/src/model_selection/cross_validator/cross_validator.dart +++ b/lib/src/model_selection/cross_validator/cross_validator.dart @@ -1,15 +1,17 @@ +import 'package:ml_algo/src/di/injector.dart'; import 'package:ml_algo/src/metric/metric_type.dart'; import 'package:ml_algo/src/model_selection/_init_module.dart'; import 'package:ml_algo/src/model_selection/_injector.dart'; -import 'package:ml_algo/src/model_selection/assessable.dart'; import 'package:ml_algo/src/model_selection/cross_validator/cross_validator_impl.dart'; +import 'package:ml_algo/src/model_selection/serializable_predictor.dart'; import 'package:ml_algo/src/model_selection/split_indices_provider/split_indices_provider_factory.dart'; import 'package:ml_algo/src/model_selection/split_indices_provider/split_indices_provider_type.dart'; +import 'package:ml_algo/src/service/worker_manager/worker_manager.dart'; import 'package:ml_dataframe/ml_dataframe.dart'; import 'package:ml_linalg/dtype.dart'; import 'package:ml_linalg/linalg.dart'; -typedef PredictorFactory = Assessable Function(DataFrame observations); +typedef PredictorFactoryFn = SerializablePredictor Function(DataFrame observations); typedef DataPreprocessFn = List Function(DataFrame trainData, DataFrame testData); @@ -39,11 +41,13 @@ abstract class CrossValidator { final dataSplitterFactory = modelSelectionInjector .get(); + final workerManager = injector.get(); final dataSplitter = dataSplitterFactory .createByType(SplitIndicesProviderType.kFold, numberOfFolds: numberOfFolds); return CrossValidatorImpl( samples, + workerManager, dataSplitter, dtype, ); @@ -71,11 +75,13 @@ abstract class CrossValidator { final dataSplitterFactory = modelSelectionInjector .get(); + final workerManager = injector.get(); final dataSplitter = dataSplitterFactory .createByType(SplitIndicesProviderType.lpo, p: p); return CrossValidatorImpl( samples, + workerManager, dataSplitter, dtype, ); @@ -133,7 +139,7 @@ abstract class CrossValidator { /// print(averageScore); /// ```` Future evaluate( - PredictorFactory predictorFactory, + PredictorFactoryFn predictorFactory, MetricType metricType, { DataPreprocessFn onDataSplit, diff --git a/lib/src/model_selection/cross_validator/cross_validator_impl.dart b/lib/src/model_selection/cross_validator/cross_validator_impl.dart index 23cc83b4..9f7cbf43 100644 --- a/lib/src/model_selection/cross_validator/cross_validator_impl.dart +++ b/lib/src/model_selection/cross_validator/cross_validator_impl.dart @@ -1,8 +1,11 @@ import 'package:ml_algo/src/common/exception/invalid_test_data_columns_number_exception.dart'; import 'package:ml_algo/src/common/exception/invalid_train_data_columns_number_exception.dart'; import 'package:ml_algo/src/metric/metric_type.dart'; +import 'package:ml_algo/src/model_selection/cross_validator/_helpers/assess_predictor.dart'; import 'package:ml_algo/src/model_selection/cross_validator/cross_validator.dart'; +import 'package:ml_algo/src/model_selection/cross_validator/cross_validator_isolate_message.dart'; import 'package:ml_algo/src/model_selection/split_indices_provider/split_indices_provider.dart'; +import 'package:ml_algo/src/service/worker_manager/worker_manager.dart'; import 'package:ml_dataframe/ml_dataframe.dart'; import 'package:ml_linalg/dtype.dart'; import 'package:ml_linalg/matrix.dart'; @@ -12,40 +15,47 @@ import 'package:quiver/iterables.dart'; class CrossValidatorImpl implements CrossValidator { CrossValidatorImpl( this.samples, + this._workerManager, this._splitter, this.dtype, ); final DataFrame samples; - final DType dtype; + final WorkerManager _workerManager; final SplitIndicesProvider _splitter; + final DType dtype; @override Future evaluate( - PredictorFactory predictorFactory, + PredictorFactoryFn predictorFactory, MetricType metricType, { DataPreprocessFn onDataSplit, } - ) { + ) async { final samplesAsMatrix = samples.toMatrix(dtype); final sourceColumnsNum = samplesAsMatrix.columnsNum; final discreteColumns = enumerate(samples.series) .where((indexedSeries) => indexedSeries.value.isDiscrete) .map((indexedSeries) => indexedSeries.index); - final allIndicesGroups = _splitter.getIndices(samplesAsMatrix.rowsNum); - final scores = allIndicesGroups + final allIndicesGroups = _splitter + .getIndices(samplesAsMatrix.rowsNum); + final scoreFutures = allIndicesGroups .map((testRowsIndices) { final split = _makeSplit(testRowsIndices, discreteColumns); final trainDataFrame = split[0]; final testDataFrame = split[1]; - final splits = onDataSplit != null + final transformedData = onDataSplit != null ? onDataSplit(trainDataFrame, testDataFrame) : [trainDataFrame, testDataFrame]; - final transformedTrainData = splits[0]; - final transformedTestData = splits[1]; - final transformedTrainDataColumnsNum = transformedTrainData.header.length; - final transformedTestDataColumnsNum = transformedTestData.header.length; + final transformedTrainData = transformedData[0]; + final transformedTestData = transformedData[1]; + final transformedTrainDataColumnsNum = transformedTrainData + .header + .length; + final transformedTestDataColumnsNum = transformedTestData + .header + .length; if (transformedTrainDataColumnsNum != sourceColumnsNum) { throw InvalidTrainDataColumnsNumberException(sourceColumnsNum, @@ -57,12 +67,18 @@ class CrossValidatorImpl implements CrossValidator { transformedTestDataColumnsNum); } - return predictorFactory(transformedTrainData) - .assess(transformedTestData, metricType); + return _assessPredictor( + predictorFactory, + transformedTrainData, + transformedTestData, + metricType, + ); }) .toList(); + final scores = await Future.wait(scoreFutures); - return Future.value(Vector.fromList(scores, dtype: dtype)); + return Future + .value(Vector.fromList(scores, dtype: dtype)); } List _makeSplit(Iterable testRowsIndices, @@ -97,4 +113,28 @@ class CrossValidatorImpl implements CrossValidator { ), ]; } + + Future _assessPredictor( + PredictorFactoryFn predictorFactoryFn, + DataFrame trainData, + DataFrame testData, + MetricType metricType, + ) async { + await _workerManager.init(); + + final samplesPrototype = samples.sampleFromRows([0]); + final predictorPrototype = predictorFactoryFn(samplesPrototype); + + return _workerManager + .executor + .execute( + arg1: CrossValidatorIsolateMessage( + predictorPrototype, + trainData, + testData, + metricType, + ).toJson(), + fun1: assessPredictor, + ); + } } diff --git a/lib/src/model_selection/cross_validator/cross_validator_isolate_message.dart b/lib/src/model_selection/cross_validator/cross_validator_isolate_message.dart new file mode 100644 index 00000000..8299c163 --- /dev/null +++ b/lib/src/model_selection/cross_validator/cross_validator_isolate_message.dart @@ -0,0 +1,39 @@ +import 'package:json_annotation/json_annotation.dart'; +import 'package:ml_algo/src/metric/metric_type.dart'; +import 'package:ml_algo/src/model_selection/cross_validator/_helpers/from_serializable_predictor_Json.dart'; +import 'package:ml_algo/src/model_selection/cross_validator/cross_validator_isolate_message_json_keys.dart'; +import 'package:ml_algo/src/model_selection/serializable_predictor.dart'; +import 'package:ml_dataframe/ml_dataframe.dart'; + +part 'cross_validator_isolate_message.g.dart'; + +@JsonSerializable() +class CrossValidatorIsolateMessage { + @JsonKey( + name: predictorPrototypeJsonKey, + fromJson: fromSerializablePredictorJson, + ) + final SerializablePredictor predictorPrototype; + + @JsonKey(name: trainDataJsonKey) + final DataFrame trainData; + + @JsonKey(name: testDataJsonKey) + final DataFrame testData; + + @JsonKey(name: metricTypeJsonKey) + final MetricType metricType; + + CrossValidatorIsolateMessage( + this.predictorPrototype, + this.trainData, + this.testData, + this.metricType, + ); + + factory CrossValidatorIsolateMessage.fromJson(Map json) => + _$CrossValidatorIsolateMessageFromJson(json); + + Map toJson() => + _$CrossValidatorIsolateMessageToJson(this); +} diff --git a/lib/src/model_selection/cross_validator/cross_validator_isolate_message.g.dart b/lib/src/model_selection/cross_validator/cross_validator_isolate_message.g.dart new file mode 100644 index 00000000..6d6d423f --- /dev/null +++ b/lib/src/model_selection/cross_validator/cross_validator_isolate_message.g.dart @@ -0,0 +1,86 @@ +// GENERATED CODE - DO NOT MODIFY BY HAND + +part of 'cross_validator_isolate_message.dart'; + +// ************************************************************************** +// JsonSerializableGenerator +// ************************************************************************** + +CrossValidatorIsolateMessage _$CrossValidatorIsolateMessageFromJson( + Map json) { + return $checkedNew('CrossValidatorIsolateMessage', json, () { + $checkKeys(json, allowedKeys: const ['P', 'T', 'TE', 'M']); + final val = CrossValidatorIsolateMessage( + $checkedConvert(json, 'P', + (v) => fromSerializablePredictorJson(v as Map)), + $checkedConvert( + json, + 'T', + (v) => + v == null ? null : DataFrame.fromJson(v as Map)), + $checkedConvert( + json, + 'TE', + (v) => + v == null ? null : DataFrame.fromJson(v as Map)), + $checkedConvert( + json, 'M', (v) => _$enumDecodeNullable(_$MetricTypeEnumMap, v)), + ); + return val; + }, fieldKeyMap: const { + 'predictorPrototype': 'P', + 'trainData': 'T', + 'testData': 'TE', + 'metricType': 'M' + }); +} + +Map _$CrossValidatorIsolateMessageToJson( + CrossValidatorIsolateMessage instance) => + { + 'P': instance.predictorPrototype, + 'T': instance.trainData, + 'TE': instance.testData, + 'M': _$MetricTypeEnumMap[instance.metricType], + }; + +T _$enumDecode( + Map enumValues, + dynamic source, { + T unknownValue, +}) { + if (source == null) { + throw ArgumentError('A value must be provided. Supported values: ' + '${enumValues.values.join(', ')}'); + } + + final value = enumValues.entries + .singleWhere((e) => e.value == source, orElse: () => null) + ?.key; + + if (value == null && unknownValue == null) { + throw ArgumentError('`$source` is not one of the supported values: ' + '${enumValues.values.join(', ')}'); + } + return value ?? unknownValue; +} + +T _$enumDecodeNullable( + Map enumValues, + dynamic source, { + T unknownValue, +}) { + if (source == null) { + return null; + } + return _$enumDecode(enumValues, source, unknownValue: unknownValue); +} + +const _$MetricTypeEnumMap = { + MetricType.mape: 'mape', + MetricType.rmse: 'rmse', + MetricType.rss: 'rss', + MetricType.accuracy: 'accuracy', + MetricType.precision: 'precision', + MetricType.recall: 'recall', +}; diff --git a/lib/src/model_selection/cross_validator/cross_validator_isolate_message_json_keys.dart b/lib/src/model_selection/cross_validator/cross_validator_isolate_message_json_keys.dart new file mode 100644 index 00000000..878a3e53 --- /dev/null +++ b/lib/src/model_selection/cross_validator/cross_validator_isolate_message_json_keys.dart @@ -0,0 +1,5 @@ +const predictorPrototypeJsonKey = 'P'; +const trainDataJsonKey = 'T'; +const testDataJsonKey = 'TE'; +const testRowsIndicesJsonKey = 'R'; +const metricTypeJsonKey = 'M'; diff --git a/lib/src/model_selection/serializable_predictor.dart b/lib/src/model_selection/serializable_predictor.dart new file mode 100644 index 00000000..ed96d021 --- /dev/null +++ b/lib/src/model_selection/serializable_predictor.dart @@ -0,0 +1,10 @@ +import 'package:ml_algo/src/common/serializable/serializable.dart'; +import 'package:ml_algo/src/model_selection/assessable.dart'; +import 'package:ml_algo/src/predictor/predictor.dart'; +import 'package:ml_algo/src/predictor/retrainable.dart'; + +abstract class SerializablePredictor implements + Assessable, + Serializable, + Retrainable, + Predictor {} diff --git a/lib/src/service/worker_manager/worker_manager.dart b/lib/src/service/worker_manager/worker_manager.dart new file mode 100644 index 00000000..b3114017 --- /dev/null +++ b/lib/src/service/worker_manager/worker_manager.dart @@ -0,0 +1,6 @@ +import 'package:worker_manager/worker_manager.dart'; + +abstract class WorkerManager { + Future init(); + Executor get executor; +} diff --git a/lib/src/service/worker_manager/worker_manager_impl.dart b/lib/src/service/worker_manager/worker_manager_impl.dart new file mode 100644 index 00000000..7527c4ae --- /dev/null +++ b/lib/src/service/worker_manager/worker_manager_impl.dart @@ -0,0 +1,19 @@ +import 'package:ml_algo/src/service/worker_manager/worker_manager.dart'; +import 'package:worker_manager/src/executor.dart'; + +class WorkerManagerImpl implements WorkerManager { + bool _isInitialized = false; + + @override + Executor get executor => Executor(); + + @override + Future init() async { + if (_isInitialized) { + return; + } + + await Executor().warmUp(); + _isInitialized = true; + } +} diff --git a/pubspec.yaml b/pubspec.yaml index e89067a4..053ba114 100644 --- a/pubspec.yaml +++ b/pubspec.yaml @@ -14,6 +14,7 @@ dependencies: ml_linalg: ^12.17.8 ml_preprocessing: ^5.2.1 quiver: ^2.0.2 + worker_manager: ^3.1.8 xrange: ^0.0.8 dev_dependencies: diff --git a/test/mocks.dart b/test/mocks.dart index 5513691a..faed8a6e 100644 --- a/test/mocks.dart +++ b/test/mocks.dart @@ -30,6 +30,7 @@ import 'package:ml_algo/src/metric/metric.dart'; import 'package:ml_algo/src/metric/metric_factory.dart'; import 'package:ml_algo/src/model_selection/assessable.dart'; import 'package:ml_algo/src/model_selection/model_assessor/classifier_assessor.dart'; +import 'package:ml_algo/src/model_selection/serializable_predictor.dart'; import 'package:ml_algo/src/model_selection/split_indices_provider/split_indices_provider.dart'; import 'package:ml_algo/src/model_selection/split_indices_provider/split_indices_provider_factory.dart'; import 'package:ml_algo/src/predictor/predictor.dart'; @@ -37,6 +38,7 @@ import 'package:ml_algo/src/regressor/knn_regressor/knn_regressor.dart'; import 'package:ml_algo/src/regressor/knn_regressor/knn_regressor_factory.dart'; import 'package:ml_algo/src/regressor/linear_regressor/linear_regressor.dart'; import 'package:ml_algo/src/regressor/linear_regressor/linear_regressor_factory.dart'; +import 'package:ml_algo/src/service/worker_manager/worker_manager.dart'; import 'package:ml_algo/src/tree_trainer/decision_tree_trainer.dart'; import 'package:ml_algo/src/tree_trainer/leaf_detector/leaf_detector.dart'; import 'package:ml_algo/src/tree_trainer/leaf_detector/leaf_detector_factory.dart'; @@ -118,6 +120,8 @@ class DataSplitterFactoryMock extends Mock implements SplitIndicesProviderFactor class AssessableMock extends Mock implements Assessable {} +class SerializablePredictorMock extends Mock implements SerializablePredictor {} + class TreeSplitAssessorMock extends Mock implements TreeSplitAssessor {} class TreeSplitAssessorFactoryMock extends Mock implements @@ -205,6 +209,8 @@ class DecisionTreeClassifierMock extends Mock class LogisticRegressorFactoryMock extends Mock implements LogisticRegressorFactory {} +class WorkerManagerMock extends Mock implements WorkerManager {} + LearningRateGeneratorFactoryMock createLearningRateGeneratorFactoryMock( LearningRateGenerator generator) { final factory = LearningRateGeneratorFactoryMock(); diff --git a/test/model_selection/cross_validator/cross_validator_impl_test.dart b/test/model_selection/cross_validator/cross_validator_impl_test.dart index 73dc13cf..11a9b3d5 100644 --- a/test/model_selection/cross_validator/cross_validator_impl_test.dart +++ b/test/model_selection/cross_validator/cross_validator_impl_test.dart @@ -18,6 +18,8 @@ SplitIndicesProvider createSplitter(Iterable> indices) { void main() { group('CrossValidatorImpl', () { + final workerManagerMock = WorkerManagerMock(); + test('should evaluate performance of a predictor on given test ' 'splits', () async { final allObservations = DataFrame(>[ @@ -33,9 +35,13 @@ void main() { ], header: ['first', 'second', 'third', 'target'], headerExists: false); final metric = MetricType.mape; final splitter = createSplitter([[0,2,4],[6, 8]]); - final predictor = AssessableMock(); - final validator = CrossValidatorImpl(allObservations, splitter, - DType.float32); + final predictor = SerializablePredictorMock(); + final validator = CrossValidatorImpl( + allObservations, + workerManagerMock, + splitter, + DType.float32, + ); final score = 20.0; when(predictor.assess(any, any)).thenReturn(score); @@ -77,9 +83,13 @@ void main() { final metric = MetricType.mape; final splitter = createSplitter([[0], [0], [0]]); - final predictor = AssessableMock(); - final validator = CrossValidatorImpl(allObservations, splitter, - DType.float32); + final predictor = SerializablePredictorMock(); + final validator = CrossValidatorImpl( + allObservations, + workerManagerMock, + splitter, + DType.float32, + ); when(predictor.assess(any, any)).thenReturn(1); @@ -137,9 +147,13 @@ void main() { final metric = MetricType.mape; final splitter = createSplitter([[0], [0], [0]]); - final predictor = AssessableMock(); - final validator = CrossValidatorImpl(allObservations, splitter, - DType.float32); + final predictor = SerializablePredictorMock(); + final validator = CrossValidatorImpl( + allObservations, + workerManagerMock, + splitter, + DType.float32, + ); when(predictor.assess(any, any)).thenReturn(1); @@ -216,9 +230,13 @@ void main() { final metric = MetricType.mape; final splitter = createSplitter([[0], [2], [4]]); - final predictor = AssessableMock(); - final validator = CrossValidatorImpl(allObservations, - splitter, DType.float32); + final predictor = SerializablePredictorMock(); + final validator = CrossValidatorImpl( + allObservations, + workerManagerMock, + splitter, + DType.float32, + ); when(predictor.assess(any, any)).thenReturn(1); @@ -318,9 +336,13 @@ void main() { final metric = MetricType.mape; final splitter = createSplitter([[0], [2], [4]]); - final predictor = AssessableMock(); - final validator = CrossValidatorImpl(originalData, splitter, - DType.float32); + final predictor = SerializablePredictorMock(); + final validator = CrossValidatorImpl( + originalData, + workerManagerMock, + splitter, + DType.float32, + ); when(predictor.assess(any, any)).thenReturn(1); @@ -362,9 +384,13 @@ void main() { final metric = MetricType.mape; final splitter = createSplitter([[0], [2], [4]]); - final predictor = AssessableMock(); - final validator = CrossValidatorImpl(originalData, splitter, - DType.float32); + final predictor = SerializablePredictorMock(); + final validator = CrossValidatorImpl( + originalData, + workerManagerMock, + splitter, + DType.float32, + ); when(predictor.assess(any, any)).thenReturn(1); @@ -405,9 +431,13 @@ void main() { final metric = MetricType.mape; final splitter = createSplitter([[0], [2], [4]]); - final predictor = AssessableMock(); - final validator = CrossValidatorImpl(originalData, splitter, - DType.float32); + final predictor = SerializablePredictorMock(); + final validator = CrossValidatorImpl( + originalData, + workerManagerMock, + splitter, + DType.float32, + ); when(predictor.assess(any, any)).thenReturn(1); @@ -448,9 +478,13 @@ void main() { final metric = MetricType.mape; final splitter = createSplitter([[0], [2], [4]]); - final predictor = AssessableMock(); - final validator = CrossValidatorImpl(originalData, splitter, - DType.float32); + final predictor = SerializablePredictorMock(); + final validator = CrossValidatorImpl( + originalData, + workerManagerMock, + splitter, + DType.float32, + ); when(predictor.assess(any, any)).thenReturn(1); From f82cffe49144cab4cbaceb63de4074f63488f9d0 Mon Sep 17 00:00:00 2001 From: Ilya Gyrdymov Date: Sat, 26 Dec 2020 20:21:22 +0200 Subject: [PATCH 2/5] CrossValidator: isolates used --- .../_mixins/assessable_classifier_mixin.dart | 3 +-- .../decision_tree_classifier.dart | 9 +-------- .../classifier/knn_classifier/knn_classifier.dart | 9 +-------- .../logistic_regressor/logistic_regressor.dart | 9 +-------- .../softmax_regressor/softmax_regressor.dart | 9 +-------- lib/src/model_selection/assessable.dart | 10 ---------- .../cross_validator/_helpers/assess_predictor.dart | 3 +-- .../_helpers/from_serializable_predictor_Json.dart | 4 ++-- .../cross_validator/cross_validator.dart | 4 ++-- .../cross_validator_isolate_message.dart | 4 ++-- .../model_selection/serializable_predictor.dart | 10 ---------- lib/src/predictor/predictor.dart | 14 ++++++++++++-- lib/src/predictor/retrainable.dart | 8 -------- .../_mixins/assessable_regressor_mixin.dart | 6 +----- lib/src/regressor/knn_regressor/knn_regressor.dart | 9 +-------- .../linear_regressor/linear_regressor.dart | 9 +-------- test/mocks.dart | 8 +++----- 17 files changed, 30 insertions(+), 98 deletions(-) delete mode 100644 lib/src/model_selection/assessable.dart delete mode 100644 lib/src/model_selection/serializable_predictor.dart delete mode 100644 lib/src/predictor/retrainable.dart diff --git a/lib/src/classifier/_mixins/assessable_classifier_mixin.dart b/lib/src/classifier/_mixins/assessable_classifier_mixin.dart index 9978fa59..47ff3cdb 100644 --- a/lib/src/classifier/_mixins/assessable_classifier_mixin.dart +++ b/lib/src/classifier/_mixins/assessable_classifier_mixin.dart @@ -1,11 +1,10 @@ import 'package:ml_algo/src/classifier/classifier.dart'; import 'package:ml_algo/src/di/injector.dart'; import 'package:ml_algo/src/metric/metric_type.dart'; -import 'package:ml_algo/src/model_selection/assessable.dart'; import 'package:ml_algo/src/model_selection/model_assessor/model_assessor.dart'; import 'package:ml_dataframe/ml_dataframe.dart'; -mixin AssessableClassifierMixin implements Assessable, Classifier { +mixin AssessableClassifierMixin implements Classifier { @override double assess( DataFrame samples, diff --git a/lib/src/classifier/decision_tree_classifier/decision_tree_classifier.dart b/lib/src/classifier/decision_tree_classifier/decision_tree_classifier.dart index 2754d2ec..228c07e3 100644 --- a/lib/src/classifier/decision_tree_classifier/decision_tree_classifier.dart +++ b/lib/src/classifier/decision_tree_classifier/decision_tree_classifier.dart @@ -2,8 +2,6 @@ import 'package:ml_algo/src/classifier/classifier.dart'; import 'package:ml_algo/src/classifier/decision_tree_classifier/_init_module.dart'; import 'package:ml_algo/src/classifier/decision_tree_classifier/decision_tree_classifier_factory.dart'; import 'package:ml_algo/src/common/serializable/serializable.dart'; -import 'package:ml_algo/src/model_selection/assessable.dart'; -import 'package:ml_algo/src/predictor/retrainable.dart'; import 'package:ml_dataframe/ml_dataframe.dart'; import 'package:ml_linalg/dtype.dart'; @@ -16,12 +14,7 @@ import 'package:ml_linalg/dtype.dart'; /// decision tree learning. Once a decision tree learned, it may be used to /// classify new samples with the same features that were used to learn the /// tree. -abstract class DecisionTreeClassifier - implements - Assessable, - Serializable, - Retrainable, - Classifier { +abstract class DecisionTreeClassifier implements Serializable, Classifier { /// Parameters: /// diff --git a/lib/src/classifier/knn_classifier/knn_classifier.dart b/lib/src/classifier/knn_classifier/knn_classifier.dart index cd8c5e3d..c39c5fb6 100644 --- a/lib/src/classifier/knn_classifier/knn_classifier.dart +++ b/lib/src/classifier/knn_classifier/knn_classifier.dart @@ -3,8 +3,6 @@ import 'package:ml_algo/src/classifier/knn_classifier/_init_module.dart'; import 'package:ml_algo/src/classifier/knn_classifier/knn_classifier_factory.dart'; import 'package:ml_algo/src/common/serializable/serializable.dart'; import 'package:ml_algo/src/knn_kernel/kernel_type.dart'; -import 'package:ml_algo/src/model_selection/assessable.dart'; -import 'package:ml_algo/src/predictor/retrainable.dart'; import 'package:ml_dataframe/ml_dataframe.dart'; import 'package:ml_linalg/distance.dart'; import 'package:ml_linalg/dtype.dart'; @@ -21,12 +19,7 @@ import 'package:ml_linalg/dtype.dart'; /// imprecise result. Thus the weighted version of KNN algorithm is used in the /// classifier. To get weight of a particular observation one may use a kernel /// function. -abstract class KnnClassifier - implements - Assessable, - Serializable, - Retrainable, - Classifier { +abstract class KnnClassifier implements Serializable, Classifier { /// Parameters: /// /// [trainData] Labelled observations. Must contain [targetName] column. diff --git a/lib/src/classifier/logistic_regressor/logistic_regressor.dart b/lib/src/classifier/logistic_regressor/logistic_regressor.dart index 25d2f321..808e0593 100644 --- a/lib/src/classifier/logistic_regressor/logistic_regressor.dart +++ b/lib/src/classifier/logistic_regressor/logistic_regressor.dart @@ -6,8 +6,6 @@ import 'package:ml_algo/src/linear_optimizer/gradient_optimizer/learning_rate_ge import 'package:ml_algo/src/linear_optimizer/initial_coefficients_generator/initial_coefficients_type.dart'; import 'package:ml_algo/src/linear_optimizer/linear_optimizer_type.dart'; import 'package:ml_algo/src/linear_optimizer/regularization_type.dart'; -import 'package:ml_algo/src/model_selection/assessable.dart'; -import 'package:ml_algo/src/predictor/retrainable.dart'; import 'package:ml_dataframe/ml_dataframe.dart'; import 'package:ml_linalg/dtype.dart'; import 'package:ml_linalg/vector.dart'; @@ -19,12 +17,7 @@ import 'package:ml_linalg/vector.dart'; /// In other words, the regressor iteratively tries to select coefficients /// that makes combination of passed features and these coefficients most /// likely. -abstract class LogisticRegressor - implements - Assessable, - Serializable, - Retrainable, - LinearClassifier { +abstract class LogisticRegressor implements Serializable, LinearClassifier { /// Parameters: /// diff --git a/lib/src/classifier/softmax_regressor/softmax_regressor.dart b/lib/src/classifier/softmax_regressor/softmax_regressor.dart index 193f3839..3441b5f1 100644 --- a/lib/src/classifier/softmax_regressor/softmax_regressor.dart +++ b/lib/src/classifier/softmax_regressor/softmax_regressor.dart @@ -6,8 +6,6 @@ import 'package:ml_algo/src/linear_optimizer/gradient_optimizer/learning_rate_ge import 'package:ml_algo/src/linear_optimizer/initial_coefficients_generator/initial_coefficients_type.dart'; import 'package:ml_algo/src/linear_optimizer/linear_optimizer_type.dart'; import 'package:ml_algo/src/linear_optimizer/regularization_type.dart'; -import 'package:ml_algo/src/model_selection/assessable.dart'; -import 'package:ml_algo/src/predictor/retrainable.dart'; import 'package:ml_dataframe/ml_dataframe.dart'; import 'package:ml_linalg/dtype.dart'; import 'package:ml_linalg/matrix.dart'; @@ -26,12 +24,7 @@ import 'package:ml_linalg/matrix.dart'; /// /// Also it is worth to mention that the algorithm is a generalization of /// [Logistic regression](https://en.wikipedia.org/wiki/Logistic_regression)) -abstract class SoftmaxRegressor - implements - Assessable, - Serializable, - Retrainable, - LinearClassifier { +abstract class SoftmaxRegressor implements Serializable, LinearClassifier { /// Parameters: /// diff --git a/lib/src/model_selection/assessable.dart b/lib/src/model_selection/assessable.dart deleted file mode 100644 index 8811f417..00000000 --- a/lib/src/model_selection/assessable.dart +++ /dev/null @@ -1,10 +0,0 @@ -import 'package:ml_algo/src/metric/metric_type.dart'; -import 'package:ml_dataframe/ml_dataframe.dart'; - -/// An interface for a ML model's performance assessment -abstract class Assessable { - /// Assesses model performance according to provided [metricType] - /// - /// Throws an exception if inappropriate [metricType] provided. - double assess(DataFrame observations, MetricType metricType); -} diff --git a/lib/src/model_selection/cross_validator/_helpers/assess_predictor.dart b/lib/src/model_selection/cross_validator/_helpers/assess_predictor.dart index 71e9e33c..e84421a1 100644 --- a/lib/src/model_selection/cross_validator/_helpers/assess_predictor.dart +++ b/lib/src/model_selection/cross_validator/_helpers/assess_predictor.dart @@ -1,4 +1,3 @@ -import 'package:ml_algo/src/model_selection/assessable.dart'; import 'package:ml_algo/src/model_selection/cross_validator/cross_validator_isolate_message.dart'; num assessPredictor(Map encodedMessage) { @@ -8,6 +7,6 @@ num assessPredictor(Map encodedMessage) { .predictorPrototype .retrain(message.trainData); - return (predictor as Assessable) + return predictor .assess(message.testData, message.metricType); } diff --git a/lib/src/model_selection/cross_validator/_helpers/from_serializable_predictor_Json.dart b/lib/src/model_selection/cross_validator/_helpers/from_serializable_predictor_Json.dart index 65ea2fab..69457146 100644 --- a/lib/src/model_selection/cross_validator/_helpers/from_serializable_predictor_Json.dart +++ b/lib/src/model_selection/cross_validator/_helpers/from_serializable_predictor_Json.dart @@ -1,5 +1,5 @@ -import 'package:ml_algo/src/model_selection/serializable_predictor.dart'; +import 'package:ml_algo/src/predictor/predictor.dart'; -SerializablePredictor fromSerializablePredictorJson(Map json) { +Predictor fromSerializablePredictorJson(Map json) { return null; } diff --git a/lib/src/model_selection/cross_validator/cross_validator.dart b/lib/src/model_selection/cross_validator/cross_validator.dart index 642b9b27..b8e7113a 100644 --- a/lib/src/model_selection/cross_validator/cross_validator.dart +++ b/lib/src/model_selection/cross_validator/cross_validator.dart @@ -3,15 +3,15 @@ import 'package:ml_algo/src/metric/metric_type.dart'; import 'package:ml_algo/src/model_selection/_init_module.dart'; import 'package:ml_algo/src/model_selection/_injector.dart'; import 'package:ml_algo/src/model_selection/cross_validator/cross_validator_impl.dart'; -import 'package:ml_algo/src/model_selection/serializable_predictor.dart'; import 'package:ml_algo/src/model_selection/split_indices_provider/split_indices_provider_factory.dart'; import 'package:ml_algo/src/model_selection/split_indices_provider/split_indices_provider_type.dart'; +import 'package:ml_algo/src/predictor/predictor.dart'; import 'package:ml_algo/src/service/worker_manager/worker_manager.dart'; import 'package:ml_dataframe/ml_dataframe.dart'; import 'package:ml_linalg/dtype.dart'; import 'package:ml_linalg/linalg.dart'; -typedef PredictorFactoryFn = SerializablePredictor Function(DataFrame observations); +typedef PredictorFactoryFn = Predictor Function(DataFrame observations); typedef DataPreprocessFn = List Function(DataFrame trainData, DataFrame testData); diff --git a/lib/src/model_selection/cross_validator/cross_validator_isolate_message.dart b/lib/src/model_selection/cross_validator/cross_validator_isolate_message.dart index 8299c163..f08911a5 100644 --- a/lib/src/model_selection/cross_validator/cross_validator_isolate_message.dart +++ b/lib/src/model_selection/cross_validator/cross_validator_isolate_message.dart @@ -2,7 +2,7 @@ import 'package:json_annotation/json_annotation.dart'; import 'package:ml_algo/src/metric/metric_type.dart'; import 'package:ml_algo/src/model_selection/cross_validator/_helpers/from_serializable_predictor_Json.dart'; import 'package:ml_algo/src/model_selection/cross_validator/cross_validator_isolate_message_json_keys.dart'; -import 'package:ml_algo/src/model_selection/serializable_predictor.dart'; +import 'package:ml_algo/src/predictor/predictor.dart'; import 'package:ml_dataframe/ml_dataframe.dart'; part 'cross_validator_isolate_message.g.dart'; @@ -13,7 +13,7 @@ class CrossValidatorIsolateMessage { name: predictorPrototypeJsonKey, fromJson: fromSerializablePredictorJson, ) - final SerializablePredictor predictorPrototype; + final Predictor predictorPrototype; @JsonKey(name: trainDataJsonKey) final DataFrame trainData; diff --git a/lib/src/model_selection/serializable_predictor.dart b/lib/src/model_selection/serializable_predictor.dart deleted file mode 100644 index ed96d021..00000000 --- a/lib/src/model_selection/serializable_predictor.dart +++ /dev/null @@ -1,10 +0,0 @@ -import 'package:ml_algo/src/common/serializable/serializable.dart'; -import 'package:ml_algo/src/model_selection/assessable.dart'; -import 'package:ml_algo/src/predictor/predictor.dart'; -import 'package:ml_algo/src/predictor/retrainable.dart'; - -abstract class SerializablePredictor implements - Assessable, - Serializable, - Retrainable, - Predictor {} diff --git a/lib/src/predictor/predictor.dart b/lib/src/predictor/predictor.dart index ac7cca88..df157eee 100644 --- a/lib/src/predictor/predictor.dart +++ b/lib/src/predictor/predictor.dart @@ -1,3 +1,4 @@ +import 'package:ml_algo/src/metric/metric_type.dart'; import 'package:ml_dataframe/ml_dataframe.dart'; import 'package:ml_linalg/dtype.dart'; @@ -7,9 +8,18 @@ abstract class Predictor { /// model Iterable get targetNames; + /// A type for all the numeric values using by the [Predictor] + DType get dtype; + + /// Assesses model performance according to provided [metricType] + /// + /// Throws an exception if inappropriate [metricType] provided. + double assess(DataFrame observations, MetricType metricType); + /// Returns prediction, based on the model learned parameters DataFrame predict(DataFrame testFeatures); - /// A type for all the numeric values using by the [Predictor] - DType get dtype; + /// Re-runs the process on new training [data]. The features, model algorithm, + /// and hyperparameters remain the same. + Predictor retrain(DataFrame data); } diff --git a/lib/src/predictor/retrainable.dart b/lib/src/predictor/retrainable.dart deleted file mode 100644 index 25cebc2f..00000000 --- a/lib/src/predictor/retrainable.dart +++ /dev/null @@ -1,8 +0,0 @@ -import 'package:ml_algo/src/predictor/predictor.dart'; -import 'package:ml_dataframe/ml_dataframe.dart'; - -abstract class Retrainable { - /// Re-runs the process on new training [data]. The features, model algorithm, - /// and hyperparameters remain the same. - Predictor retrain(DataFrame data); -} diff --git a/lib/src/regressor/_mixins/assessable_regressor_mixin.dart b/lib/src/regressor/_mixins/assessable_regressor_mixin.dart index 8ea48296..fe60c71e 100644 --- a/lib/src/regressor/_mixins/assessable_regressor_mixin.dart +++ b/lib/src/regressor/_mixins/assessable_regressor_mixin.dart @@ -1,14 +1,10 @@ import 'package:ml_algo/src/di/injector.dart'; import 'package:ml_algo/src/metric/metric_type.dart'; -import 'package:ml_algo/src/model_selection/assessable.dart'; import 'package:ml_algo/src/model_selection/model_assessor/model_assessor.dart'; import 'package:ml_algo/src/predictor/predictor.dart'; import 'package:ml_dataframe/ml_dataframe.dart'; -mixin AssessableRegressorMixin implements - Assessable, - Predictor { - +mixin AssessableRegressorMixin implements Predictor { @override double assess( DataFrame samples, diff --git a/lib/src/regressor/knn_regressor/knn_regressor.dart b/lib/src/regressor/knn_regressor/knn_regressor.dart index 03bda804..ed1e2292 100644 --- a/lib/src/regressor/knn_regressor/knn_regressor.dart +++ b/lib/src/regressor/knn_regressor/knn_regressor.dart @@ -1,8 +1,6 @@ import 'package:ml_algo/src/common/serializable/serializable.dart'; import 'package:ml_algo/src/knn_kernel/kernel_type.dart'; -import 'package:ml_algo/src/model_selection/assessable.dart'; import 'package:ml_algo/src/predictor/predictor.dart'; -import 'package:ml_algo/src/predictor/retrainable.dart'; import 'package:ml_algo/src/regressor/knn_regressor/_init_module.dart'; import 'package:ml_algo/src/regressor/knn_regressor/knn_regressor_factory.dart'; import 'package:ml_dataframe/ml_dataframe.dart'; @@ -22,12 +20,7 @@ import 'package:ml_linalg/dtype.dart'; /// To get a more precise result, one may use weighted average of found labels - /// the farther a found observation from the target one, the lower the weight of /// the observation is. To obtain these weights one may use a kernel function. -abstract class KnnRegressor - implements - Assessable, - Serializable, - Retrainable, - Predictor { +abstract class KnnRegressor implements Serializable, Predictor { /// Parameters: /// /// [fittingData] Labelled observations, among which will be searched [k] diff --git a/lib/src/regressor/linear_regressor/linear_regressor.dart b/lib/src/regressor/linear_regressor/linear_regressor.dart index 183579f3..2c68720d 100644 --- a/lib/src/regressor/linear_regressor/linear_regressor.dart +++ b/lib/src/regressor/linear_regressor/linear_regressor.dart @@ -3,9 +3,7 @@ import 'package:ml_algo/src/linear_optimizer/gradient_optimizer/learning_rate_ge import 'package:ml_algo/src/linear_optimizer/initial_coefficients_generator/initial_coefficients_type.dart'; import 'package:ml_algo/src/linear_optimizer/linear_optimizer_type.dart'; import 'package:ml_algo/src/linear_optimizer/regularization_type.dart'; -import 'package:ml_algo/src/model_selection/assessable.dart'; import 'package:ml_algo/src/predictor/predictor.dart'; -import 'package:ml_algo/src/predictor/retrainable.dart'; import 'package:ml_algo/src/regressor/linear_regressor/_init_module.dart'; import 'package:ml_algo/src/regressor/linear_regressor/linear_regressor_factory.dart'; import 'package:ml_dataframe/ml_dataframe.dart'; @@ -22,12 +20,7 @@ import 'package:ml_linalg/vector.dart'; /// regressor should predict, and since all the `x` values are known (since they /// are the input for the algorithm), the regressor should find the best /// coefficients (weights) for each `x`-es to make a best prediction of `y` term. -abstract class LinearRegressor - implements - Assessable, - Serializable, - Retrainable, - Predictor { +abstract class LinearRegressor implements Serializable, Predictor { /// Parameters: /// /// [fittingData] A [DataFrame] with observations that is used by the diff --git a/test/mocks.dart b/test/mocks.dart index faed8a6e..0575d073 100644 --- a/test/mocks.dart +++ b/test/mocks.dart @@ -9,6 +9,7 @@ import 'package:ml_algo/src/classifier/softmax_regressor/softmax_regressor.dart' import 'package:ml_algo/src/classifier/softmax_regressor/softmax_regressor_factory.dart'; import 'package:ml_algo/src/common/distribution_calculator/distribution_calculator.dart'; import 'package:ml_algo/src/common/distribution_calculator/distribution_calculator_factory.dart'; +import 'package:ml_algo/src/common/serializable/serializable.dart'; import 'package:ml_algo/src/cost_function/cost_function.dart'; import 'package:ml_algo/src/cost_function/cost_function_factory.dart'; import 'package:ml_algo/src/knn_kernel/kernel.dart'; @@ -28,9 +29,7 @@ import 'package:ml_algo/src/math/randomizer/randomizer.dart'; import 'package:ml_algo/src/math/randomizer/randomizer_factory.dart'; import 'package:ml_algo/src/metric/metric.dart'; import 'package:ml_algo/src/metric/metric_factory.dart'; -import 'package:ml_algo/src/model_selection/assessable.dart'; import 'package:ml_algo/src/model_selection/model_assessor/classifier_assessor.dart'; -import 'package:ml_algo/src/model_selection/serializable_predictor.dart'; import 'package:ml_algo/src/model_selection/split_indices_provider/split_indices_provider.dart'; import 'package:ml_algo/src/model_selection/split_indices_provider/split_indices_provider_factory.dart'; import 'package:ml_algo/src/predictor/predictor.dart'; @@ -118,9 +117,8 @@ class DataSplitterMock extends Mock implements SplitIndicesProvider {} class DataSplitterFactoryMock extends Mock implements SplitIndicesProviderFactory {} -class AssessableMock extends Mock implements Assessable {} - -class SerializablePredictorMock extends Mock implements SerializablePredictor {} +class SerializablePredictorMock extends Mock + implements Serializable, Predictor {} class TreeSplitAssessorMock extends Mock implements TreeSplitAssessor {} From c4871a1af0d5e78fa307b86ac7e00d2de9fa8065 Mon Sep 17 00:00:00 2001 From: Ilya Gyrdymov Date: Sat, 26 Dec 2020 20:39:07 +0200 Subject: [PATCH 3/5] CrossValidator: isolates used --- lib/src/classifier/classifier.dart | 2 +- .../decision_tree_classifier/decision_tree_classifier.dart | 3 +-- lib/src/classifier/knn_classifier/knn_classifier.dart | 3 +-- lib/src/classifier/logistic_regressor/logistic_regressor.dart | 3 +-- lib/src/classifier/softmax_regressor/softmax_regressor.dart | 3 +-- lib/src/predictor/predictor.dart | 3 ++- lib/src/regressor/knn_regressor/knn_regressor.dart | 3 +-- lib/src/regressor/linear_regressor/linear_regressor.dart | 3 +-- 8 files changed, 9 insertions(+), 14 deletions(-) diff --git a/lib/src/classifier/classifier.dart b/lib/src/classifier/classifier.dart index 6b5a87b5..59f95d3f 100644 --- a/lib/src/classifier/classifier.dart +++ b/lib/src/classifier/classifier.dart @@ -3,7 +3,7 @@ import 'package:ml_dataframe/ml_dataframe.dart'; /// An interface for any classifier (linear, non-linear, parametric, /// non-parametric, etc.) -abstract class Classifier extends Predictor { +abstract class Classifier implements Predictor { /// Returns predicted distribution of probabilities for each observation in /// the passed [testFeatures] DataFrame predictProbabilities(DataFrame testFeatures); diff --git a/lib/src/classifier/decision_tree_classifier/decision_tree_classifier.dart b/lib/src/classifier/decision_tree_classifier/decision_tree_classifier.dart index 228c07e3..d34f9166 100644 --- a/lib/src/classifier/decision_tree_classifier/decision_tree_classifier.dart +++ b/lib/src/classifier/decision_tree_classifier/decision_tree_classifier.dart @@ -1,7 +1,6 @@ import 'package:ml_algo/src/classifier/classifier.dart'; import 'package:ml_algo/src/classifier/decision_tree_classifier/_init_module.dart'; import 'package:ml_algo/src/classifier/decision_tree_classifier/decision_tree_classifier_factory.dart'; -import 'package:ml_algo/src/common/serializable/serializable.dart'; import 'package:ml_dataframe/ml_dataframe.dart'; import 'package:ml_linalg/dtype.dart'; @@ -14,7 +13,7 @@ import 'package:ml_linalg/dtype.dart'; /// decision tree learning. Once a decision tree learned, it may be used to /// classify new samples with the same features that were used to learn the /// tree. -abstract class DecisionTreeClassifier implements Serializable, Classifier { +abstract class DecisionTreeClassifier implements Classifier { /// Parameters: /// diff --git a/lib/src/classifier/knn_classifier/knn_classifier.dart b/lib/src/classifier/knn_classifier/knn_classifier.dart index c39c5fb6..7e3305d3 100644 --- a/lib/src/classifier/knn_classifier/knn_classifier.dart +++ b/lib/src/classifier/knn_classifier/knn_classifier.dart @@ -1,7 +1,6 @@ import 'package:ml_algo/src/classifier/classifier.dart'; import 'package:ml_algo/src/classifier/knn_classifier/_init_module.dart'; import 'package:ml_algo/src/classifier/knn_classifier/knn_classifier_factory.dart'; -import 'package:ml_algo/src/common/serializable/serializable.dart'; import 'package:ml_algo/src/knn_kernel/kernel_type.dart'; import 'package:ml_dataframe/ml_dataframe.dart'; import 'package:ml_linalg/distance.dart'; @@ -19,7 +18,7 @@ import 'package:ml_linalg/dtype.dart'; /// imprecise result. Thus the weighted version of KNN algorithm is used in the /// classifier. To get weight of a particular observation one may use a kernel /// function. -abstract class KnnClassifier implements Serializable, Classifier { +abstract class KnnClassifier implements Classifier { /// Parameters: /// /// [trainData] Labelled observations. Must contain [targetName] column. diff --git a/lib/src/classifier/logistic_regressor/logistic_regressor.dart b/lib/src/classifier/logistic_regressor/logistic_regressor.dart index 808e0593..824170ab 100644 --- a/lib/src/classifier/logistic_regressor/logistic_regressor.dart +++ b/lib/src/classifier/logistic_regressor/logistic_regressor.dart @@ -1,7 +1,6 @@ import 'package:ml_algo/src/classifier/linear_classifier.dart'; import 'package:ml_algo/src/classifier/logistic_regressor/_init_module.dart'; import 'package:ml_algo/src/classifier/logistic_regressor/logistic_regressor_factory.dart'; -import 'package:ml_algo/src/common/serializable/serializable.dart'; import 'package:ml_algo/src/linear_optimizer/gradient_optimizer/learning_rate_generator/learning_rate_type.dart'; import 'package:ml_algo/src/linear_optimizer/initial_coefficients_generator/initial_coefficients_type.dart'; import 'package:ml_algo/src/linear_optimizer/linear_optimizer_type.dart'; @@ -17,7 +16,7 @@ import 'package:ml_linalg/vector.dart'; /// In other words, the regressor iteratively tries to select coefficients /// that makes combination of passed features and these coefficients most /// likely. -abstract class LogisticRegressor implements Serializable, LinearClassifier { +abstract class LogisticRegressor implements LinearClassifier { /// Parameters: /// diff --git a/lib/src/classifier/softmax_regressor/softmax_regressor.dart b/lib/src/classifier/softmax_regressor/softmax_regressor.dart index 3441b5f1..b8b181f2 100644 --- a/lib/src/classifier/softmax_regressor/softmax_regressor.dart +++ b/lib/src/classifier/softmax_regressor/softmax_regressor.dart @@ -1,7 +1,6 @@ import 'package:ml_algo/src/classifier/linear_classifier.dart'; import 'package:ml_algo/src/classifier/softmax_regressor/_init_module.dart'; import 'package:ml_algo/src/classifier/softmax_regressor/softmax_regressor_factory.dart'; -import 'package:ml_algo/src/common/serializable/serializable.dart'; import 'package:ml_algo/src/linear_optimizer/gradient_optimizer/learning_rate_generator/learning_rate_type.dart'; import 'package:ml_algo/src/linear_optimizer/initial_coefficients_generator/initial_coefficients_type.dart'; import 'package:ml_algo/src/linear_optimizer/linear_optimizer_type.dart'; @@ -24,7 +23,7 @@ import 'package:ml_linalg/matrix.dart'; /// /// Also it is worth to mention that the algorithm is a generalization of /// [Logistic regression](https://en.wikipedia.org/wiki/Logistic_regression)) -abstract class SoftmaxRegressor implements Serializable, LinearClassifier { +abstract class SoftmaxRegressor implements LinearClassifier { /// Parameters: /// diff --git a/lib/src/predictor/predictor.dart b/lib/src/predictor/predictor.dart index df157eee..9e5793ef 100644 --- a/lib/src/predictor/predictor.dart +++ b/lib/src/predictor/predictor.dart @@ -1,9 +1,10 @@ +import 'package:ml_algo/src/common/serializable/serializable.dart'; import 'package:ml_algo/src/metric/metric_type.dart'; import 'package:ml_dataframe/ml_dataframe.dart'; import 'package:ml_linalg/dtype.dart'; /// A common interface for all types of classifiers and regressors -abstract class Predictor { +abstract class Predictor implements Serializable { /// A collection of target column names of a dataset used to learn the ML /// model Iterable get targetNames; diff --git a/lib/src/regressor/knn_regressor/knn_regressor.dart b/lib/src/regressor/knn_regressor/knn_regressor.dart index ed1e2292..85c2d3c3 100644 --- a/lib/src/regressor/knn_regressor/knn_regressor.dart +++ b/lib/src/regressor/knn_regressor/knn_regressor.dart @@ -1,4 +1,3 @@ -import 'package:ml_algo/src/common/serializable/serializable.dart'; import 'package:ml_algo/src/knn_kernel/kernel_type.dart'; import 'package:ml_algo/src/predictor/predictor.dart'; import 'package:ml_algo/src/regressor/knn_regressor/_init_module.dart'; @@ -20,7 +19,7 @@ import 'package:ml_linalg/dtype.dart'; /// To get a more precise result, one may use weighted average of found labels - /// the farther a found observation from the target one, the lower the weight of /// the observation is. To obtain these weights one may use a kernel function. -abstract class KnnRegressor implements Serializable, Predictor { +abstract class KnnRegressor implements Predictor { /// Parameters: /// /// [fittingData] Labelled observations, among which will be searched [k] diff --git a/lib/src/regressor/linear_regressor/linear_regressor.dart b/lib/src/regressor/linear_regressor/linear_regressor.dart index 2c68720d..c7f03a2e 100644 --- a/lib/src/regressor/linear_regressor/linear_regressor.dart +++ b/lib/src/regressor/linear_regressor/linear_regressor.dart @@ -1,4 +1,3 @@ -import 'package:ml_algo/src/common/serializable/serializable.dart'; import 'package:ml_algo/src/linear_optimizer/gradient_optimizer/learning_rate_generator/learning_rate_type.dart'; import 'package:ml_algo/src/linear_optimizer/initial_coefficients_generator/initial_coefficients_type.dart'; import 'package:ml_algo/src/linear_optimizer/linear_optimizer_type.dart'; @@ -20,7 +19,7 @@ import 'package:ml_linalg/vector.dart'; /// regressor should predict, and since all the `x` values are known (since they /// are the input for the algorithm), the regressor should find the best /// coefficients (weights) for each `x`-es to make a best prediction of `y` term. -abstract class LinearRegressor implements Serializable, Predictor { +abstract class LinearRegressor implements Predictor { /// Parameters: /// /// [fittingData] A [DataFrame] with observations that is used by the From f78e0aa2526803b282e410a19c72baf87e7ebe6a Mon Sep 17 00:00:00 2001 From: Ilya Gyrdymov Date: Sun, 27 Dec 2020 00:26:10 +0200 Subject: [PATCH 4/5] CrossValidator: isolates used --- .../common/constants/common_json_keys.dart | 1 + .../metric/metric_type_encoded_values.dart | 6 ++ .../metric/metric_type_json_converter.dart | 59 +++++++++++++ .../_helpers/decode_predictor.dart | 35 ++++++++ .../_helpers/decode_predictor_type.dart | 27 ++++++ .../_helpers/encode_predictor_type.dart | 27 ++++++ .../from_serializable_predictor_Json.dart | 5 -- .../_helpers/get_predictor_type.dart | 37 ++++++++ .../cross_validator/cross_validator_impl.dart | 3 + .../cross_validator_isolate_message.dart | 51 ++++++----- .../cross_validator_isolate_message.g.dart | 86 ------------------- ...s_validator_isolate_message_json_keys.dart | 1 + .../cross_validator/predictor_type.dart | 8 ++ .../predictor_type_encoded_values.dart | 6 ++ 14 files changed, 241 insertions(+), 111 deletions(-) create mode 100644 lib/src/metric/metric_type_encoded_values.dart create mode 100644 lib/src/metric/metric_type_json_converter.dart create mode 100644 lib/src/model_selection/cross_validator/_helpers/decode_predictor.dart create mode 100644 lib/src/model_selection/cross_validator/_helpers/decode_predictor_type.dart create mode 100644 lib/src/model_selection/cross_validator/_helpers/encode_predictor_type.dart delete mode 100644 lib/src/model_selection/cross_validator/_helpers/from_serializable_predictor_Json.dart create mode 100644 lib/src/model_selection/cross_validator/_helpers/get_predictor_type.dart delete mode 100644 lib/src/model_selection/cross_validator/cross_validator_isolate_message.g.dart create mode 100644 lib/src/model_selection/cross_validator/predictor_type.dart create mode 100644 lib/src/model_selection/cross_validator/predictor_type_encoded_values.dart diff --git a/lib/src/common/constants/common_json_keys.dart b/lib/src/common/constants/common_json_keys.dart index 2d3e1e65..a53afd48 100644 --- a/lib/src/common/constants/common_json_keys.dart +++ b/lib/src/common/constants/common_json_keys.dart @@ -1 +1,2 @@ const jsonSchemaVersionJsonKey = '\$V'; +const predictorTypeJsonKey = '\$PT'; diff --git a/lib/src/metric/metric_type_encoded_values.dart b/lib/src/metric/metric_type_encoded_values.dart new file mode 100644 index 00000000..84c51a97 --- /dev/null +++ b/lib/src/metric/metric_type_encoded_values.dart @@ -0,0 +1,6 @@ +const mapeMetricTypeEncodedValue = 'MAPE'; +const rmseMetricTypeEncodedValue = 'RMSE'; +const rssMetricTypeEncodedValue = 'RSS'; +const accuracyMetricTypeEncodedValue = 'ACR'; +const precisionMetricTypeEncodedValue = 'PRC'; +const recallMetricTypeEncodedValue = 'RC'; diff --git a/lib/src/metric/metric_type_json_converter.dart b/lib/src/metric/metric_type_json_converter.dart new file mode 100644 index 00000000..ac350407 --- /dev/null +++ b/lib/src/metric/metric_type_json_converter.dart @@ -0,0 +1,59 @@ +import 'package:json_annotation/json_annotation.dart'; +import 'package:ml_algo/src/metric/metric_type.dart'; +import 'package:ml_algo/src/metric/metric_type_encoded_values.dart'; + +class MetricTypeJsonConverter implements JsonConverter { + const MetricTypeJsonConverter(); + + @override + MetricType fromJson(String json) { + switch (json) { + case mapeMetricTypeEncodedValue: + return MetricType.mape; + + case rmseMetricTypeEncodedValue: + return MetricType.rmse; + + case rssMetricTypeEncodedValue: + return MetricType.rss; + + case accuracyMetricTypeEncodedValue: + return MetricType.accuracy; + + case precisionMetricTypeEncodedValue: + return MetricType.precision; + + case recallMetricTypeEncodedValue: + return MetricType.recall; + + default: + throw UnsupportedError('Unsupported encoded metric value - $json'); + } + } + + @override + String toJson(MetricType metricType) { + switch (metricType) { + case MetricType.mape: + return mapeMetricTypeEncodedValue; + + case MetricType.rmse: + return rmseMetricTypeEncodedValue; + + case MetricType.rss: + return rssMetricTypeEncodedValue; + + case MetricType.accuracy: + return accuracyMetricTypeEncodedValue; + + case MetricType.precision: + return precisionMetricTypeEncodedValue; + + case MetricType.recall: + return recallMetricTypeEncodedValue; + + default: + throw UnsupportedError('Unsupported metric type - $metricType'); + } + } +} diff --git a/lib/src/model_selection/cross_validator/_helpers/decode_predictor.dart b/lib/src/model_selection/cross_validator/_helpers/decode_predictor.dart new file mode 100644 index 00000000..c6d6e762 --- /dev/null +++ b/lib/src/model_selection/cross_validator/_helpers/decode_predictor.dart @@ -0,0 +1,35 @@ +import 'dart:convert'; + +import 'package:ml_algo/ml_algo.dart'; +import 'package:ml_algo/src/classifier/logistic_regressor/logistic_regressor.dart'; +import 'package:ml_algo/src/classifier/softmax_regressor/softmax_regressor.dart'; +import 'package:ml_algo/src/model_selection/cross_validator/predictor_type.dart'; +import 'package:ml_algo/src/predictor/predictor.dart'; + +Predictor decodePredictor( + PredictorType predictorType, + Map json, +) { + switch (predictorType) { + case PredictorType.softmaxRegressor: + return SoftmaxRegressor.fromJson(jsonEncode(json)); + + case PredictorType.logisticRegressor: + return LogisticRegressor.fromJson(jsonEncode(json)); + + case PredictorType.knnClassifier: + return KnnClassifier.fromJson(jsonEncode(json)); + + case PredictorType.decisionTreeClassifier: + return DecisionTreeClassifier.fromJson(jsonEncode(json)); + + case PredictorType.knnRegressor: + return KnnRegressor.fromJson(jsonEncode(json)); + + case PredictorType.linearRegressor: + return LinearRegressor.fromJson(jsonEncode(json)); + + default: + throw UnsupportedError('Unsupported predictor type - ${predictorType}'); + } +} diff --git a/lib/src/model_selection/cross_validator/_helpers/decode_predictor_type.dart b/lib/src/model_selection/cross_validator/_helpers/decode_predictor_type.dart new file mode 100644 index 00000000..12c07fb0 --- /dev/null +++ b/lib/src/model_selection/cross_validator/_helpers/decode_predictor_type.dart @@ -0,0 +1,27 @@ +import 'package:ml_algo/src/model_selection/cross_validator/predictor_type.dart'; +import 'package:ml_algo/src/model_selection/cross_validator/predictor_type_encoded_values.dart'; + +PredictorType decodePredictorType(String json) { + switch (json) { + case knnRegressorPredictorTypeEncodedValue: + return PredictorType.knnRegressor; + + case linearRegressorPredictorTypeEncodedValue: + return PredictorType.linearRegressor; + + case decisionTreeClassifierPredictorTypeEncodedValue: + return PredictorType.decisionTreeClassifier; + + case knnClassifierPredictorTypeEncodedValue: + return PredictorType.knnClassifier; + + case logisticRegressorPredictorTypeEncodedValue: + return PredictorType.logisticRegressor; + + case softmaxRegressorPredictorTypeEncodedValue: + return PredictorType.softmaxRegressor; + + default: + throw UnsupportedError('Unsupported predictor encoded value - $json'); + } +} diff --git a/lib/src/model_selection/cross_validator/_helpers/encode_predictor_type.dart b/lib/src/model_selection/cross_validator/_helpers/encode_predictor_type.dart new file mode 100644 index 00000000..68da10d2 --- /dev/null +++ b/lib/src/model_selection/cross_validator/_helpers/encode_predictor_type.dart @@ -0,0 +1,27 @@ +import 'package:ml_algo/src/model_selection/cross_validator/predictor_type.dart'; +import 'package:ml_algo/src/model_selection/cross_validator/predictor_type_encoded_values.dart'; + +String encodePredictorType(PredictorType predictorType) { + switch (predictorType) { + case PredictorType.knnRegressor: + return knnRegressorPredictorTypeEncodedValue; + + case PredictorType.linearRegressor: + return linearRegressorPredictorTypeEncodedValue; + + case PredictorType.decisionTreeClassifier: + return decisionTreeClassifierPredictorTypeEncodedValue; + + case PredictorType.knnClassifier: + return knnClassifierPredictorTypeEncodedValue; + + case PredictorType.logisticRegressor: + return logisticRegressorPredictorTypeEncodedValue; + + case PredictorType.softmaxRegressor: + return softmaxRegressorPredictorTypeEncodedValue; + + default: + throw UnsupportedError('Unsupported predictor type - $predictorType'); + } +} diff --git a/lib/src/model_selection/cross_validator/_helpers/from_serializable_predictor_Json.dart b/lib/src/model_selection/cross_validator/_helpers/from_serializable_predictor_Json.dart deleted file mode 100644 index 69457146..00000000 --- a/lib/src/model_selection/cross_validator/_helpers/from_serializable_predictor_Json.dart +++ /dev/null @@ -1,5 +0,0 @@ -import 'package:ml_algo/src/predictor/predictor.dart'; - -Predictor fromSerializablePredictorJson(Map json) { - return null; -} diff --git a/lib/src/model_selection/cross_validator/_helpers/get_predictor_type.dart b/lib/src/model_selection/cross_validator/_helpers/get_predictor_type.dart new file mode 100644 index 00000000..ad9026f0 --- /dev/null +++ b/lib/src/model_selection/cross_validator/_helpers/get_predictor_type.dart @@ -0,0 +1,37 @@ +import 'package:ml_algo/src/classifier/decision_tree_classifier/decision_tree_classifier.dart'; +import 'package:ml_algo/src/classifier/knn_classifier/knn_classifier.dart'; +import 'package:ml_algo/src/classifier/logistic_regressor/logistic_regressor.dart'; +import 'package:ml_algo/src/classifier/softmax_regressor/softmax_regressor.dart'; +import 'package:ml_algo/src/model_selection/cross_validator/predictor_type.dart'; +import 'package:ml_algo/src/predictor/predictor.dart'; +import 'package:ml_algo/src/regressor/knn_regressor/knn_regressor.dart'; +import 'package:ml_algo/src/regressor/linear_regressor/linear_regressor.dart'; + +PredictorType getPredictorType(Predictor predictor) { + if (predictor is LogisticRegressor) { + return PredictorType.logisticRegressor; + } + + if (predictor is SoftmaxRegressor) { + return PredictorType.softmaxRegressor; + } + + if (predictor is DecisionTreeClassifier) { + return PredictorType.decisionTreeClassifier; + } + + if (predictor is KnnClassifier) { + return PredictorType.knnClassifier; + } + + if (predictor is LinearRegressor) { + return PredictorType.linearRegressor; + } + + if (predictor is KnnRegressor) { + return PredictorType.knnRegressor; + } + + throw UnsupportedError( + 'Unsupported predictor type - ${predictor.runtimeType}'); +} diff --git a/lib/src/model_selection/cross_validator/cross_validator_impl.dart b/lib/src/model_selection/cross_validator/cross_validator_impl.dart index 9f7cbf43..f8796153 100644 --- a/lib/src/model_selection/cross_validator/cross_validator_impl.dart +++ b/lib/src/model_selection/cross_validator/cross_validator_impl.dart @@ -2,6 +2,7 @@ import 'package:ml_algo/src/common/exception/invalid_test_data_columns_number_ex import 'package:ml_algo/src/common/exception/invalid_train_data_columns_number_exception.dart'; import 'package:ml_algo/src/metric/metric_type.dart'; import 'package:ml_algo/src/model_selection/cross_validator/_helpers/assess_predictor.dart'; +import 'package:ml_algo/src/model_selection/cross_validator/_helpers/get_predictor_type.dart'; import 'package:ml_algo/src/model_selection/cross_validator/cross_validator.dart'; import 'package:ml_algo/src/model_selection/cross_validator/cross_validator_isolate_message.dart'; import 'package:ml_algo/src/model_selection/split_indices_provider/split_indices_provider.dart'; @@ -124,6 +125,7 @@ class CrossValidatorImpl implements CrossValidator { final samplesPrototype = samples.sampleFromRows([0]); final predictorPrototype = predictorFactoryFn(samplesPrototype); + final predictorType = getPredictorType(predictorPrototype); return _workerManager .executor @@ -132,6 +134,7 @@ class CrossValidatorImpl implements CrossValidator { predictorPrototype, trainData, testData, + predictorType, metricType, ).toJson(), fun1: assessPredictor, diff --git a/lib/src/model_selection/cross_validator/cross_validator_isolate_message.dart b/lib/src/model_selection/cross_validator/cross_validator_isolate_message.dart index f08911a5..5492ea97 100644 --- a/lib/src/model_selection/cross_validator/cross_validator_isolate_message.dart +++ b/lib/src/model_selection/cross_validator/cross_validator_isolate_message.dart @@ -1,39 +1,50 @@ -import 'package:json_annotation/json_annotation.dart'; import 'package:ml_algo/src/metric/metric_type.dart'; -import 'package:ml_algo/src/model_selection/cross_validator/_helpers/from_serializable_predictor_Json.dart'; +import 'package:ml_algo/src/metric/metric_type_json_converter.dart'; +import 'package:ml_algo/src/model_selection/cross_validator/_helpers/decode_predictor.dart'; +import 'package:ml_algo/src/model_selection/cross_validator/_helpers/decode_predictor_type.dart'; +import 'package:ml_algo/src/model_selection/cross_validator/_helpers/encode_predictor_type.dart'; import 'package:ml_algo/src/model_selection/cross_validator/cross_validator_isolate_message_json_keys.dart'; +import 'package:ml_algo/src/model_selection/cross_validator/predictor_type.dart'; import 'package:ml_algo/src/predictor/predictor.dart'; import 'package:ml_dataframe/ml_dataframe.dart'; -part 'cross_validator_isolate_message.g.dart'; - -@JsonSerializable() class CrossValidatorIsolateMessage { - @JsonKey( - name: predictorPrototypeJsonKey, - fromJson: fromSerializablePredictorJson, - ) final Predictor predictorPrototype; - - @JsonKey(name: trainDataJsonKey) final DataFrame trainData; - - @JsonKey(name: testDataJsonKey) final DataFrame testData; - - @JsonKey(name: metricTypeJsonKey) + final PredictorType predictorType; final MetricType metricType; CrossValidatorIsolateMessage( this.predictorPrototype, this.trainData, this.testData, + this.predictorType, this.metricType, ); - factory CrossValidatorIsolateMessage.fromJson(Map json) => - _$CrossValidatorIsolateMessageFromJson(json); - - Map toJson() => - _$CrossValidatorIsolateMessageToJson(this); + static CrossValidatorIsolateMessage fromJson(Map json) { + final predictorType = decodePredictorType( + json[predictorTypeJsonKey] as String); + + return CrossValidatorIsolateMessage( + decodePredictor( + predictorType, + json[predictorPrototypeJsonKey] as Map, + ), + DataFrame.fromJson(json[trainDataJsonKey] as Map), + DataFrame.fromJson(json[testDataJsonKey] as Map), + predictorType, + MetricTypeJsonConverter() + .fromJson(json[metricTypeJsonKey] as String), + ); + } + + Map toJson() => { + predictorPrototypeJsonKey: predictorPrototype.toJson(), + trainDataJsonKey: trainData.toJson(), + testDataJsonKey: testData.toJson(), + predictorTypeJsonKey: encodePredictorType(predictorType), + metricTypeJsonKey: MetricTypeJsonConverter().toJson(metricType), + }; } diff --git a/lib/src/model_selection/cross_validator/cross_validator_isolate_message.g.dart b/lib/src/model_selection/cross_validator/cross_validator_isolate_message.g.dart deleted file mode 100644 index 6d6d423f..00000000 --- a/lib/src/model_selection/cross_validator/cross_validator_isolate_message.g.dart +++ /dev/null @@ -1,86 +0,0 @@ -// GENERATED CODE - DO NOT MODIFY BY HAND - -part of 'cross_validator_isolate_message.dart'; - -// ************************************************************************** -// JsonSerializableGenerator -// ************************************************************************** - -CrossValidatorIsolateMessage _$CrossValidatorIsolateMessageFromJson( - Map json) { - return $checkedNew('CrossValidatorIsolateMessage', json, () { - $checkKeys(json, allowedKeys: const ['P', 'T', 'TE', 'M']); - final val = CrossValidatorIsolateMessage( - $checkedConvert(json, 'P', - (v) => fromSerializablePredictorJson(v as Map)), - $checkedConvert( - json, - 'T', - (v) => - v == null ? null : DataFrame.fromJson(v as Map)), - $checkedConvert( - json, - 'TE', - (v) => - v == null ? null : DataFrame.fromJson(v as Map)), - $checkedConvert( - json, 'M', (v) => _$enumDecodeNullable(_$MetricTypeEnumMap, v)), - ); - return val; - }, fieldKeyMap: const { - 'predictorPrototype': 'P', - 'trainData': 'T', - 'testData': 'TE', - 'metricType': 'M' - }); -} - -Map _$CrossValidatorIsolateMessageToJson( - CrossValidatorIsolateMessage instance) => - { - 'P': instance.predictorPrototype, - 'T': instance.trainData, - 'TE': instance.testData, - 'M': _$MetricTypeEnumMap[instance.metricType], - }; - -T _$enumDecode( - Map enumValues, - dynamic source, { - T unknownValue, -}) { - if (source == null) { - throw ArgumentError('A value must be provided. Supported values: ' - '${enumValues.values.join(', ')}'); - } - - final value = enumValues.entries - .singleWhere((e) => e.value == source, orElse: () => null) - ?.key; - - if (value == null && unknownValue == null) { - throw ArgumentError('`$source` is not one of the supported values: ' - '${enumValues.values.join(', ')}'); - } - return value ?? unknownValue; -} - -T _$enumDecodeNullable( - Map enumValues, - dynamic source, { - T unknownValue, -}) { - if (source == null) { - return null; - } - return _$enumDecode(enumValues, source, unknownValue: unknownValue); -} - -const _$MetricTypeEnumMap = { - MetricType.mape: 'mape', - MetricType.rmse: 'rmse', - MetricType.rss: 'rss', - MetricType.accuracy: 'accuracy', - MetricType.precision: 'precision', - MetricType.recall: 'recall', -}; diff --git a/lib/src/model_selection/cross_validator/cross_validator_isolate_message_json_keys.dart b/lib/src/model_selection/cross_validator/cross_validator_isolate_message_json_keys.dart index 878a3e53..8cdb41f7 100644 --- a/lib/src/model_selection/cross_validator/cross_validator_isolate_message_json_keys.dart +++ b/lib/src/model_selection/cross_validator/cross_validator_isolate_message_json_keys.dart @@ -2,4 +2,5 @@ const predictorPrototypeJsonKey = 'P'; const trainDataJsonKey = 'T'; const testDataJsonKey = 'TE'; const testRowsIndicesJsonKey = 'R'; +const predictorTypeJsonKey = 'PT'; const metricTypeJsonKey = 'M'; diff --git a/lib/src/model_selection/cross_validator/predictor_type.dart b/lib/src/model_selection/cross_validator/predictor_type.dart new file mode 100644 index 00000000..fed3f116 --- /dev/null +++ b/lib/src/model_selection/cross_validator/predictor_type.dart @@ -0,0 +1,8 @@ +enum PredictorType { + logisticRegressor, + softmaxRegressor, + decisionTreeClassifier, + knnClassifier, + linearRegressor, + knnRegressor, +} diff --git a/lib/src/model_selection/cross_validator/predictor_type_encoded_values.dart b/lib/src/model_selection/cross_validator/predictor_type_encoded_values.dart new file mode 100644 index 00000000..4f58bc36 --- /dev/null +++ b/lib/src/model_selection/cross_validator/predictor_type_encoded_values.dart @@ -0,0 +1,6 @@ +const logisticRegressorPredictorTypeEncodedValue = 'LOGR'; +const softmaxRegressorPredictorTypeEncodedValue = 'SR'; +const decisionTreeClassifierPredictorTypeEncodedValue = 'DTC'; +const knnClassifierPredictorTypeEncodedValue = 'KNNC'; +const knnRegressorPredictorTypeEncodedValue = 'KNNR'; +const linearRegressorPredictorTypeEncodedValue = 'LINR'; From 00298ae5c0f328e294341e18b0c226552daf3bab Mon Sep 17 00:00:00 2001 From: Ilya Gyrdymov Date: Tue, 29 Dec 2020 17:44:03 +0200 Subject: [PATCH 5/5] CrossValidator: isolates used --- .../cross_validator/cross_validator_impl.dart | 14 +++++++------- .../worker_manager/worker_manager_impl.dart | 11 ++++++++--- 2 files changed, 15 insertions(+), 10 deletions(-) diff --git a/lib/src/model_selection/cross_validator/cross_validator_impl.dart b/lib/src/model_selection/cross_validator/cross_validator_impl.dart index f8796153..e3a375c6 100644 --- a/lib/src/model_selection/cross_validator/cross_validator_impl.dart +++ b/lib/src/model_selection/cross_validator/cross_validator_impl.dart @@ -34,6 +34,8 @@ class CrossValidatorImpl implements CrossValidator { DataPreprocessFn onDataSplit, } ) async { + await _workerManager.init(); + final samplesAsMatrix = samples.toMatrix(dtype); final sourceColumnsNum = samplesAsMatrix.columnsNum; final discreteColumns = enumerate(samples.series) @@ -116,13 +118,11 @@ class CrossValidatorImpl implements CrossValidator { } Future _assessPredictor( - PredictorFactoryFn predictorFactoryFn, - DataFrame trainData, - DataFrame testData, - MetricType metricType, - ) async { - await _workerManager.init(); - + PredictorFactoryFn predictorFactoryFn, + DataFrame trainData, + DataFrame testData, + MetricType metricType, + ) async { final samplesPrototype = samples.sampleFromRows([0]); final predictorPrototype = predictorFactoryFn(samplesPrototype); final predictorType = getPredictorType(predictorPrototype); diff --git a/lib/src/service/worker_manager/worker_manager_impl.dart b/lib/src/service/worker_manager/worker_manager_impl.dart index 7527c4ae..9d1ea648 100644 --- a/lib/src/service/worker_manager/worker_manager_impl.dart +++ b/lib/src/service/worker_manager/worker_manager_impl.dart @@ -1,19 +1,24 @@ +import 'dart:async'; import 'package:ml_algo/src/service/worker_manager/worker_manager.dart'; import 'package:worker_manager/src/executor.dart'; class WorkerManagerImpl implements WorkerManager { - bool _isInitialized = false; + Completer _initCompleter; @override Executor get executor => Executor(); @override Future init() async { - if (_isInitialized) { + final isWarmedUp = (await _initCompleter?.future) ?? false; + + if (isWarmedUp) { return; } + _initCompleter ??= Completer(); + await Executor().warmUp(); - _isInitialized = true; + _initCompleter.complete(true); } }