Skip to content

Commit 318fe3a

Browse files
authored
DI logic: conditional dependency registering added (#162)
1 parent 7a20255 commit 318fe3a

File tree

9 files changed

+64
-41
lines changed

9 files changed

+64
-41
lines changed

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
# Changelog
22

3+
## 15.3.4
4+
- `DI logic`:
5+
- conditional dependency registering added
6+
37
## 15.3.3
48
- FUNDING.yml created
59

lib/src/classifier/decision_tree_classifier/_init_module.dart

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import 'package:ml_algo/src/classifier/decision_tree_classifier/decision_tree_cl
44
import 'package:ml_algo/src/common/distribution_calculator/distribution_calculator_factory.dart';
55
import 'package:ml_algo/src/common/distribution_calculator/distribution_calculator_factory_impl.dart';
66
import 'package:ml_algo/src/di/common/init_common_module.dart';
7+
import 'package:ml_algo/src/extensions/injector.dart';
78
import 'package:ml_algo/src/tree_trainer/leaf_detector/leaf_detector_factory.dart';
89
import 'package:ml_algo/src/tree_trainer/leaf_detector/leaf_detector_factory_impl.dart';
910
import 'package:ml_algo/src/tree_trainer/leaf_label/leaf_label_factory_factory.dart';
@@ -25,50 +26,49 @@ void initDecisionTreeModule() {
2526
initCommonModule();
2627

2728
decisionTreeInjector
28-
..clearAll()
29-
..registerSingleton<DistributionCalculatorFactory>(
29+
..registerSingletonIf<DistributionCalculatorFactory>(
3030
() => const DistributionCalculatorFactoryImpl())
3131

32-
..registerSingleton<NominalTreeSplitterFactory>(
32+
..registerSingletonIf<NominalTreeSplitterFactory>(
3333
() => const NominalTreeSplitterFactoryImpl())
3434

35-
..registerSingleton<NumericalTreeSplitterFactory>(
35+
..registerSingletonIf<NumericalTreeSplitterFactory>(
3636
() => const NumericalTreeSplitterFactoryImpl())
3737

38-
..registerSingleton<TreeSplitAssessorFactory>(
38+
..registerSingletonIf<TreeSplitAssessorFactory>(
3939
() => const TreeSplitAssessorFactoryImpl())
4040

41-
..registerSingleton<TreeSplitterFactory>(
41+
..registerSingletonIf<TreeSplitterFactory>(
4242
() => TreeSplitterFactoryImpl(
4343
decisionTreeInjector.get<TreeSplitAssessorFactory>(),
4444
decisionTreeInjector.get<NominalTreeSplitterFactory>(),
4545
decisionTreeInjector.get<NumericalTreeSplitterFactory>(),
4646
))
4747

48-
..registerSingleton<TreeSplitSelectorFactory>(
48+
..registerSingletonIf<TreeSplitSelectorFactory>(
4949
() => TreeSplitSelectorFactoryImpl(
5050
decisionTreeInjector.get<TreeSplitAssessorFactory>(),
5151
decisionTreeInjector.get<TreeSplitterFactory>(),
5252
))
5353

54-
..registerSingleton<TreeLeafDetectorFactory>(
54+
..registerSingletonIf<TreeLeafDetectorFactory>(
5555
() => TreeLeafDetectorFactoryImpl(
5656
decisionTreeInjector.get<TreeSplitAssessorFactory>(),
5757
))
5858

59-
..registerSingleton<TreeLeafLabelFactoryFactory>(
59+
..registerSingletonIf<TreeLeafLabelFactoryFactory>(
6060
() => TreeLeafLabelFactoryFactoryImpl(
6161
decisionTreeInjector
6262
.get<DistributionCalculatorFactory>(),
6363
))
6464

65-
..registerSingleton<TreeTrainerFactory>(
65+
..registerSingletonIf<TreeTrainerFactory>(
6666
() => TreeTrainerFactoryImpl(
6767
decisionTreeInjector.get<TreeLeafDetectorFactory>(),
6868
decisionTreeInjector.get<TreeLeafLabelFactoryFactory>(),
6969
decisionTreeInjector.get<TreeSplitSelectorFactory>(),
7070
))
7171

72-
..registerSingleton<DecisionTreeClassifierFactory>(
72+
..registerSingletonIf<DecisionTreeClassifierFactory>(
7373
() => const DecisionTreeClassifierFactoryImpl());
7474
}

lib/src/classifier/knn_classifier/_init_module.dart

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ import 'package:ml_algo/src/classifier/knn_classifier/_injector.dart';
22
import 'package:ml_algo/src/classifier/knn_classifier/knn_classifier_factory.dart';
33
import 'package:ml_algo/src/classifier/knn_classifier/knn_classifier_factory_impl.dart';
44
import 'package:ml_algo/src/di/common/init_common_module.dart';
5+
import 'package:ml_algo/src/extensions/injector.dart';
56
import 'package:ml_algo/src/knn_kernel/kernel_factory.dart';
67
import 'package:ml_algo/src/knn_kernel/kernel_factory_impl.dart';
78
import 'package:ml_algo/src/knn_solver/knn_solver_factory.dart';
@@ -13,17 +14,16 @@ void initKnnClassifierModule() {
1314
initCommonModule();
1415

1516
knnClassifierInjector
16-
..clearAll()
17-
..registerSingleton<KernelFactory>(
17+
..registerSingletonIf<KernelFactory>(
1818
() => const KernelFactoryImpl())
1919

20-
..registerDependency<KnnSolverFactory>(
20+
..registerSingletonIf<KnnSolverFactory>(
2121
() => const KnnSolverFactoryImpl())
2222

23-
..registerSingleton<KnnClassifierFactory>(
23+
..registerSingletonIf<KnnClassifierFactory>(
2424
() => const KnnClassifierFactoryImpl())
2525

26-
..registerSingleton<KnnRegressorFactory>(
26+
..registerSingletonIf<KnnRegressorFactory>(
2727
() => KnnRegressorFactoryImpl(
2828
knnClassifierInjector.get<KernelFactory>(),
2929
knnClassifierInjector.get<KnnSolverFactory>(),
Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import 'package:ml_algo/src/classifier/logistic_regressor/_injector.dart';
22
import 'package:ml_algo/src/di/common/init_common_module.dart';
3+
import 'package:ml_algo/src/extensions/injector.dart';
34
import 'package:ml_algo/src/link_function/link_function.dart';
45
import 'package:ml_algo/src/link_function/link_function_dependency_tokens.dart';
56
import 'package:ml_algo/src/link_function/logit/float32_inverse_logit_function.dart';
@@ -9,12 +10,11 @@ void initLogisticRegressorModule() {
910
initCommonModule();
1011

1112
logisticRegressorInjector
12-
..clearAll()
13-
..registerSingleton<LinkFunction>(
13+
..registerSingletonIf<LinkFunction>(
1414
() => const Float32InverseLogitLinkFunction(),
1515
dependencyName: float32InverseLogitLinkFunctionToken)
1616

17-
..registerSingleton<LinkFunction>(
17+
..registerSingletonIf<LinkFunction>(
1818
() => const Float64InverseLogitLinkFunction(),
1919
dependencyName: float64InverseLogitLinkFunctionToken);
2020
}

lib/src/classifier/softmax_regressor/_init_module.dart

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ import 'package:ml_algo/src/classifier/softmax_regressor/_injector.dart';
22
import 'package:ml_algo/src/classifier/softmax_regressor/softmax_regressor_factory.dart';
33
import 'package:ml_algo/src/classifier/softmax_regressor/softmax_regressor_factory_impl.dart';
44
import 'package:ml_algo/src/di/common/init_common_module.dart';
5+
import 'package:ml_algo/src/extensions/injector.dart';
56
import 'package:ml_algo/src/link_function/link_function.dart';
67
import 'package:ml_algo/src/link_function/link_function_dependency_tokens.dart';
78
import 'package:ml_algo/src/link_function/softmax/float32_softmax_link_function.dart';
@@ -11,15 +12,14 @@ void initSoftmaxRegressorModule() {
1112
initCommonModule();
1213

1314
softmaxRegressorInjector
14-
..clearAll()
15-
..registerSingleton<LinkFunction>(
15+
..registerSingletonIf<LinkFunction>(
1616
() => const Float32SoftmaxLinkFunction(),
1717
dependencyName: float32SoftmaxLinkFunctionToken)
1818

19-
..registerSingleton<LinkFunction>(
19+
..registerSingletonIf<LinkFunction>(
2020
() => const Float64SoftmaxLinkFunction(),
2121
dependencyName: float64SoftmaxLinkFunctionToken)
2222

23-
..registerSingleton<SoftmaxRegressorFactory>(
23+
..registerSingletonIf<SoftmaxRegressorFactory>(
2424
() => const SoftmaxRegressorFactoryImpl());
2525
}

lib/src/di/common/init_common_module.dart

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ import 'package:ml_algo/src/cost_function/cost_function_factory.dart';
33
import 'package:ml_algo/src/cost_function/cost_function_factory_impl.dart';
44
import 'package:ml_algo/src/di/dependency_keys.dart';
55
import 'package:ml_algo/src/di/injector.dart';
6+
import 'package:ml_algo/src/extensions/injector.dart';
67
import 'package:ml_algo/src/helpers/features_target_split.dart';
78
import 'package:ml_algo/src/helpers/features_target_split_interface.dart';
89
import 'package:ml_algo/src/helpers/normalize_class_labels.dart';
@@ -30,40 +31,39 @@ typedef EncoderFactory = Encoder Function(DataFrame, Iterable<String>);
3031

3132
void initCommonModule() {
3233
injector
33-
..clearAll()
34-
..registerSingleton<EncoderFactory>(
34+
..registerSingletonIf<EncoderFactory>(
3535
() => (DataFrame data, Iterable<String> targetNames) =>
3636
Encoder.oneHot(data, featureNames: targetNames),
3737
dependencyName: oneHotEncoderFactoryKey)
3838

39-
..registerSingleton<RandomizerFactory>(
39+
..registerSingletonIf<RandomizerFactory>(
4040
() => const RandomizerFactoryImpl())
4141

42-
..registerDependency<FeaturesTargetSplit>(
42+
..registerSingletonIf<FeaturesTargetSplit>(
4343
() => featuresTargetSplit)
4444

45-
..registerSingleton<MetricFactory>(
45+
..registerSingletonIf<MetricFactory>(
4646
() => const MetricFactoryImpl())
4747

48-
..registerDependency<NormalizeClassLabels>(
48+
..registerSingletonIf<NormalizeClassLabels>(
4949
() => normalizeClassLabels)
5050

51-
..registerSingleton<LinearOptimizerFactory>(
51+
..registerSingletonIf<LinearOptimizerFactory>(
5252
() => const LinearOptimizerFactoryImpl())
5353

54-
..registerSingleton<LearningRateGeneratorFactory>(
54+
..registerSingletonIf<LearningRateGeneratorFactory>(
5555
() => const LearningRateGeneratorFactoryImpl())
5656

57-
..registerSingleton<InitialCoefficientsGeneratorFactory>(
57+
..registerSingletonIf<InitialCoefficientsGeneratorFactory>(
5858
() => const InitialCoefficientsGeneratorFactoryImpl())
5959

60-
..registerDependency<ConvergenceDetectorFactory>(
60+
..registerSingletonIf<ConvergenceDetectorFactory>(
6161
() => const ConvergenceDetectorFactoryImpl())
6262

63-
..registerSingleton<CostFunctionFactory>(
63+
..registerSingletonIf<CostFunctionFactory>(
6464
() => const CostFunctionFactoryImpl())
6565

66-
..registerSingleton<ModelAssessor<Classifier>>(() =>
66+
..registerSingletonIf<ModelAssessor<Classifier>>(() =>
6767
ClassifierAssessor(
6868
injector.get<MetricFactory>(),
6969
injector.get<EncoderFactory>(
@@ -72,7 +72,7 @@ void initCommonModule() {
7272
normalizeClassLabels,
7373
))
7474

75-
..registerSingleton<ModelAssessor<Predictor>>(() =>
75+
..registerSingletonIf<ModelAssessor<Predictor>>(() =>
7676
RegressorAssessor(
7777
injector.get<MetricFactory>(),
7878
featuresTargetSplit,

lib/src/extensions/injector.dart

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
import 'package:injector/injector.dart';
2+
3+
extension InjectorExtension on Injector {
4+
/// Registers a dependency only if it doesn't exist
5+
void registerSingletonIf<T>(Builder<T> builder, {
6+
bool override = false,
7+
String dependencyName = '',
8+
}) {
9+
if (exists<T>(dependencyName: dependencyName)) {
10+
return;
11+
}
12+
13+
registerSingleton<T>(
14+
builder,
15+
override: override,
16+
dependencyName: dependencyName,
17+
);
18+
}
19+
}

lib/src/regressor/knn_regressor/_init_module.dart

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import 'package:ml_algo/src/di/common/init_common_module.dart';
2+
import 'package:ml_algo/src/extensions/injector.dart';
23
import 'package:ml_algo/src/knn_kernel/kernel_factory.dart';
34
import 'package:ml_algo/src/knn_kernel/kernel_factory_impl.dart';
45
import 'package:ml_algo/src/knn_solver/knn_solver_factory.dart';
@@ -11,14 +12,13 @@ void initKnnRegressorModule() {
1112
initCommonModule();
1213

1314
knnRegressorInjector
14-
..clearAll()
15-
..registerSingleton<KernelFactory>(
15+
..registerSingletonIf<KernelFactory>(
1616
() => const KernelFactoryImpl())
1717

18-
..registerDependency<KnnSolverFactory>(
18+
..registerSingletonIf<KnnSolverFactory>(
1919
() => const KnnSolverFactoryImpl())
2020

21-
..registerSingleton<KnnRegressorFactory>(
21+
..registerSingletonIf<KnnRegressorFactory>(
2222
() => KnnRegressorFactoryImpl(
2323
knnRegressorInjector.get<KernelFactory>(),
2424
knnRegressorInjector.get<KnnSolverFactory>(),

pubspec.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name: ml_algo
22
description: Machine learning algorithms, Machine learning models performance evaluation functionality
3-
version: 15.3.3
3+
version: 15.3.4
44
homepage: https://github.com/gyrdym/ml_algo
55

66
environment:

0 commit comments

Comments
 (0)