3
3
# Copyright (c) 2023 Oracle and/or its affiliates.
4
4
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
5
5
6
- import os
6
+ from mock import patch
7
7
from typing import Tuple
8
+ import os
8
9
import pandas as pd
9
10
import pytest
11
+
12
+ from ads .common import utils
10
13
from ads .dataset .classification_dataset import BinaryClassificationDataset
11
14
from ads .dataset .dataset_with_target import ADSDatasetWithTarget
12
15
from ads .dataset .pipeline import TransformerPipeline
13
16
from ads .dataset .target import TargetVariable
14
17
15
18
16
19
class TestADSDatasetTarget :
20
+ def get_data_path (self ):
21
+ current_dir = os .path .dirname (os .path .abspath (__file__ ))
22
+ return os .path .join (current_dir , "data" , "orcl_attrition.csv" )
23
+
17
24
def test_initialize_dataset_target (self ):
18
25
employees = ADSDatasetWithTarget (
19
26
df = pd .read_csv (self .get_data_path ()),
20
27
target = "Attrition" ,
21
28
name = "test_dataset" ,
22
29
description = "test_description" ,
23
- storage_options = {' config' : {},' region' : ' us-ashburn-1' }
30
+ storage_options = {" config" : {}, " region" : " us-ashburn-1" },
24
31
)
25
32
26
33
assert isinstance (employees , ADSDatasetWithTarget )
@@ -32,8 +39,8 @@ def test_dataset_target_from_dataframe(self):
32
39
employees = ADSDatasetWithTarget .from_dataframe (
33
40
df = pd .read_csv (self .get_data_path ()),
34
41
target = "Attrition" ,
35
- storage_options = {' config' : {},' region' : ' us-ashburn-1' }
36
- ).set_positive_class (' Yes' )
42
+ storage_options = {" config" : {}, " region" : " us-ashburn-1" },
43
+ ).set_positive_class (" Yes" )
37
44
38
45
assert isinstance (employees , BinaryClassificationDataset )
39
46
self .assert_dataset (employees )
@@ -65,6 +72,45 @@ def assert_dataset(self, dataset):
65
72
assert "type_discovery" in dataset .init_kwargs
66
73
assert isinstance (dataset .transformer_pipeline , TransformerPipeline )
67
74
68
- def get_data_path (self ):
69
- current_dir = os .path .dirname (os .path .abspath (__file__ ))
70
- return os .path .join (current_dir , "data" , "orcl_attrition.csv" )
75
+ def test_seggested_sampling_for_imbalanced_dataset (self ):
76
+ employees = ADSDatasetWithTarget .from_dataframe (
77
+ df = pd .read_csv (self .get_data_path ()),
78
+ target = "Attrition" ,
79
+ ).set_positive_class ("Yes" )
80
+
81
+ rt = employees ._get_recommendations_transformer (
82
+ fix_imbalance = True , correlation_threshold = 1
83
+ )
84
+ rt .fit (employees )
85
+
86
+ ## Assert with default setup for thresholds MAX_LEN_FOR_UP_SAMPLING and MIN_RATIO_FOR_DOWN_SAMPLING
87
+ assert utils .MAX_LEN_FOR_UP_SAMPLING == 5000
88
+ assert utils .MIN_RATIO_FOR_DOWN_SAMPLING == 1 / 20
89
+
90
+ assert (
91
+ rt .reco_dict_ ["fix_imbalance" ]["Attrition" ]["Message" ]
92
+ == "Imbalanced Target(33.33%)"
93
+ )
94
+ # up-sample if length of dataframe is less than or equal to MAX_LEN_FOR_UP_SAMPLING
95
+ assert len (employees ) < utils .MAX_LEN_FOR_UP_SAMPLING
96
+ assert (
97
+ rt .reco_dict_ ["fix_imbalance" ]["Attrition" ]["Selected Action" ]
98
+ == "Up-sample"
99
+ )
100
+
101
+ # manipulate MAX_LEN_FOR_UP_SAMPLING, MIN_RATIO_FOR_DOWN_SAMPLING to get other recommendations
102
+ with patch ("ads.common.utils.MAX_LEN_FOR_UP_SAMPLING" , 5 ):
103
+ assert utils .MAX_LEN_FOR_UP_SAMPLING == 5
104
+ rt .fit (employees )
105
+ # expect down-sample suggested, because minor_majority_ratio is greater than MIN_RATIO_FOR_DOWN_SAMPLING
106
+ assert (
107
+ rt .reco_dict_ ["fix_imbalance" ]["Attrition" ]["Selected Action" ]
108
+ == "Down-sample"
109
+ )
110
+ with patch ("ads.common.utils.MIN_RATIO_FOR_DOWN_SAMPLING" , 0.35 ):
111
+ rt .fit (employees )
112
+ # expect "Do nothing" with both MAX_LEN_FOR_UP_SAMPLING, MIN_RATIO_FOR_DOWN_SAMPLING tweaked for sampled_df
113
+ assert (
114
+ rt .reco_dict_ ["fix_imbalance" ]["Attrition" ]["Selected Action" ]
115
+ == "Do nothing"
116
+ )
0 commit comments