Skip to content

Commit 7cd8bfb

Browse files
committed
CrossValidator: isolates used
1 parent a9c744f commit 7cd8bfb

17 files changed

+331
-46
lines changed

lib/src/di/common/init_common_module.dart

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ import 'package:ml_algo/src/model_selection/model_assessor/classifier_assessor.d
2424
import 'package:ml_algo/src/model_selection/model_assessor/model_assessor.dart';
2525
import 'package:ml_algo/src/model_selection/model_assessor/regressor_assessor.dart';
2626
import 'package:ml_algo/src/predictor/predictor.dart';
27+
import 'package:ml_algo/src/service/worker_manager/worker_manager.dart';
28+
import 'package:ml_algo/src/service/worker_manager/worker_manager_impl.dart';
2729
import 'package:ml_dataframe/ml_dataframe.dart';
2830
import 'package:ml_preprocessing/ml_preprocessing.dart';
2931

@@ -76,5 +78,8 @@ void initCommonModule() {
7678
RegressorAssessor(
7779
injector.get<MetricFactory>(),
7880
featuresTargetSplit,
79-
));
81+
))
82+
83+
..registerSingletonIf<WorkerManager>(
84+
() => WorkerManagerImpl());
8085
}
Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
1+
import 'package:ml_algo/src/di/common/init_common_module.dart';
12
import 'package:ml_algo/src/model_selection/_injector.dart';
23
import 'package:ml_algo/src/model_selection/split_indices_provider/split_indices_provider_factory.dart';
34
import 'package:ml_algo/src/model_selection/split_indices_provider/split_indices_provider_factory_impl.dart';
5+
import 'package:ml_algo/src/extensions/injector.dart';
46

57
void initModelSelectionModule() {
6-
if (!modelSelectionInjector.exists<SplitIndicesProviderFactory>()) {
7-
modelSelectionInjector
8-
..registerSingleton<SplitIndicesProviderFactory>(
9-
() => const SplitIndicesProviderFactoryImpl());
10-
}
8+
initCommonModule();
9+
10+
modelSelectionInjector
11+
..registerSingletonIf<SplitIndicesProviderFactory>(
12+
() => const SplitIndicesProviderFactoryImpl());
1113
}
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
import 'package:ml_algo/src/model_selection/assessable.dart';
2+
import 'package:ml_algo/src/model_selection/cross_validator/cross_validator_isolate_message.dart';
3+
4+
num assessPredictor(Map<String, dynamic> encodedMessage) {
5+
final message = CrossValidatorIsolateMessage
6+
.fromJson(encodedMessage);
7+
final predictor = message
8+
.predictorPrototype
9+
.retrain(message.trainData);
10+
11+
return (predictor as Assessable)
12+
.assess(message.testData, message.metricType);
13+
}
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
import 'package:ml_algo/src/model_selection/serializable_predictor.dart';
2+
3+
SerializablePredictor fromSerializablePredictorJson(Map<String, dynamic> json) {
4+
return null;
5+
}
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
import 'package:ml_algo/src/di/common/init_common_module.dart';
2+
3+
void initCrossValidatorModule() {
4+
initCommonModule();
5+
}
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
import 'package:injector/injector.dart';
2+
3+
final crossValidatorInjector = Injector();

lib/src/model_selection/cross_validator/cross_validator.dart

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,17 @@
1+
import 'package:ml_algo/src/di/injector.dart';
12
import 'package:ml_algo/src/metric/metric_type.dart';
23
import 'package:ml_algo/src/model_selection/_init_module.dart';
34
import 'package:ml_algo/src/model_selection/_injector.dart';
4-
import 'package:ml_algo/src/model_selection/assessable.dart';
55
import 'package:ml_algo/src/model_selection/cross_validator/cross_validator_impl.dart';
6+
import 'package:ml_algo/src/model_selection/serializable_predictor.dart';
67
import 'package:ml_algo/src/model_selection/split_indices_provider/split_indices_provider_factory.dart';
78
import 'package:ml_algo/src/model_selection/split_indices_provider/split_indices_provider_type.dart';
9+
import 'package:ml_algo/src/service/worker_manager/worker_manager.dart';
810
import 'package:ml_dataframe/ml_dataframe.dart';
911
import 'package:ml_linalg/dtype.dart';
1012
import 'package:ml_linalg/linalg.dart';
1113

