From 6c266607d9232f98540bd201a3177b2dfdaf0ee1 Mon Sep 17 00:00:00 2001 From: Adam Li Date: Wed, 3 Apr 2024 14:26:58 -0400 Subject: [PATCH 1/5] Add option to treat nans as mcar Signed-off-by: Adam Li --- sklearn/ensemble/_forest.py | 7 +++ sklearn/tree/_classes.py | 8 +++ sklearn/tree/_splitter.pyx | 102 ++++++++++++++++++++++++++---------- 3 files changed, 88 insertions(+), 29 deletions(-) diff --git a/sklearn/ensemble/_forest.py b/sklearn/ensemble/_forest.py index b5ee64b6e708c..fd4c428880360 100644 --- a/sklearn/ensemble/_forest.py +++ b/sklearn/ensemble/_forest.py @@ -320,6 +320,7 @@ def __init__( max_samples=None, max_bins=None, store_leaf_values=False, + missing_car=False, ): super().__init__( estimator=estimator, @@ -337,6 +338,7 @@ def __init__( self.max_samples = max_samples self.max_bins = max_bins self.store_leaf_values = store_leaf_values + self.missing_car = missing_car def apply(self, X): """ @@ -1085,6 +1087,7 @@ def __init__( max_samples=None, max_bins=None, store_leaf_values=False, + missing_car=False, ): super().__init__( estimator=estimator, @@ -1100,6 +1103,7 @@ def __init__( max_samples=max_samples, max_bins=max_bins, store_leaf_values=store_leaf_values, + missing_car=missing_car, ) @staticmethod @@ -2111,6 +2115,7 @@ def __init__( max_bins=None, store_leaf_values=False, monotonic_cst=None, + missing_car=False, ): super().__init__( estimator=DecisionTreeClassifier(), @@ -2128,6 +2133,7 @@ def __init__( "ccp_alpha", "store_leaf_values", "monotonic_cst", + "missing_car", ), bootstrap=bootstrap, oob_score=oob_score, @@ -2139,6 +2145,7 @@ def __init__( max_samples=max_samples, max_bins=max_bins, store_leaf_values=store_leaf_values, + missing_car=missing_car, ) self.criterion = criterion diff --git a/sklearn/tree/_classes.py b/sklearn/tree/_classes.py index 2124cd76c69c8..711132dd4dc1b 100644 --- a/sklearn/tree/_classes.py +++ b/sklearn/tree/_classes.py @@ -149,6 +149,7 @@ def __init__( ccp_alpha=0.0, store_leaf_values=False, monotonic_cst=None, + missing_car=False, ): self.criterion = criterion self.splitter = splitter @@ -164,6 +165,7 @@ def __init__( self.ccp_alpha = ccp_alpha self.store_leaf_values = store_leaf_values self.monotonic_cst = monotonic_cst + self.missing_car = missing_car def get_depth(self): """Return the depth of the decision tree. @@ -532,6 +534,7 @@ def _build_tree( min_weight_leaf, random_state, monotonic_cst, + self.missing_car, ) if is_classifier(self): @@ -614,6 +617,7 @@ def _update_tree(self, X, y, sample_weight): min_weight_leaf, random_state, monotonic_cst, + self.missing_car, ) # Use BestFirst if max_leaf_nodes given; use DepthFirst otherwise @@ -1280,6 +1284,7 @@ def __init__( ccp_alpha=0.0, store_leaf_values=False, monotonic_cst=None, + missing_car=False, ): super().__init__( criterion=criterion, @@ -1296,6 +1301,7 @@ def __init__( monotonic_cst=monotonic_cst, ccp_alpha=ccp_alpha, store_leaf_values=store_leaf_values, + missing_car=missing_car, ) @_fit_context(prefer_skip_nested_validation=True) @@ -1784,6 +1790,7 @@ def __init__( ccp_alpha=0.0, store_leaf_values=False, monotonic_cst=None, + missing_car=False, ): super().__init__( criterion=criterion, @@ -1799,6 +1806,7 @@ def __init__( ccp_alpha=ccp_alpha, store_leaf_values=store_leaf_values, monotonic_cst=monotonic_cst, + missing_car=missing_car, ) @_fit_context(prefer_skip_nested_validation=True) diff --git a/sklearn/tree/_splitter.pyx b/sklearn/tree/_splitter.pyx index d3c8fa1f98e83..6919638b77769 100644 --- a/sklearn/tree/_splitter.pyx +++ b/sklearn/tree/_splitter.pyx @@ -148,6 +148,7 @@ cdef class Splitter(BaseSplitter): float64_t min_weight_leaf, object random_state, const int8_t[:] monotonic_cst, + bint missing_car, *argv ): """ @@ -173,8 +174,17 @@ cdef class Splitter(BaseSplitter): The user inputted random state to be used for pseudo-randomness monotonic_cst : const int8_t[:] - Monotonicity constraints + Indicates the monotonicity constraint to enforce on each feature. + - 1: monotonic increase + - 0: no constraint + - -1: monotonic decrease + If monotonic_cst is None, no constraints are applied. + + missing_car : bool + Indicates if the missing-values should be assumed as missing completely + at random. If that is the case, the missing values will be randomly + assigned to the left or right child of the split. """ self.criterion = criterion @@ -187,14 +197,18 @@ cdef class Splitter(BaseSplitter): self.random_state = random_state self.monotonic_cst = monotonic_cst self.with_monotonic_cst = monotonic_cst is not None + self.missing_car = missing_car def __reduce__(self): - return (type(self), (self.criterion, - self.max_features, - self.min_samples_leaf, - self.min_weight_leaf, - self.random_state, - self.monotonic_cst.base if self.monotonic_cst is not None else None), self.__getstate__()) + return (type(self), ( + self.criterion, + self.max_features, + self.min_samples_leaf, + self.min_weight_leaf, + self.random_state, + self.monotonic_cst.base if self.monotonic_cst is not None else None, + self.missing_car, + ), self.__getstate__()) cdef int init( self, @@ -562,10 +576,13 @@ cdef inline intp_t node_split_best( # The second search will have all the missing values going to the left node. # If there are no missing values, then we search only once for the most # optimal split. - n_searches = 2 if has_missing else 1 + n_searches = 2 if has_missing and not self.missing_car else 1 for i in range(n_searches): - missing_go_to_left = i == 1 + if self.missing_car: + missing_go_to_left = rand_int(0, 2, random_state) + else: + missing_go_to_left = i == 1 criterion.missing_go_to_left = missing_go_to_left criterion.reset() @@ -645,26 +662,18 @@ cdef inline intp_t node_split_best( # Evaluate when there are missing values and all missing values goes # to the right node and non-missing values goes to the left node. - if has_missing: - n_left, n_right = end - start - n_missing, n_missing - p = end - n_missing - missing_go_to_left = 0 - - if not (n_left < min_samples_leaf or n_right < min_samples_leaf): - criterion.missing_go_to_left = missing_go_to_left - criterion.update(p) - - if not ((criterion.weighted_n_left < min_weight_leaf) or - (criterion.weighted_n_right < min_weight_leaf)): - current_proxy_improvement = criterion.proxy_impurity_improvement() - - if current_proxy_improvement > best_proxy_improvement: - best_proxy_improvement = current_proxy_improvement - current_split.threshold = INFINITY - current_split.missing_go_to_left = missing_go_to_left - current_split.n_missing = n_missing - current_split.pos = p - best_split = current_split + if has_missing and not self.missing_car: + evaluate_missing_values_to_right( + start, + end, + n_missing, + min_samples_leaf, + min_weight_leaf, + criterion, + current_split, + best_split, + best_proxy_improvement + ) # Reorganize into samples[start:best_split.pos] + samples[best_split.pos:end] if best_split.pos < end: @@ -706,6 +715,41 @@ cdef inline intp_t node_split_best( return 0 +cdef inline void evaluate_missing_values_to_right( + intp_t start, + intp_t end, + intp_t n_missing, + intp_t min_samples_leaf, + double min_weight_leaf, + BaseCriterion criterion, + SplitRecord current_split, + SplitRecord best_split, + double best_proxy_improvement +) nogil: + cdef intp_t n_left, n_right, p + cdef intp_t missing_go_to_left = 0 + + n_left = end - start - n_missing + n_right = n_missing + p = end - n_missing + + if not (n_left < min_samples_leaf or n_right < min_samples_leaf): + criterion.missing_go_to_left = missing_go_to_left + criterion.update(p) + + if not ((criterion.weighted_n_left < min_weight_leaf) or + (criterion.weighted_n_right < min_weight_leaf)): + current_proxy_improvement = criterion.proxy_impurity_improvement() + + if current_proxy_improvement > best_proxy_improvement: + best_proxy_improvement = current_proxy_improvement + current_split.threshold = INFINITY + current_split.missing_go_to_left = missing_go_to_left + current_split.n_missing = n_missing + current_split.pos = p + best_split = current_split + + # Sort n-element arrays pointed to by feature_values and samples, simultaneously, # by the values in feature_values. Algorithm: Introsort (Musser, SP&E, 1997). cdef inline void sort(float32_t* feature_values, intp_t* samples, intp_t n) noexcept nogil: From b9c6bb7b54395d461cf4e2bef16360144edd1b7a Mon Sep 17 00:00:00 2001 From: Adam Li Date: Wed, 3 Apr 2024 18:08:04 -0400 Subject: [PATCH 2/5] Add feature to allow treating nans randomly Signed-off-by: Adam Li --- sklearn/tree/_splitter.pxd | 1 + sklearn/tree/_splitter.pyx | 13 +++++++------ 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/sklearn/tree/_splitter.pxd b/sklearn/tree/_splitter.pxd index 9a8ae9da81b52..94dce58254663 100644 --- a/sklearn/tree/_splitter.pxd +++ b/sklearn/tree/_splitter.pxd @@ -96,6 +96,7 @@ cdef class Splitter(BaseSplitter): cdef public Criterion criterion # Impurity criterion cdef const float64_t[:, ::1] y + cdef bint missing_car # Monotonicity constraints for each feature. # The encoding is as follows: diff --git a/sklearn/tree/_splitter.pyx b/sklearn/tree/_splitter.pyx index 6919638b77769..d91d4ece21468 100644 --- a/sklearn/tree/_splitter.pyx +++ b/sklearn/tree/_splitter.pyx @@ -576,10 +576,10 @@ cdef inline intp_t node_split_best( # The second search will have all the missing values going to the left node. # If there are no missing values, then we search only once for the most # optimal split. - n_searches = 2 if has_missing and not self.missing_car else 1 + n_searches = 2 if has_missing and not splitter.missing_car else 1 for i in range(n_searches): - if self.missing_car: + if splitter.missing_car: missing_go_to_left = rand_int(0, 2, random_state) else: missing_go_to_left = i == 1 @@ -662,7 +662,7 @@ cdef inline intp_t node_split_best( # Evaluate when there are missing values and all missing values goes # to the right node and non-missing values goes to the left node. - if has_missing and not self.missing_car: + if has_missing and not splitter.missing_car: evaluate_missing_values_to_right( start, end, @@ -720,14 +720,15 @@ cdef inline void evaluate_missing_values_to_right( intp_t end, intp_t n_missing, intp_t min_samples_leaf, - double min_weight_leaf, - BaseCriterion criterion, + float64_t min_weight_leaf, + Criterion criterion, SplitRecord current_split, SplitRecord best_split, - double best_proxy_improvement + float64_t best_proxy_improvement ) nogil: cdef intp_t n_left, n_right, p cdef intp_t missing_go_to_left = 0 + cdef float64_t current_proxy_improvement = -INFINITY n_left = end - start - n_missing n_right = n_missing From 64738a3e18626941ff5df37a06c7a2f4ba559f89 Mon Sep 17 00:00:00 2001 From: Adam Li Date: Wed, 3 Apr 2024 20:46:30 -0400 Subject: [PATCH 3/5] Fix unit tests Signed-off-by: Adam Li --- sklearn/ensemble/_forest.py | 9 +++++ sklearn/tree/_classes.py | 9 +++++ sklearn/tree/_splitter.pyx | 72 ++++++++++----------------------- sklearn/tree/tests/test_tree.py | 4 +- 4 files changed, 43 insertions(+), 51 deletions(-) diff --git a/sklearn/ensemble/_forest.py b/sklearn/ensemble/_forest.py index fd4c428880360..c5627c60a38ca 100644 --- a/sklearn/ensemble/_forest.py +++ b/sklearn/ensemble/_forest.py @@ -1974,6 +1974,9 @@ class RandomForestClassifier(ForestClassifier): .. versionadded:: 1.4 + missing_car : bool, default=False + Whether the missing values are missing completely at random (CAR). + Attributes ---------- estimator_ : :class:`~sklearn.tree.DecisionTreeClassifier` @@ -2749,6 +2752,9 @@ class ExtraTreesClassifier(ForestClassifier): .. versionadded:: 1.4 + missing_car : bool, default=False + Whether the missing values are missing completely at random (CAR). + Attributes ---------- estimator_ : :class:`~sklearn.tree.ExtraTreeClassifier` @@ -2879,6 +2885,7 @@ def __init__( max_bins=None, store_leaf_values=False, monotonic_cst=None, + missing_car=False, ): super().__init__( estimator=ExtraTreeClassifier(), @@ -2896,6 +2903,7 @@ def __init__( "ccp_alpha", "store_leaf_values", "monotonic_cst", + "missing_car", ), bootstrap=bootstrap, oob_score=oob_score, @@ -2907,6 +2915,7 @@ def __init__( max_samples=max_samples, max_bins=max_bins, store_leaf_values=store_leaf_values, + missing_car=missing_car, ) self.criterion = criterion diff --git a/sklearn/tree/_classes.py b/sklearn/tree/_classes.py index 711132dd4dc1b..ea3683025cf6f 100644 --- a/sklearn/tree/_classes.py +++ b/sklearn/tree/_classes.py @@ -129,6 +129,7 @@ class BaseDecisionTree(MultiOutputMixin, BaseEstimator, metaclass=ABCMeta): "ccp_alpha": [Interval(Real, 0.0, None, closed="left")], "store_leaf_values": ["boolean"], "monotonic_cst": ["array-like", None], + "missing_car": ["boolean"], } @abstractmethod @@ -1156,6 +1157,9 @@ class DecisionTreeClassifier(ClassifierMixin, BaseDecisionTree): .. versionadded:: 1.4 + missing_car : bool, default=False + Whether the missing values are missing completely at random (CAR). + Attributes ---------- classes_ : ndarray of shape (n_classes,) or list of ndarray @@ -2062,6 +2066,9 @@ class ExtraTreeClassifier(DecisionTreeClassifier): .. versionadded:: 1.4 + missing_car : bool, default=False + Whether the missing values are missing completely at random (CAR). + Attributes ---------- classes_ : ndarray of shape (n_classes,) or list of ndarray @@ -2176,6 +2183,7 @@ def __init__( ccp_alpha=0.0, store_leaf_values=False, monotonic_cst=None, + missing_car=False, ): super().__init__( criterion=criterion, @@ -2192,6 +2200,7 @@ def __init__( ccp_alpha=ccp_alpha, store_leaf_values=store_leaf_values, monotonic_cst=monotonic_cst, + missing_car=missing_car, ) diff --git a/sklearn/tree/_splitter.pyx b/sklearn/tree/_splitter.pyx index d91d4ece21468..5469845e8fe80 100644 --- a/sklearn/tree/_splitter.pyx +++ b/sklearn/tree/_splitter.pyx @@ -579,10 +579,10 @@ cdef inline intp_t node_split_best( n_searches = 2 if has_missing and not splitter.missing_car else 1 for i in range(n_searches): - if splitter.missing_car: - missing_go_to_left = rand_int(0, 2, random_state) - else: + if not splitter.missing_car: missing_go_to_left = i == 1 + else: + missing_go_to_left = rand_int(0, 2, random_state) criterion.missing_go_to_left = missing_go_to_left criterion.reset() @@ -663,17 +663,25 @@ cdef inline intp_t node_split_best( # Evaluate when there are missing values and all missing values goes # to the right node and non-missing values goes to the left node. if has_missing and not splitter.missing_car: - evaluate_missing_values_to_right( - start, - end, - n_missing, - min_samples_leaf, - min_weight_leaf, - criterion, - current_split, - best_split, - best_proxy_improvement - ) + n_left, n_right = end - start - n_missing, n_missing + p = end - n_missing + missing_go_to_left = 0 + + if not (n_left < min_samples_leaf or n_right < min_samples_leaf): + criterion.missing_go_to_left = missing_go_to_left + criterion.update(p) + + if not ((criterion.weighted_n_left < min_weight_leaf) or + (criterion.weighted_n_right < min_weight_leaf)): + current_proxy_improvement = criterion.proxy_impurity_improvement() + + if current_proxy_improvement > best_proxy_improvement: + best_proxy_improvement = current_proxy_improvement + current_split.threshold = INFINITY + current_split.missing_go_to_left = missing_go_to_left + current_split.n_missing = n_missing + current_split.pos = p + best_split = current_split # Reorganize into samples[start:best_split.pos] + samples[best_split.pos:end] if best_split.pos < end: @@ -715,42 +723,6 @@ cdef inline intp_t node_split_best( return 0 -cdef inline void evaluate_missing_values_to_right( - intp_t start, - intp_t end, - intp_t n_missing, - intp_t min_samples_leaf, - float64_t min_weight_leaf, - Criterion criterion, - SplitRecord current_split, - SplitRecord best_split, - float64_t best_proxy_improvement -) nogil: - cdef intp_t n_left, n_right, p - cdef intp_t missing_go_to_left = 0 - cdef float64_t current_proxy_improvement = -INFINITY - - n_left = end - start - n_missing - n_right = n_missing - p = end - n_missing - - if not (n_left < min_samples_leaf or n_right < min_samples_leaf): - criterion.missing_go_to_left = missing_go_to_left - criterion.update(p) - - if not ((criterion.weighted_n_left < min_weight_leaf) or - (criterion.weighted_n_right < min_weight_leaf)): - current_proxy_improvement = criterion.proxy_impurity_improvement() - - if current_proxy_improvement > best_proxy_improvement: - best_proxy_improvement = current_proxy_improvement - current_split.threshold = INFINITY - current_split.missing_go_to_left = missing_go_to_left - current_split.n_missing = n_missing - current_split.pos = p - best_split = current_split - - # Sort n-element arrays pointed to by feature_values and samples, simultaneously, # by the values in feature_values. Algorithm: Introsort (Musser, SP&E, 1997). cdef inline void sort(float32_t* feature_values, intp_t* samples, intp_t n) noexcept nogil: diff --git a/sklearn/tree/tests/test_tree.py b/sklearn/tree/tests/test_tree.py index ef26ec1be0b1d..3b814ff35f038 100644 --- a/sklearn/tree/tests/test_tree.py +++ b/sklearn/tree/tests/test_tree.py @@ -2349,7 +2349,9 @@ def test_splitter_serializable(Splitter): n_outputs, n_classes = 2, np.array([3, 2], dtype=np.intp) criterion = CRITERIA_CLF["gini"](n_outputs, n_classes) - splitter = Splitter(criterion, max_features, 5, 0.5, rng, monotonic_cst=None) + splitter = Splitter( + criterion, max_features, 5, 0.5, rng, monotonic_cst=None, missing_car=False + ) splitter_serialize = pickle.dumps(splitter) splitter_back = pickle.loads(splitter_serialize) From 5f0ac4c61dcb8f14ad4066e3a0d741fe08debe0d Mon Sep 17 00:00:00 2001 From: Adam Li Date: Wed, 3 Apr 2024 20:55:24 -0400 Subject: [PATCH 4/5] MCAR for decision tree Signed-off-by: Adam Li --- sklearn/tree/tests/test_tree.py | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/sklearn/tree/tests/test_tree.py b/sklearn/tree/tests/test_tree.py index 3b814ff35f038..bbb3d8bda9a48 100644 --- a/sklearn/tree/tests/test_tree.py +++ b/sklearn/tree/tests/test_tree.py @@ -2602,6 +2602,35 @@ def test_missing_value_is_predictive(): assert tree.score(X_test, y_test) >= 0.85 +def test_missing_value_is_not_predictive_with_mcar(): + """Check the tree doesnt learns when the missing value is forced to be unpredictive.""" + rng = np.random.RandomState(0) + n_samples = 1000 + + X = rng.standard_normal(size=(n_samples, 10)) + y = rng.randint(0, high=2, size=n_samples) + + # Create a predictive feature using `y` and with some noise + X_random_mask = rng.choice([False, True], size=n_samples, p=[0.9, 0.1]) + y_mask = y.copy().astype(bool) + y_mask[X_random_mask] = ~y_mask[X_random_mask] + + X_predictive = rng.standard_normal(size=n_samples) + X_predictive[y_mask] = np.nan + + X[:, 5] = X_predictive + + X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=rng) + tree = DecisionTreeClassifier(random_state=rng, missing_car=True).fit( + X_train, y_train + ) + non_mcar_tree = DecisionTreeClassifier(random_state=rng, missing_car=False).fit( + X_train, y_train + ) + + non_mcar_tree.score(X_test, y_test) > tree.score(X_test, y_test) + 0.2 + + @pytest.mark.parametrize( "make_data, Tree", [ From 9816760c5a801d9a5e1e96021e5cdf91f5a14260 Mon Sep 17 00:00:00 2001 From: Adam Li Date: Wed, 3 Apr 2024 21:13:50 -0400 Subject: [PATCH 5/5] Fix lint Signed-off-by: Adam Li --- sklearn/tree/tests/test_tree.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/sklearn/tree/tests/test_tree.py b/sklearn/tree/tests/test_tree.py index bbb3d8bda9a48..0c722be827b36 100644 --- a/sklearn/tree/tests/test_tree.py +++ b/sklearn/tree/tests/test_tree.py @@ -2603,7 +2603,9 @@ def test_missing_value_is_predictive(): def test_missing_value_is_not_predictive_with_mcar(): - """Check the tree doesnt learns when the missing value is forced to be unpredictive.""" + """Check the tree doesnt learns when the missing value is forced to be + unpredictive. + """ rng = np.random.RandomState(0) n_samples = 1000