Skip to content

Commit ecedbe2

Browse files
committed
CrossValidator: isolates used
1 parent a9c744f commit ecedbe2

17 files changed

+317
-41
lines changed

lib/src/di/common/init_common_module.dart

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ import 'package:ml_algo/src/model_selection/model_assessor/classifier_assessor.d
2424
import 'package:ml_algo/src/model_selection/model_assessor/model_assessor.dart';
2525
import 'package:ml_algo/src/model_selection/model_assessor/regressor_assessor.dart';
2626
import 'package:ml_algo/src/predictor/predictor.dart';
27+
import 'package:ml_algo/src/service/worker_manager/worker_manager.dart';
28+
import 'package:ml_algo/src/service/worker_manager/worker_manager_impl.dart';
2729
import 'package:ml_dataframe/ml_dataframe.dart';
2830
import 'package:ml_preprocessing/ml_preprocessing.dart';
2931

@@ -76,5 +78,8 @@ void initCommonModule() {
7678
RegressorAssessor(
7779
injector.get<MetricFactory>(),
7880
featuresTargetSplit,
79-
));
81+
))
82+
83+
..registerSingletonIf<WorkerManager>(
84+
() => WorkerManagerImpl());
8085
}
Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
1+
import 'package:ml_algo/src/di/common/init_common_module.dart';
12
import 'package:ml_algo/src/model_selection/_injector.dart';
23
import 'package:ml_algo/src/model_selection/split_indices_provider/split_indices_provider_factory.dart';
34
import 'package:ml_algo/src/model_selection/split_indices_provider/split_indices_provider_factory_impl.dart';
5+
import 'package:ml_algo/src/extensions/injector.dart';
46

57
void initModelSelectionModule() {
6-
if (!modelSelectionInjector.exists<SplitIndicesProviderFactory>()) {
7-
modelSelectionInjector
8-
..registerSingleton<SplitIndicesProviderFactory>(
9-
() => const SplitIndicesProviderFactoryImpl());
10-
}
8+
initCommonModule();
9+
10+
modelSelectionInjector
11+
..registerSingletonIf<SplitIndicesProviderFactory>(
12+
() => const SplitIndicesProviderFactoryImpl());
1113
}
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
import 'package:ml_algo/src/common/exception/invalid_test_data_columns_number_exception.dart';
2+
import 'package:ml_algo/src/common/exception/invalid_train_data_columns_number_exception.dart';
3+
import 'package:ml_algo/src/model_selection/cross_validator/cross_validator_assessing_data_model.dart';
4+
import 'package:ml_dataframe/ml_dataframe.dart';
5+
import 'package:ml_linalg/dtype.dart';
6+
import 'package:ml_linalg/matrix.dart';
7+
import 'package:ml_linalg/vector.dart';
8+
import 'package:quiver/iterables.dart';
9+
10+
num assessPredictor(Map<String, dynamic> assessingDataJson) {
11+
final assessingDataModel = CrossValidatorAssessingDataModel
12+
.fromJson(assessingDataJson);
13+
final predictorPrototype = assessingDataModel.predictorPrototype;
14+
final testRowsIndices = assessingDataModel.testRowsIndices;
15+
final dtype = predictorPrototype.dtype;
16+
final samples = assessingDataModel.samplesPrototype;
17+
final discreteColumns = enumerate(samples.series)
18+
.where((indexedSeries) => indexedSeries.value.isDiscrete)
19+
.map((indexedSeries) => indexedSeries.index);
20+
final samplesAsMatrix = samples.toMatrix(dtype);
21+
final sourceColumnsNum = samplesAsMatrix.columnsNum;
22+
final testRowsIndicesAsSet = Set<int>.from(testRowsIndices);
23+
final trainSamples = List<Vector>(
24+
samplesAsMatrix.rowsNum - testRowsIndicesAsSet.length);
25+
final testSamples = List<Vector>(testRowsIndicesAsSet.length);
26+
27+
var trainSamplesCounter = 0;
28+
var testSamplesCounter = 0;
29+
30+
samplesAsMatrix.rowIndices.forEach((i) {
31+
if (testRowsIndicesAsSet.contains(i)) {
32+
testSamples[testSamplesCounter++] = samplesAsMatrix[i];
33+
} else {
34+
trainSamples[trainSamplesCounter++] = samplesAsMatrix[i];
35+
}
36+
});
37+
38+
final transformedData = [
39+
DataFrame.fromMatrix(
40+
Matrix.fromRows(trainSamples, dtype: dtype),
41+
header: samples.header,
42+
discreteColumns: discreteColumns,
43+
),
44+
DataFrame.fromMatrix(
45+
Matrix.fromRows(testSamples, dtype: dtype),
46+
header: samples.header,
47+
discreteColumns: discreteColumns,
48+
),
49+
];
50+
final transformedTrainData = transformedData[0];
51+
final transformedTestData = transformedData[1];
52+
53+
if (transformedTrainData.header.length != sourceColumnsNum) {
54+
throw InvalidTrainDataColumnsNumberException(sourceColumnsNum,
55+
transformedTrainData.header.length);
56+
}
57+
58+
if (transformedTestData.header.length != sourceColumnsNum) {
59+
throw InvalidTestDataColumnsNumberException(sourceColumnsNum,
60+
transformedTestData.header.length);
61+
}
62+
63+
return 0.3;
64+
}
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
import 'package:ml_algo/src/model_selection/serializable_predictor.dart';
2+
3+
SerializablePredictor fromSerializablePredictorJson(Map<String, dynamic> json) {
4+
return null;
5+
}
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
import 'package:ml_algo/src/di/common/init_common_module.dart';
2+
3+
void initCrossValidatorModule() {
4+
initCommonModule();
5+
}
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
import 'package:injector/injector.dart';
2+
3+
final crossValidatorInjector = Injector();

