Skip to content

Commit f0be6a5

Browse files
authored
Fixed suggested_sampling for imbalanced data in RecommendationTransformer (#193)
1 parent f79a4d6 commit f0be6a5

File tree

2 files changed

+62
-15
lines changed

2 files changed

+62
-15
lines changed

ads/dataset/recommendation_transformer.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -253,14 +253,15 @@ def _get_recommendations(self, df):
253253
)[1]
254254
minor_majority_ratio = minority_class_len / majority_class_len
255255

256-
# Suggest down sampling if minor_majority_ratio is 1:1000
257-
suggested_sampling = (
258-
"Do nothing"
259-
if len(df) <= utils.MAX_LEN_FOR_UP_SAMPLING
260-
else "Down-sample"
261-
if minor_majority_ratio >= utils.MIN_RATIO_FOR_DOWN_SAMPLING
262-
else "Do nothing"
263-
)
256+
# up-sample if length of dataframe is less than or equal to MAX_LEN_FOR_UP_SAMPLING = 5000
257+
# down-sample if minor_majority_ratio is greater than or equal to MIN_RATIO_FOR_DOWN_SAMPLING = 1/20
258+
if len(df) <= utils.MAX_LEN_FOR_UP_SAMPLING:
259+
suggested_sampling = "Up-sample"
260+
elif minor_majority_ratio >= utils.MIN_RATIO_FOR_DOWN_SAMPLING:
261+
suggested_sampling = "Down-sample"
262+
else:
263+
suggested_sampling = "Do nothing"
264+
264265
self._build_recommendation(
265266
recommendations,
266267
"fix_imbalance",

tests/unitary/with_extras/dataset/test_dataset_target.py

Lines changed: 53 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,24 +3,31 @@
33
# Copyright (c) 2023 Oracle and/or its affiliates.
44
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
55

6-
import os
6+
from mock import patch
77
from typing import Tuple
8+
import os
89
import pandas as pd
910
import pytest
11+
12+
from ads.common import utils
1013
from ads.dataset.classification_dataset import BinaryClassificationDataset
1114
from ads.dataset.dataset_with_target import ADSDatasetWithTarget
1215
from ads.dataset.pipeline import TransformerPipeline
1316
from ads.dataset.target import TargetVariable
1417

1518

1619
class TestADSDatasetTarget:
20+
def get_data_path(self):
21+
current_dir = os.path.dirname(os.path.abspath(__file__))
22+
return os.path.join(current_dir, "data", "orcl_attrition.csv")
23+
1724
def test_initialize_dataset_target(self):
1825
employees = ADSDatasetWithTarget(
1926
df=pd.read_csv(self.get_data_path()),
2027
target="Attrition",
2128
name="test_dataset",
2229
description="test_description",
23-
storage_options={'config':{},'region':'us-ashburn-1'}
30+
storage_options={"config": {}, "region": "us-ashburn-1"},
2431
)
2532

2633
assert isinstance(employees, ADSDatasetWithTarget)
@@ -32,8 +39,8 @@ def test_dataset_target_from_dataframe(self):
3239
employees = ADSDatasetWithTarget.from_dataframe(
3340
df=pd.read_csv(self.get_data_path()),
3441
target="Attrition",
35-
storage_options={'config':{},'region':'us-ashburn-1'}
36-
).set_positive_class('Yes')
42+
storage_options={"config": {}, "region": "us-ashburn-1"},
43+
).set_positive_class("Yes")
3744

3845
assert isinstance(employees, BinaryClassificationDataset)
3946
self.assert_dataset(employees)
@@ -65,6 +72,45 @@ def assert_dataset(self, dataset):
6572
assert "type_discovery" in dataset.init_kwargs
6673
assert isinstance(dataset.transformer_pipeline, TransformerPipeline)
6774

68-
def get_data_path(self):
69-
current_dir = os.path.dirname(os.path.abspath(__file__))
70-
return os.path.join(current_dir, "data", "orcl_attrition.csv")
75+
def test_seggested_sampling_for_imbalanced_dataset(self):
76+
employees = ADSDatasetWithTarget.from_dataframe(
77+
df=pd.read_csv(self.get_data_path()),
78+
target="Attrition",
79+
).set_positive_class("Yes")
80+
81+
rt = employees._get_recommendations_transformer(
82+
fix_imbalance=True, correlation_threshold=1
83+
)
84+
rt.fit(employees)
85+
86+
## Assert with default setup for thresholds MAX_LEN_FOR_UP_SAMPLING and MIN_RATIO_FOR_DOWN_SAMPLING
87+
assert utils.MAX_LEN_FOR_UP_SAMPLING == 5000
88+
assert utils.MIN_RATIO_FOR_DOWN_SAMPLING == 1 / 20
89+
90+
assert (
91+
rt.reco_dict_["fix_imbalance"]["Attrition"]["Message"]
92+
== "Imbalanced Target(33.33%)"
93+
)
94+
# up-sample if length of dataframe is less than or equal to MAX_LEN_FOR_UP_SAMPLING
95+
assert len(employees) < utils.MAX_LEN_FOR_UP_SAMPLING
96+
assert (
97+
rt.reco_dict_["fix_imbalance"]["Attrition"]["Selected Action"]
98+
== "Up-sample"
99+
)
100+
101+
# manipulate MAX_LEN_FOR_UP_SAMPLING, MIN_RATIO_FOR_DOWN_SAMPLING to get other recommendations
102+
with patch("ads.common.utils.MAX_LEN_FOR_UP_SAMPLING", 5):
103+
assert utils.MAX_LEN_FOR_UP_SAMPLING == 5
104+
rt.fit(employees)
105+
# expect down-sample suggested, because minor_majority_ratio is greater than MIN_RATIO_FOR_DOWN_SAMPLING
106+
assert (
107+
rt.reco_dict_["fix_imbalance"]["Attrition"]["Selected Action"]
108+
== "Down-sample"
109+
)
110+
with patch("ads.common.utils.MIN_RATIO_FOR_DOWN_SAMPLING", 0.35):
111+
rt.fit(employees)
112+
# expect "Do nothing" with both MAX_LEN_FOR_UP_SAMPLING, MIN_RATIO_FOR_DOWN_SAMPLING tweaked for sampled_df
113+
assert (
114+
rt.reco_dict_["fix_imbalance"]["Attrition"]["Selected Action"]
115+
== "Do nothing"
116+
)

0 commit comments

Comments
 (0)