Skip to content

Commit 12a8d06

Browse files
authored
Modified ADS dataset class (#162)
2 parents 0b18dfe + 18f4bb5 commit 12a8d06

File tree

15 files changed

+411
-73
lines changed

15 files changed

+411
-73
lines changed

ads/dataset/classification_dataset.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#!/usr/bin/env python
22
# -*- coding: utf-8; -*-
33

4-
# Copyright (c) 2020, 2022 Oracle and/or its affiliates.
4+
# Copyright (c) 2020, 2023 Oracle and/or its affiliates.
55
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
66

77
import pandas as pd
@@ -22,7 +22,7 @@ class ClassificationDataset(ADSDatasetWithTarget):
2222

2323
def __init__(self, df, sampled_df, target, target_type, shape, **kwargs):
2424
ADSDatasetWithTarget.__init__(
25-
self, df, sampled_df, target, target_type, shape, **kwargs
25+
self, df=df, sampled_df=sampled_df, target=target, target_type=target_type, shape=shape, **kwargs
2626
)
2727

2828
def auto_transform(

ads/dataset/dataset.py

Lines changed: 45 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#!/usr/bin/env python
22
# -*- coding: utf-8 -*--
33

4-
# Copyright (c) 2020, 2022 Oracle and/or its affiliates.
4+
# Copyright (c) 2020, 2023 Oracle and/or its affiliates.
55
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
66

77
from __future__ import print_function, absolute_import, division
@@ -16,7 +16,7 @@
1616

1717
from collections import Counter
1818
from sklearn.preprocessing import FunctionTransformer
19-
from typing import Iterable, Union
19+
from typing import Iterable, Tuple, Union
2020

2121
from ads import set_documentation_mode
2222
from ads.common import utils
@@ -71,8 +71,8 @@ class ADSDataset(PandasDataset):
7171
def __init__(
7272
self,
7373
df,
74-
sampled_df,
75-
shape,
74+
sampled_df=None,
75+
shape=None,
7676
name="",
7777
description=None,
7878
type_discovery=True,
@@ -88,6 +88,17 @@ def __init__(
8888
# to keep performance high and linear no matter the size of the distributed dataset we
8989
# create a pandas df that's used internally because this has a fixed upper size.
9090
#
91+
if shape is None:
92+
shape = df.shape
93+
94+
if sampled_df is None:
95+
sampled_df = generate_sample(
96+
df,
97+
shape[0],
98+
DatasetDefaults.sampling_confidence_level,
99+
DatasetDefaults.sampling_confidence_interval,
100+
**kwargs,
101+
)
91102
super().__init__(
92103
sampled_df,
93104
type_discovery=type_discovery,
@@ -134,6 +145,36 @@ def __repr__(self):
134145
def __len__(self):
135146
return self.shape[0]
136147

148+
@staticmethod
149+
def from_dataframe(
150+
df,
151+
sampled_df=None,
152+
shape=None,
153+
name="",
154+
description=None,
155+
type_discovery=True,
156+
types={},
157+
metadata=None,
158+
progress=DummyProgressBar(),
159+
transformer_pipeline=None,
160+
interactive=False,
161+
**kwargs,
162+
) -> "ADSDataset":
163+
return ADSDataset(
164+
df=df,
165+
sampled_df=sampled_df,
166+
shape=shape,
167+
name=name,
168+
description=description,
169+
type_discovery=type_discovery,
170+
types=types,
171+
metadata=metadata,
172+
progress=progress,
173+
transformer_pipeline=transformer_pipeline,
174+
interactive=interactive,
175+
**kwargs
176+
)
177+
137178
@property
138179
@deprecated(
139180
"2.5.2", details="The ddf attribute is deprecated. Use the df attribute."

ads/dataset/dataset_with_target.py

Lines changed: 170 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#!/usr/bin/env python
22
# -*- coding: utf-8; -*-
33

4-
# Copyright (c) 2020, 2022 Oracle and/or its affiliates.
4+
# Copyright (c) 2020, 2023 Oracle and/or its affiliates.
55
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
66

77
from __future__ import absolute_import, print_function
@@ -10,7 +10,7 @@
1010
import importlib
1111
from collections import defaultdict
1212
from numbers import Number
13-
from typing import Union
13+
from typing import Tuple, Union
1414

1515
import pandas as pd
1616
from ads.common import utils, logger
@@ -23,14 +23,30 @@
2323
from ads.dataset.dataset import ADSDataset
2424
from ads.dataset.feature_engineering_transformer import FeatureEngineeringTransformer
2525
from ads.dataset.feature_selection import FeatureImportance
26-
from ads.dataset.helper import deprecate_default_value, deprecate_variable
26+
from ads.dataset.helper import (
27+
DatasetDefaults,
28+
deprecate_default_value,
29+
deprecate_variable,
30+
generate_sample,
31+
get_target_type,
32+
is_text_data,
33+
)
2734
from ads.dataset.label_encoder import DataFrameLabelEncoder
2835
from ads.dataset.pipeline import TransformerPipeline
2936
from ads.dataset.progress import DummyProgressBar
3037
from ads.dataset.recommendation import Recommendation
3138
from ads.dataset.recommendation_transformer import RecommendationTransformer
3239
from ads.dataset.target import TargetVariable
33-
from ads.type_discovery.typed_feature import DateTimeTypedFeature
40+
from ads.type_discovery.typed_feature import (
41+
CategoricalTypedFeature,
42+
ContinuousTypedFeature,
43+
DocumentTypedFeature,
44+
GISTypedFeature,
45+
OrdinalTypedFeature,
46+
TypedFeature,
47+
DateTimeTypedFeature,
48+
TypedFeature
49+
)
3450
from sklearn.model_selection import train_test_split
3551
from pandas.io.formats.printing import pprint_thing
3652
from sklearn.preprocessing import FunctionTransformer
@@ -45,10 +61,10 @@ class ADSDatasetWithTarget(ADSDataset, metaclass=ABCMeta):
4561
def __init__(
4662
self,
4763
df,
48-
sampled_df,
4964
target,
50-
target_type,
51-
shape,
65+
sampled_df=None,
66+
shape=None,
67+
target_type=None,
5268
sample_max_rows=-1,
5369
type_discovery=True,
5470
types={},
@@ -61,6 +77,16 @@ def __init__(
6177
**kwargs,
6278
):
6379
self.recommendation_transformer = None
80+
if shape is None:
81+
shape = df.shape
82+
if sampled_df is None:
83+
sampled_df = generate_sample(
84+
df,
85+
shape[0],
86+
DatasetDefaults.sampling_confidence_level,
87+
DatasetDefaults.sampling_confidence_interval,
88+
**kwargs,
89+
)
6490

6591
if parent is None:
6692
cols = sampled_df.columns.tolist()
@@ -135,6 +161,8 @@ def __init__(
135161
cols.insert(0, cols.pop(cols.index(target)))
136162
self.sampled_df = self.sampled_df[[*cols]]
137163

164+
if target_type is None:
165+
target_type = get_target_type(target, sampled_df, **kwargs)
138166
self.target = TargetVariable(self, target, target_type)
139167

140168
# remove target from type discovery conversion
@@ -145,6 +173,141 @@ def __init__(
145173
):
146174
step[1].kw_args["dtypes"].pop(self.target.name)
147175

176+
@staticmethod
177+
def from_dataframe(
178+
df: pd.DataFrame,
179+
target: str,
180+
sampled_df: pd.DataFrame = None,
181+
shape: Tuple[int, int] = None,
182+
target_type: TypedFeature = None,
183+
positive_class=None,
184+
**init_kwargs,
185+
):
186+
from ads.dataset.classification_dataset import (
187+
BinaryClassificationDataset,
188+
BinaryTextClassificationDataset,
189+
MultiClassClassificationDataset,
190+
MultiClassTextClassificationDataset
191+
)
192+
from ads.dataset.forecasting_dataset import ForecastingDataset
193+
from ads.dataset.regression_dataset import RegressionDataset
194+
195+
if sampled_df is None:
196+
sampled_df = generate_sample(
197+
df,
198+
(shape or df.shape)[0],
199+
DatasetDefaults.sampling_confidence_level,
200+
DatasetDefaults.sampling_confidence_interval,
201+
**init_kwargs,
202+
)
203+
204+
if target_type is None:
205+
target_type = get_target_type(target, sampled_df, **init_kwargs)
206+
207+
if len(df[target].dropna()) == 0:
208+
logger.warning(
209+
"It is not recommended to use an empty column as the target variable."
210+
)
211+
raise ValueError(
212+
f"We do not support using empty columns as the chosen target"
213+
)
214+
if utils.is_same_class(target_type, ContinuousTypedFeature):
215+
return RegressionDataset(
216+
df=df,
217+
sampled_df=sampled_df,
218+
target=target,
219+
target_type=target_type,
220+
shape=shape,
221+
**init_kwargs,
222+
)
223+
elif utils.is_same_class(
224+
target_type, DateTimeTypedFeature
225+
) or df.index.dtype.name.startswith("datetime"):
226+
return ForecastingDataset(
227+
df=df,
228+
sampled_df=sampled_df,
229+
target=target,
230+
target_type=target_type,
231+
shape=shape,
232+
**init_kwargs,
233+
)
234+
235+
# Adding ordinal typed feature, but ultimately we should rethink how we want to model this type
236+
elif utils.is_same_class(target_type, CategoricalTypedFeature) or utils.is_same_class(
237+
target_type, OrdinalTypedFeature
238+
):
239+
if target_type.meta_data["internal"]["unique"] == 2:
240+
if is_text_data(sampled_df, target):
241+
return BinaryTextClassificationDataset(
242+
df=df,
243+
sampled_df=sampled_df,
244+
target=target,
245+
shape=shape,
246+
target_type=target_type,
247+
positive_class=positive_class,
248+
**init_kwargs,
249+
)
250+
251+
return BinaryClassificationDataset(
252+
df=df,
253+
sampled_df=sampled_df,
254+
target=target,
255+
shape=shape,
256+
target_type=target_type,
257+
positive_class=positive_class,
258+
**init_kwargs,
259+
)
260+
else:
261+
if is_text_data(sampled_df, target):
262+
return MultiClassTextClassificationDataset(
263+
df=df,
264+
sampled_df=sampled_df,
265+
target=target,
266+
target_type=target_type,
267+
shape=shape,
268+
**init_kwargs,
269+
)
270+
return MultiClassClassificationDataset(
271+
df=df,
272+
sampled_df=sampled_df,
273+
target=target,
274+
target_type=target_type,
275+
shape=shape,
276+
**init_kwargs,
277+
)
278+
elif (
279+
utils.is_same_class(target, DocumentTypedFeature)
280+
or "text" in target_type["type"]
281+
or "text" in target
282+
):
283+
raise ValueError(
284+
f"The column {target} cannot be used as the target column."
285+
)
286+
elif (
287+
utils.is_same_class(target_type, GISTypedFeature)
288+
or "coord" in target_type["type"]
289+
or "coord" in target
290+
):
291+
raise ValueError(
292+
f"The column {target} cannot be used as the target column."
293+
)
294+
# This is to catch constant columns that are boolean. Added as a fix for pd.isnull(), and datasets with a
295+
# binary target, but only data on one instance
296+
elif target_type and target_type["low_level_type"] == "bool":
297+
return BinaryClassificationDataset(
298+
df=df,
299+
sampled_df=sampled_df,
300+
target=target,
301+
shape=shape,
302+
target_type=target_type,
303+
positive_class=positive_class,
304+
**init_kwargs,
305+
)
306+
raise ValueError(
307+
f"Unable to identify problem type. Specify the data type of {target} using 'types'. "
308+
f"For example, types = {{{target}: 'category'}}"
309+
)
310+
148311
def rename_columns(self, columns):
149312
"""
150313
Returns a dataset with columns renamed.

ads/dataset/forecasting_dataset.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ def __init__(self, df, sampled_df, target, target_type, shape, **kwargs):
1414
if isinstance(target, DateTimeTypedFeature):
1515
df = df.set_index(target)
1616
ADSDatasetWithTarget.__init__(
17-
self, df, sampled_df, target, target_type, shape, **kwargs
17+
self, df=df, sampled_df=sampled_df, target=target, target_type=target_type, shape=shape, **kwargs
1818
)
1919

2020
def select_best_features(self, score_func=None, k=12):

ads/dataset/helper.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#!/usr/bin/env python
22
# -*- coding: utf-8; -*-
33

4-
# Copyright (c) 2020, 2022 Oracle and/or its affiliates.
4+
# Copyright (c) 2020, 2023 Oracle and/or its affiliates.
55
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
66

77
import ast
@@ -832,3 +832,20 @@ def _log_yscale_not_set():
832832
logger.info(
833833
"`yscale` parameter is not set. Valid values are `'linear'`, `'log'`, `'symlog'`."
834834
)
835+
836+
def infer_target_type(target, target_series, discover_target_type=True):
837+
# if type discovery is turned off, infer type from pandas dtype
838+
if discover_target_type:
839+
target_type = TypeDiscoveryDriver().discover(
840+
target, target_series, is_target=True
841+
)
842+
else:
843+
target_type = get_feature_type(target, target_series)
844+
return target_type
845+
846+
def get_target_type(target, sampled_df, **init_kwargs):
847+
discover_target_type = init_kwargs.get("type_discovery", True)
848+
if target in init_kwargs.get("types", {}):
849+
sampled_df[target] = sampled_df[target].astype(init_kwargs.get("types")[target])
850+
discover_target_type = False
851+
return infer_target_type(target, sampled_df[target], discover_target_type)

ads/dataset/regression_dataset.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,5 +10,5 @@
1010
class RegressionDataset(ADSDatasetWithTarget):
1111
def __init__(self, df, sampled_df, target, target_type, shape, **kwargs):
1212
ADSDatasetWithTarget.__init__(
13-
self, df, sampled_df, target, target_type, shape, **kwargs
13+
self, df=df, sampled_df=sampled_df, target=target, target_type=target_type, shape=shape, **kwargs
1414
)

0 commit comments

Comments
 (0)