lib/src/model_selection/cross_validator/cross_validator.dart

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,17 @@
1+
import 'package:ml_algo/src/di/injector.dart';
12
import 'package:ml_algo/src/metric/metric_type.dart';
23
import 'package:ml_algo/src/model_selection/_init_module.dart';
34
import 'package:ml_algo/src/model_selection/_injector.dart';
4-
import 'package:ml_algo/src/model_selection/assessable.dart';
55
import 'package:ml_algo/src/model_selection/cross_validator/cross_validator_impl.dart';
6+
import 'package:ml_algo/src/model_selection/serializable_predictor.dart';
67
import 'package:ml_algo/src/model_selection/split_indices_provider/split_indices_provider_factory.dart';
78
import 'package:ml_algo/src/model_selection/split_indices_provider/split_indices_provider_type.dart';
9+
import 'package:ml_algo/src/service/worker_manager/worker_manager.dart';
810
import 'package:ml_dataframe/ml_dataframe.dart';
911
import 'package:ml_linalg/dtype.dart';
1012
import 'package:ml_linalg/linalg.dart';
1113

12-
typedef PredictorFactory = Assessable Function(DataFrame observations);
14+
typedef PredictorFactoryFn = SerializablePredictor Function(DataFrame observations);
1315

1416
typedef DataPreprocessFn = List<DataFrame> Function(DataFrame trainData,
1517
DataFrame testData);
@@ -39,11 +41,13 @@ abstract class CrossValidator {
3941

4042
final dataSplitterFactory = modelSelectionInjector
4143
.get<SplitIndicesProviderFactory>();
44+
final workerManager = injector.get<WorkerManager>();
4245
final dataSplitter = dataSplitterFactory
4346
.createByType(SplitIndicesProviderType.kFold, numberOfFolds: numberOfFolds);
4447

4548
return CrossValidatorImpl(
4649
samples,
50+
workerManager,
4751
dataSplitter,
4852
dtype,
4953
);
@@ -71,11 +75,13 @@ abstract class CrossValidator {
7175

7276
final dataSplitterFactory = modelSelectionInjector
7377
.get<SplitIndicesProviderFactory>();
78+
final workerManager = injector.get<WorkerManager>();
7479
final dataSplitter = dataSplitterFactory
7580
.createByType(SplitIndicesProviderType.lpo, p: p);
7681

7782
return CrossValidatorImpl(
7883
samples,
84+
workerManager,
7985
dataSplitter,
8086
dtype,
8187
);
@@ -133,7 +139,7 @@ abstract class CrossValidator {
133139
/// print(averageScore);
134140
/// ````
135141
Future<Vector> evaluate(
136-
PredictorFactory predictorFactory,
142+
PredictorFactoryFn predictorFactory,
137143
MetricType metricType,
138144
{
139145
DataPreprocessFn onDataSplit,
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
import 'package:json_annotation/json_annotation.dart';
2+
import 'package:ml_algo/src/model_selection/cross_validator/_helpers/fromSerializablePredictorJson.dart';
3+
import 'package:ml_algo/src/model_selection/cross_validator/cross_validator_assessing_data_model_json_keys.dart';
4+
import 'package:ml_algo/src/model_selection/serializable_predictor.dart';
5+
import 'package:ml_dataframe/ml_dataframe.dart';
6+
7+
part 'cross_validator_assessing_data_model.g.dart';
8+
9+
@JsonSerializable()
10+
class CrossValidatorAssessingDataModel {
11+
@JsonKey(
12+
name: predictorPrototypeJsonKey,
13+
fromJson: fromSerializablePredictorJson,
14+
)
15+
final SerializablePredictor predictorPrototype;
16+
17+
@JsonKey(name: samplesJsonKey)
18+
final DataFrame samplesPrototype;
19+
20+
@JsonKey(name: testRowsIndicesJsonKey)
21+
final List<int> testRowsIndices;
22+
23+
CrossValidatorAssessingDataModel(
24+
this.predictorPrototype,
25+
this.samplesPrototype,
26+
this.testRowsIndices,
27+
);
28+
29+
factory CrossValidatorAssessingDataModel.fromJson(Map<String, dynamic> json) =>
30+
_$CrossValidatorAssessingDataModelFromJson(json);
31+
32+
Map<String, dynamic> toJson() =>
33+
_$CrossValidatorAssessingDataModelToJson(this);
34+
}

lib/src/model_selection/cross_validator/cross_validator_assessing_data_model.g.dart

Lines changed: 38 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
const predictorPrototypeJsonKey = 'P';
2+
const samplesJsonKey = 'S';
3+
const testRowsIndicesJsonKey = 'R';

0 commit comments

Comments
 (0)