12-
typedef PredictorFactory = Assessable Function(DataFrame observations);
14+
typedef PredictorFactoryFn = SerializablePredictor Function(DataFrame observations);
1315

1416
typedef DataPreprocessFn = List<DataFrame> Function(DataFrame trainData,
1517
DataFrame testData);
@@ -39,11 +41,13 @@ abstract class CrossValidator {
3941

4042
final dataSplitterFactory = modelSelectionInjector
4143
.get<SplitIndicesProviderFactory>();
44+
final workerManager = injector.get<WorkerManager>();
4245
final dataSplitter = dataSplitterFactory
4346
.createByType(SplitIndicesProviderType.kFold, numberOfFolds: numberOfFolds);
4447

4548
return CrossValidatorImpl(
4649
samples,
50+
workerManager,
4751
dataSplitter,
4852
dtype,
4953
);
@@ -71,11 +75,13 @@ abstract class CrossValidator {
7175

7276
final dataSplitterFactory = modelSelectionInjector
7377
.get<SplitIndicesProviderFactory>();
78+
final workerManager = injector.get<WorkerManager>();
7479
final dataSplitter = dataSplitterFactory
7580
.createByType(SplitIndicesProviderType.lpo, p: p);
7681

7782
return CrossValidatorImpl(
7883
samples,
84+
workerManager,
7985
dataSplitter,
8086
dtype,
8187
);
@@ -133,7 +139,7 @@ abstract class CrossValidator {
133139
/// print(averageScore);
134140
/// ````
135141
Future<Vector> evaluate(
136-
PredictorFactory predictorFactory,
142+
PredictorFactoryFn predictorFactory,
137143
MetricType metricType,
138144
{
139145
DataPreprocessFn onDataSplit,

lib/src/model_selection/cross_validator/cross_validator_impl.dart

Lines changed: 53 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
11
import 'package:ml_algo/src/common/exception/invalid_test_data_columns_number_exception.dart';
22
import 'package:ml_algo/src/common/exception/invalid_train_data_columns_number_exception.dart';
33
import 'package:ml_algo/src/metric/metric_type.dart';
4+
import 'package:ml_algo/src/model_selection/cross_validator/_helpers/assess_predictor.dart';
45
import 'package:ml_algo/src/model_selection/cross_validator/cross_validator.dart';
6+
import 'package:ml_algo/src/model_selection/cross_validator/cross_validator_isolate_message.dart';
57
import 'package:ml_algo/src/model_selection/split_indices_provider/split_indices_provider.dart';
8+
import 'package:ml_algo/src/service/worker_manager/worker_manager.dart';
69
import 'package:ml_dataframe/ml_dataframe.dart';
710
import 'package:ml_linalg/dtype.dart';
811
import 'package:ml_linalg/matrix.dart';
@@ -12,40 +15,47 @@ import 'package:quiver/iterables.dart';
1215
class CrossValidatorImpl implements CrossValidator {
1316
CrossValidatorImpl(
1417
this.samples,
18+
this._workerManager,
1519
this._splitter,
1620
this.dtype,
1721
);
1822

1923
final DataFrame samples;
20-
final DType dtype;
24+
final WorkerManager _workerManager;
2125
final SplitIndicesProvider _splitter;
26+
final DType dtype;
2227

2328
@override
2429
Future<Vector> evaluate(
25-
PredictorFactory predictorFactory,
30+
PredictorFactoryFn predictorFactory,
2631
MetricType metricType,
2732
{
2833
DataPreprocessFn onDataSplit,
2934
}
30-
) {
35+
) async {
3136
final samplesAsMatrix = samples.toMatrix(dtype);
3237
final sourceColumnsNum = samplesAsMatrix.columnsNum;
3338
final discreteColumns = enumerate(samples.series)
3439
.where((indexedSeries) => indexedSeries.value.isDiscrete)
3540
.map((indexedSeries) => indexedSeries.index);
36-
final allIndicesGroups = _splitter.getIndices(samplesAsMatrix.rowsNum);
37-
final scores = allIndicesGroups
41+
final allIndicesGroups = _splitter
42+
.getIndices(samplesAsMatrix.rowsNum);
43+
final scoreFutures = allIndicesGroups
3844
.map((testRowsIndices) {
3945
final split = _makeSplit(testRowsIndices, discreteColumns);
4046
final trainDataFrame = split[0];
4147
final testDataFrame = split[1];
42-
final splits = onDataSplit != null
48+
final transformedData = onDataSplit != null
4349
? onDataSplit(trainDataFrame, testDataFrame)
4450
: [trainDataFrame, testDataFrame];
45-
final transformedTrainData = splits[0];
46-
final transformedTestData = splits[1];
47-
final transformedTrainDataColumnsNum = transformedTrainData.header.length;
48-
final transformedTestDataColumnsNum = transformedTestData.header.length;
51+
final transformedTrainData = transformedData[0];
52+
final transformedTestData = transformedData[1];
53+
final transformedTrainDataColumnsNum = transformedTrainData
54+
.header
55+
.length;
56+
final transformedTestDataColumnsNum = transformedTestData
57+
.header
58+
.length;
4959

5060
if (transformedTrainDataColumnsNum != sourceColumnsNum) {
5161
throw InvalidTrainDataColumnsNumberException(sourceColumnsNum,
@@ -57,12 +67,18 @@ class CrossValidatorImpl implements CrossValidator {
5767
transformedTestDataColumnsNum);
5868
}
5969

60-
return predictorFactory(transformedTrainData)
61-
.assess(transformedTestData, metricType);
70+
return _assessPredictor(
71+
predictorFactory,
72+
transformedTrainData,
73+
transformedTestData,
74+
metricType,
75+
);
6276
})
6377
.toList();
78+
final scores = await Future.wait(scoreFutures);
6479

65-
return Future.value(Vector.fromList(scores, dtype: dtype));
80+
return Future
81+
.value(Vector.fromList(scores, dtype: dtype));
6682
}
6783

6884
List<DataFrame> _makeSplit(Iterable<int> testRowsIndices,
@@ -97,4 +113,28 @@ class CrossValidatorImpl implements CrossValidator {
97113
),
98114
];
99115
}
116+
117+
Future<num> _assessPredictor(
118+
PredictorFactoryFn predictorFactoryFn,
119+
DataFrame trainData,
120+
DataFrame testData,
121+
MetricType metricType,
122+
) async {
123+
await _workerManager.init();
124+
125+
final samplesPrototype = samples.sampleFromRows([0]);
126+
final predictorPrototype = predictorFactoryFn(samplesPrototype);
127+
128+
return _workerManager
129+
.executor
130+
.execute(
131+
arg1: CrossValidatorIsolateMessage(
132+
predictorPrototype,
133+
trainData,
134+
testData,
135+
metricType,
136+
).toJson(),
137+
fun1: assessPredictor,
138+
);
139+
}
100140
}
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
import 'package:json_annotation/json_annotation.dart';
2+
import 'package:ml_algo/src/metric/metric_type.dart';
3+
import 'package:ml_algo/src/model_selection/cross_validator/_helpers/from_serializable_predictor_Json.dart';
4+
import 'package:ml_algo/src/model_selection/cross_validator/cross_validator_isolate_message_json_keys.dart';
5+
import 'package:ml_algo/src/model_selection/serializable_predictor.dart';
6+
import 'package:ml_dataframe/ml_dataframe.dart';
7+
8+
part 'cross_validator_isolate_message.g.dart';
9+
10+
@JsonSerializable()
11+
class CrossValidatorIsolateMessage {
12+
@JsonKey(
13+
name: predictorPrototypeJsonKey,
14+
fromJson: fromSerializablePredictorJson,
15+
)
16+
final SerializablePredictor predictorPrototype;
17+
18+
@JsonKey(name: trainDataJsonKey)
19+
final DataFrame trainData;
20+
21+
@JsonKey(name: testDataJsonKey)
22+
final DataFrame testData;
23+
24+
@JsonKey(name: metricTypeJsonKey)
25+
final MetricType metricType;
26+
27+
CrossValidatorIsolateMessage(
28+
this.predictorPrototype,
29+
this.trainData,
30+
this.testData,
31+
this.metricType,
32+
);
33+
34+
factory CrossValidatorIsolateMessage.fromJson(Map<String, dynamic> json) =>
35+
_$CrossValidatorIsolateMessageFromJson(json);
36+
37+
Map<String, dynamic> toJson() =>
38+
_$CrossValidatorIsolateMessageToJson(this);
39+
}

lib/src/model_selection/cross_validator/cross_validator_isolate_message.g.dart

Lines changed: 86 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)