Skip to content

Commit cf3dff5

Browse files
authored
DecisionTree: Gini index split criteria added (#215)
1 parent 56bc063 commit cf3dff5

File tree

67 files changed

+971
-585
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

67 files changed

+971
-585
lines changed

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
# Changelog
22

3+
## 16.8.0
4+
- DecisionTreeClassifier:
5+
- Added Gini index assessor type
6+
37
## 16.7.2
48
- DecisionTreeClassifier:
59
- TreeNode fields renamed
Lines changed: 1 addition & 1 deletion
Loading

e2e/decision_tree_classifier/pima_indians_tree.svg

Lines changed: 1 addition & 1 deletion
Loading

lib/ml_algo.dart

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,4 @@ export 'package:ml_algo/src/regressor/knn_regressor/knn_regressor.dart';
1414
export 'package:ml_algo/src/regressor/linear_regressor/linear_regressor.dart';
1515
export 'package:ml_algo/src/retrieval/kd_tree/kd_tree.dart';
1616
export 'package:ml_algo/src/retrieval/kd_tree/kd_tree_split_strategy.dart';
17+
export 'package:ml_algo/src/tree_trainer/tree_assessor/tree_assessor_type.dart';

lib/src/classifier/decision_tree_classifier/_init_module.dart

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,16 @@ import 'package:injector/injector.dart';
22
import 'package:ml_algo/src/classifier/decision_tree_classifier/_injector.dart';
33
import 'package:ml_algo/src/classifier/decision_tree_classifier/decision_tree_classifier_factory.dart';
44
import 'package:ml_algo/src/classifier/decision_tree_classifier/decision_tree_classifier_factory_impl.dart';
5+
import 'package:ml_algo/src/common/distribution_calculator/distribution_calculator.dart';
56
import 'package:ml_algo/src/common/distribution_calculator/distribution_calculator_factory.dart';
67
import 'package:ml_algo/src/common/distribution_calculator/distribution_calculator_factory_impl.dart';
8+
import 'package:ml_algo/src/common/distribution_calculator/distribution_calculator_impl.dart';
79
import 'package:ml_algo/src/di/common/init_common_module.dart';
810
import 'package:ml_algo/src/extensions/injector.dart';
911
import 'package:ml_algo/src/tree_trainer/leaf_detector/leaf_detector_factory.dart';
1012
import 'package:ml_algo/src/tree_trainer/leaf_detector/leaf_detector_factory_impl.dart';
1113
import 'package:ml_algo/src/tree_trainer/leaf_label/leaf_label_factory_factory.dart';
1214
import 'package:ml_algo/src/tree_trainer/leaf_label/leaf_label_factory_factory_impl.dart';
13-
import 'package:ml_algo/src/tree_trainer/split_assessor/split_assessor_factory.dart';
14-
import 'package:ml_algo/src/tree_trainer/split_assessor/split_assessor_factory_impl.dart';
1515
import 'package:ml_algo/src/tree_trainer/split_selector/split_selector_factory.dart';
1616
import 'package:ml_algo/src/tree_trainer/split_selector/split_selector_factory_impl.dart';
1717
import 'package:ml_algo/src/tree_trainer/splitter/nominal_splitter/nominal_splitter_factory.dart';
@@ -20,6 +20,8 @@ import 'package:ml_algo/src/tree_trainer/splitter/numerical_splitter/numerical_s
2020
import 'package:ml_algo/src/tree_trainer/splitter/numerical_splitter/numerical_splitter_factory_impl.dart';
2121
import 'package:ml_algo/src/tree_trainer/splitter/splitter_factory.dart';
2222
import 'package:ml_algo/src/tree_trainer/splitter/splitter_factory_impl.dart';
23+
import 'package:ml_algo/src/tree_trainer/tree_assessor/tree_assessor_factory.dart';
24+
import 'package:ml_algo/src/tree_trainer/tree_assessor/tree_assessor_factory_impl.dart';
2325
import 'package:ml_algo/src/tree_trainer/tree_trainer_factory.dart';
2426
import 'package:ml_algo/src/tree_trainer/tree_trainer_factory_impl.dart';
2527

@@ -29,25 +31,27 @@ Injector initDecisionTreeModule() {
2931
return decisionTreeInjector
3032
..registerSingletonIf<DistributionCalculatorFactory>(
3133
() => const DistributionCalculatorFactoryImpl())
34+
..registerSingletonIf<DistributionCalculator>(
35+
() => const DistributionCalculatorImpl())
3236
..registerSingletonIf<NominalTreeSplitterFactory>(
3337
() => const NominalTreeSplitterFactoryImpl())
3438
..registerSingletonIf<NumericalTreeSplitterFactory>(
3539
() => const NumericalTreeSplitterFactoryImpl())
36-
..registerSingletonIf<TreeSplitAssessorFactory>(
37-
() => const TreeSplitAssessorFactoryImpl())
40+
..registerSingletonIf<TreeAssessorFactory>(() => TreeAssessorFactoryImpl(
41+
decisionTreeInjector.get<DistributionCalculator>()))
3842
..registerSingletonIf<TreeSplitterFactory>(() => TreeSplitterFactoryImpl(
39-
decisionTreeInjector.get<TreeSplitAssessorFactory>(),
43+
decisionTreeInjector.get<TreeAssessorFactory>(),
4044
decisionTreeInjector.get<NominalTreeSplitterFactory>(),
4145
decisionTreeInjector.get<NumericalTreeSplitterFactory>(),
4246
))
4347
..registerSingletonIf<TreeSplitSelectorFactory>(
4448
() => TreeSplitSelectorFactoryImpl(
45-
decisionTreeInjector.get<TreeSplitAssessorFactory>(),
49+
decisionTreeInjector.get<TreeAssessorFactory>(),
4650
decisionTreeInjector.get<TreeSplitterFactory>(),
4751
))
4852
..registerSingletonIf<TreeLeafDetectorFactory>(
4953
() => TreeLeafDetectorFactoryImpl(
50-
decisionTreeInjector.get<TreeSplitAssessorFactory>(),
54+
decisionTreeInjector.get<TreeAssessorFactory>(),
5155
))
5256
..registerSingletonIf<TreeLeafLabelFactoryFactory>(
5357
() => TreeLeafLabelFactoryFactoryImpl(

lib/src/classifier/decision_tree_classifier/decision_tree_classifier.dart

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,14 @@ import 'package:ml_algo/src/common/constants/default_parameters/common.dart';
77
import 'package:ml_algo/src/common/serializable/serializable.dart';
88
import 'package:ml_algo/src/model_selection/assessable.dart';
99
import 'package:ml_algo/src/predictor/retrainable.dart';
10+
import 'package:ml_algo/src/tree_trainer/tree_assessor/tree_assessor_type.dart';
1011
import 'package:ml_dataframe/ml_dataframe.dart';
1112
import 'package:ml_linalg/dtype.dart';
1213

1314
/// A class that performs decision tree-based classification
1415
///
1516
/// Decision tree is an algorithm that recursively splits the input data into
16-
/// subsets until the subsets conforming certain stop criteria are found.
17+
/// subsets until the subsets conforming certain stop criteria are met.
1718
///
1819
/// Process of forming such a recursive subsets structure is called
1920
/// decision tree learning. Once a decision tree learned, it may be used to
@@ -45,13 +46,26 @@ abstract class DecisionTreeClassifier
4546
/// equal to [minSamplesCount] observations, the node turns into the leaf.
4647
///
4748
/// [maxDepth] A maximum number of decision tree levels.
49+
///
50+
/// [assessorType] Defines an assessment type that will be applied to a subset
51+
/// of data in order to decide how to split the subset while building the tree.
52+
/// Default value is [TreeAssessorType.gini]
53+
///
54+
/// Possible values of [assessorType] :
55+
///
56+
/// [TreeAssessorType.gini] The algorithm makes a decision on how to split a
57+
/// subset of data based on the [Gini index](https://en.wikipedia.org/wiki/Gini_coefficient)
58+
///
59+
/// [TreeAssessorType.majority] The algorithm makes a decision on how to split a
60+
/// subset of data based on a major class.
4861
factory DecisionTreeClassifier(
4962
DataFrame trainData,
5063
String targetName, {
5164
num minError = 0.5,
5265
int minSamplesCount = 1,
5366
int maxDepth = 10,
5467
DType dtype = dTypeDefaultValue,
68+
TreeAssessorType assessorType = TreeAssessorType.gini,
5569
}) =>
5670
initDecisionTreeModule().get<DecisionTreeClassifierFactory>().create(
5771
trainData,
@@ -60,6 +74,7 @@ abstract class DecisionTreeClassifier
6074
minError,
6175
minSamplesCount,
6276
maxDepth,
77+
assessorType,
6378
);
6479

6580
/// Restores previously fitted classifier instance from the given [json]
@@ -125,6 +140,10 @@ abstract class DecisionTreeClassifier
125140
/// The value is read-only, it's a hyperparameter of the model
126141
int get maxDepth;
127142

143+
/// An assessment type that was applied to a subset of data in order to
144+
/// decide how to split the subset while building the tree
145+
TreeAssessorType get assessorType;
146+
128147
/// Saves tree as SVG-image. Example:
129148
///
130149
/// ```dart

lib/src/classifier/decision_tree_classifier/decision_tree_classifier_factory.dart

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
import 'package:ml_algo/ml_algo.dart';
1+
import 'package:ml_algo/src/classifier/decision_tree_classifier/decision_tree_classifier.dart';
2+
import 'package:ml_algo/src/tree_trainer/tree_assessor/tree_assessor_type.dart';
23
import 'package:ml_dataframe/ml_dataframe.dart';
34
import 'package:ml_linalg/dtype.dart';
45

@@ -10,6 +11,7 @@ abstract class DecisionTreeClassifierFactory {
1011
num minError,
1112
int minSamplesCount,
1213
int maxDepth,
14+
TreeAssessorType assessorType,
1315
);
1416

1517
DecisionTreeClassifier fromJson(String json);

lib/src/classifier/decision_tree_classifier/decision_tree_classifier_factory_impl.dart

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ import 'package:ml_algo/src/helpers/validate_tree_solver_max_depth.dart';
77
import 'package:ml_algo/src/helpers/validate_tree_solver_min_error.dart';
88
import 'package:ml_algo/src/helpers/validate_tree_solver_min_samples_count.dart';
99
import 'package:ml_algo/src/tree_trainer/leaf_label/leaf_label_factory_type.dart';
10-
import 'package:ml_algo/src/tree_trainer/split_assessor/split_assessor_type.dart';
10+
import 'package:ml_algo/src/tree_trainer/tree_assessor/tree_assessor_type.dart';
1111
import 'package:ml_algo/src/tree_trainer/split_selector/split_selector_type.dart';
1212
import 'package:ml_algo/src/tree_trainer/splitter/splitter_type.dart';
1313
import 'package:ml_algo/src/tree_trainer/tree_trainer_factory.dart';
@@ -29,6 +29,7 @@ class DecisionTreeClassifierFactoryImpl
2929
num minError,
3030
int minSamplesCount,
3131
int maxDepth,
32+
TreeAssessorType assessorType,
3233
) {
3334
validateTreeSolverMinError(minError);
3435
validateTreeSolversMinSamplesCount(minSamplesCount);
@@ -41,10 +42,10 @@ class DecisionTreeClassifierFactoryImpl
4142
minError,
4243
minSamplesCount,
4344
maxDepth,
44-
TreeSplitAssessorType.majority,
45+
assessorType,
4546
TreeLeafLabelFactoryType.majority,
4647
TreeSplitSelectorType.greedy,
47-
TreeSplitAssessorType.majority,
48+
assessorType,
4849
TreeSplitterType.greedy,
4950
);
5051
final treeRootNode = trainer.train(trainData.toMatrix(dtype));
@@ -55,6 +56,7 @@ class DecisionTreeClassifierFactoryImpl
5556
maxDepth,
5657
treeRootNode,
5758
targetName,
59+
assessorType,
5860
dtype,
5961
);
6062
}

lib/src/classifier/decision_tree_classifier/decision_tree_classifier_impl.dart

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@ import 'package:ml_algo/src/common/constants/common_json_keys.dart';
1212
import 'package:ml_algo/src/common/json_converter/dtype_json_converter.dart';
1313
import 'package:ml_algo/src/common/serializable/serializable_mixin.dart';
1414
import 'package:ml_algo/src/tree_trainer/leaf_label/leaf_label.dart';
15+
import 'package:ml_algo/src/tree_trainer/tree_assessor/helpers/from_tree_assessor_type_json.dart';
16+
import 'package:ml_algo/src/tree_trainer/tree_assessor/helpers/to_tree_assessor_type_json.dart';
17+
import 'package:ml_algo/src/tree_trainer/tree_assessor/tree_assessor_type.dart';
1518
import 'package:ml_algo/src/tree_trainer/tree_node/tree_node.dart';
1619
import 'package:ml_dataframe/ml_dataframe.dart';
1720
import 'package:ml_linalg/dtype.dart';
@@ -34,6 +37,7 @@ class DecisionTreeClassifierImpl
3437
this.maxDepth,
3538
this.treeRootNode,
3639
this.targetColumnName,
40+
this.assessorType,
3741
this.dtype, {
3842
this.schemaVersion = decisionTreeClassifierJsonSchemaVersion,
3943
});
@@ -79,6 +83,13 @@ class DecisionTreeClassifierImpl
7983
@JsonKey(name: jsonSchemaVersionJsonKey)
8084
final int schemaVersion;
8185

86+
@override
87+
@JsonKey(
88+
name: decisionTreeClassifierAssessorTypeJsonKey,
89+
toJson: toTreeAssessorTypeJson,
90+
fromJson: fromTreeAssessorTypeJson)
91+
final TreeAssessorType assessorType;
92+
8293
@override
8394
DataFrame predict(DataFrame features) {
8495
final predictedLabels = features
@@ -144,6 +155,7 @@ class DecisionTreeClassifierImpl
144155
minError,
145156
minSamplesCount,
146157
maxDepth,
158+
assessorType,
147159
);
148160
}
149161

lib/src/classifier/decision_tree_classifier/decision_tree_classifier_impl.g.dart

Lines changed: 5 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)