Skip to content

Commit a4f76fe

Browse files
committed
ODSC-47682: Adding unit/int tests for statistics
1 parent b72b5fe commit a4f76fe

15 files changed

+309
-71
lines changed

ads/feature_store/statistics/charts/abstract_feature_stat.py renamed to ads/feature_store/statistics/abs_feature_value.py

Lines changed: 2 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -16,40 +16,23 @@
1616
)
1717

1818

19-
class AbsFeatureStat:
20-
class ValidationFailedException(Exception):
21-
def __init__(self):
22-
pass
23-
19+
class AbsFeatureValue:
2420
def __init__(self):
2521
self.__validate__()
2622

2723
@abstractmethod
2824
def __validate__(self):
2925
pass
3026

31-
@abstractmethod
32-
def add_to_figure(self, fig: Figure, xaxis: int, yaxis: int):
33-
pass
34-
3527
@classmethod
3628
@abstractmethod
3729
def __from_json__(cls, json_dict: dict):
3830
pass
3931

40-
@staticmethod
41-
def get_x_y_str_axes(xaxis: int, yaxis: int) -> ():
42-
return (
43-
("xaxis" + str(xaxis + 1)),
44-
("yaxis" + str(yaxis + 1)),
45-
("x" + str(xaxis + 1)),
46-
("y" + str(yaxis + 1)),
47-
)
48-
4932
@classmethod
5033
def from_json(
5134
cls, json_dict: dict, ignore_errors: bool = False
52-
) -> Union["AbsFeatureStat", None]:
35+
) -> Union["AbsFeatureValue", None]:
5336
try:
5437
return cls.__from_json__(json_dict=json_dict)
5538
except Exception as e:
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
#!/usr/bin/env python
2+
# -*- coding: utf-8; -*-
3+
# Copyright (c) 2023 Oracle and/or its affiliates.
4+
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
5+
import abc
6+
7+
from abc import abstractmethod
8+
from ads.feature_store.statistics.abs_feature_value import AbsFeatureValue
9+
from ads.common.decorator.runtime_dependency import OptionalDependency
10+
11+
try:
12+
from plotly.graph_objs import Figure
13+
except ModuleNotFoundError:
14+
raise ModuleNotFoundError(
15+
f"The `plotly` module was not found. Please run `pip install "
16+
f"{OptionalDependency.FEATURE_STORE}`."
17+
)
18+
19+
20+
class AbsFeaturePlot(abc.ABC, AbsFeatureValue):
21+
@abstractmethod
22+
def add_to_figure(self, fig: Figure, xaxis: int, yaxis: int):
23+
pass
24+
25+
@classmethod
26+
@abstractmethod
27+
def __from_json__(cls, json_dict: dict):
28+
pass
29+
30+
@staticmethod
31+
def get_x_y_str_axes(xaxis: int, yaxis: int) -> ():
32+
return (
33+
("xaxis" + str(xaxis + 1)),
34+
("yaxis" + str(yaxis + 1)),
35+
("x" + str(xaxis + 1)),
36+
("y" + str(yaxis + 1)),
37+
)

ads/feature_store/statistics/charts/box_plot.py

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,12 @@
22
# -*- coding: utf-8; -*-
33
# Copyright (c) 2023 Oracle and/or its affiliates.
44
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
5-
from typing import List
5+
from typing import List, Union
6+
7+
from ads.feature_store.statistics.abs_feature_value import AbsFeatureValue
68

79
from ads.common.decorator.runtime_dependency import OptionalDependency
8-
from ads.feature_store.statistics.charts.abstract_feature_stat import AbsFeatureStat
10+
from ads.feature_store.statistics.charts.abstract_feature_plot import AbsFeaturePlot
911
from ads.feature_store.statistics.charts.frequency_distribution import (
1012
FrequencyDistribution,
1113
)
@@ -20,7 +22,7 @@
2022
)
2123

2224

23-
class BoxPlot(AbsFeatureStat):
25+
class BoxPlot(AbsFeaturePlot):
2426
CONST_MIN = "Min"
2527
CONST_MAX = "Max"
2628
CONST_QUARTILES = "Quartiles"
@@ -30,7 +32,7 @@ class BoxPlot(AbsFeatureStat):
3032
CONST_IQR = "IQR"
3133
CONST_BOX_POINTS = "box_points"
3234

33-
class Quartiles:
35+
class Quartiles(AbsFeatureValue):
3436
CONST_Q1 = "q1"
3537
CONST_Q2 = "q2"
3638
CONST_Q3 = "q3"
@@ -39,15 +41,20 @@ def __init__(self, q1: float, q2: float, q3: float):
3941
self.q1 = q1
4042
self.q2 = q2
4143
self.q3 = q3
44+
super().__init__()
4245

4346
@classmethod
44-
def from_json(cls, json_dict: dict) -> "BoxPlot.Quartiles":
47+
def __from_json__(cls, json_dict: dict) -> "BoxPlot.Quartiles":
4548
return cls(
4649
json_dict.get(cls.CONST_Q1),
4750
json_dict.get(cls.CONST_Q2),
4851
json_dict.get(cls.CONST_Q3),
4952
)
5053

54+
def __validate__(self):
55+
assert type(self.q1) == type(self.q2) == type(self.q3) == int or float
56+
assert self.q3 >= self.q2 >= self.q1
57+
5158
def __init__(
5259
self,
5360
mean: float,
@@ -67,14 +74,14 @@ def __init__(
6774
super().__init__()
6875

6976
def __validate__(self):
70-
if (
71-
self.q1 is None
72-
or self.q3 is None
73-
or self.iqr is None
74-
or type(self.box_points) is not list
75-
or len(self.box_points) == 0
76-
):
77-
return self.ValidationFailedException()
77+
assert self.q1 is not None
78+
assert self.q3 is not None
79+
assert self.iqr is not None
80+
assert self.q3 is not None
81+
assert self.median is not None
82+
assert self.mean is not None
83+
assert type(self.box_points) is list
84+
assert len(self.box_points) > 0
7885

7986
def add_to_figure(self, fig: Figure, xaxis: int, yaxis: int):
8087
xaxis_str, yaxis_str, x_str, y_str = self.get_x_y_str_axes(xaxis, yaxis)

ads/feature_store/statistics/charts/frequency_distribution.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
from typing import List
77
from ads.common.decorator.runtime_dependency import OptionalDependency
8-
from ads.feature_store.statistics.charts.abstract_feature_stat import AbsFeatureStat
8+
from ads.feature_store.statistics.charts.abstract_feature_plot import AbsFeaturePlot
99

1010
try:
1111
from plotly.graph_objs import Figure
@@ -16,24 +16,21 @@
1616
)
1717

1818

19-
class FrequencyDistribution(AbsFeatureStat):
19+
class FrequencyDistribution(AbsFeaturePlot):
2020
CONST_FREQUENCY = "frequency"
2121
CONST_BINS = "bins"
2222
CONST_FREQUENCY_DISTRIBUTION_TITLE = "Frequency Distribution"
2323

24-
def __validate__(self):
25-
if not (
26-
type(self.frequency) == list
27-
and type(self.bins) == list
28-
and 0 < len(self.frequency) == len(self.bins) > 0
29-
):
30-
raise self.ValidationFailedException()
31-
3224
def __init__(self, frequency: List, bins: List):
3325
self.frequency = frequency
3426
self.bins = bins
3527
super().__init__()
3628

29+
def __validate__(self):
30+
assert type(self.frequency) == list
31+
assert type(self.bins) == list
32+
assert 0 < len(self.frequency) == len(self.bins) > 0
33+
3734
@classmethod
3835
def __from_json__(cls, json_dict: dict) -> "FrequencyDistribution":
3936
return FrequencyDistribution(

ads/feature_store/statistics/charts/probability_distribution.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from typing import List
66

77
from ads.common.decorator.runtime_dependency import OptionalDependency
8-
from ads.feature_store.statistics.charts.abstract_feature_stat import AbsFeatureStat
8+
from ads.feature_store.statistics.charts.abstract_feature_plot import AbsFeaturePlot
99

1010
try:
1111
from plotly.graph_objs import Figure
@@ -16,15 +16,7 @@
1616
)
1717

1818

19-
class ProbabilityDistribution(AbsFeatureStat):
20-
def __validate__(self):
21-
if not (
22-
type(self.density) == list
23-
and type(self.bins) == list
24-
and 0 < len(self.density) == len(self.bins) > 0
25-
):
26-
raise self.ValidationFailedException()
27-
19+
class ProbabilityDistribution(AbsFeaturePlot):
2820
CONST_DENSITY = "density"
2921
CONST_BINS = "bins"
3022
CONST_PROBABILITY_DISTRIBUTION_TITLE = "Probability Distribution"
@@ -34,6 +26,11 @@ def __init__(self, density: List, bins: List):
3426
self.bins = bins
3527
super().__init__()
3628

29+
def __validate__(self):
30+
assert type(self.density) == list
31+
assert type(self.bins) == list
32+
assert 0 < len(self.density) == len(self.bins) > 0
33+
3734
@classmethod
3835
def __from_json__(cls, json_dict: dict) -> "ProbabilityDistribution":
3936
return cls(

ads/feature_store/statistics/charts/top_k_frequent_elements.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,12 @@
33
# Copyright (c) 2023 Oracle and/or its affiliates.
44
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
55
from typing import List
6+
7+
from ads.feature_store.statistics.abs_feature_value import AbsFeatureValue
8+
69
from ads.common.decorator.runtime_dependency import OptionalDependency
710

8-
from ads.feature_store.statistics.charts.abstract_feature_stat import AbsFeatureStat
11+
from ads.feature_store.statistics.charts.abstract_feature_plot import AbsFeaturePlot
912

1013
try:
1114
from plotly.graph_objs import Figure
@@ -16,15 +19,11 @@
1619
)
1720

1821

19-
class TopKFrequentElements(AbsFeatureStat):
20-
def __validate__(self):
21-
if not (type(self.elements) == list and len(self.elements) > 0):
22-
raise self.ValidationFailedException
23-
22+
class TopKFrequentElements(AbsFeaturePlot):
2423
CONST_VALUE = "value"
2524
CONST_TOP_K_FREQUENT_TITLE = "Top K Frequent Elements"
2625

27-
class TopKFrequentElement:
26+
class TopKFrequentElement(AbsFeatureValue):
2827
CONST_VALUE = "value"
2928
CONST_ESTIMATE = "estimate"
3029
CONST_LOWER_BOUND = "lower_bound"
@@ -37,9 +36,14 @@ def __init__(
3736
self.estimate = estimate
3837
self.lower_bound = lower_bound
3938
self.upper_bound = upper_bound
39+
super().__init__()
40+
41+
def __validate__(self):
42+
assert type(self.value) == str and len(self.value) > 0
43+
assert type(self.estimate) == int and self.estimate >= 0
4044

4145
@classmethod
42-
def from_json(
46+
def __from_json__(
4347
cls, json_dict: dict
4448
) -> "TopKFrequentElements.TopKFrequentElement":
4549
return cls(
@@ -53,6 +57,12 @@ def __init__(self, elements: List[TopKFrequentElement]):
5357
self.elements = elements
5458
super().__init__()
5559

60+
def __validate__(self):
61+
assert type(self.elements) == list
62+
assert len(self.elements) > 0
63+
for element in self.elements:
64+
assert element is not None
65+
5666
@classmethod
5767
def __from_json__(cls, json_dict: dict) -> "TopKFrequentElements":
5868
elements = json_dict.get(cls.CONST_VALUE)

ads/feature_store/statistics/feature_stat.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
from ads.common.decorator.runtime_dependency import OptionalDependency
77
from typing import List
8-
from ads.feature_store.statistics.charts.abstract_feature_stat import AbsFeatureStat
8+
from ads.feature_store.statistics.charts.abstract_feature_plot import AbsFeaturePlot
99
from ads.feature_store.statistics.charts.box_plot import BoxPlot
1010
from ads.feature_store.statistics.charts.frequency_distribution import (
1111
FrequencyDistribution,
@@ -79,7 +79,7 @@ def from_json(cls, feature_name: str, json_dict: dict) -> "FeatureStatistics":
7979
return cls(feature_name)
8080

8181
@property
82-
def __feature_stat_objects__(self) -> List[AbsFeatureStat]:
82+
def __feature_stat_objects__(self) -> List[AbsFeaturePlot]:
8383
return [
8484
stat
8585
for stat in [
@@ -101,14 +101,15 @@ def next_graph_position_generator():
101101
if len(self.__feature_stat_objects__) > 0:
102102
fig = make_subplots(cols=3, column_titles=[" "] * 3)
103103
for idx, stat in zip(
104-
next_graph_position_generator(),
105-
[stat for stat in self.__feature_stat_objects__ if stat is not None],
104+
next_graph_position_generator(), self.__feature_stat_objects__
106105
):
107106
stat.add_to_figure(fig, idx, idx)
108107

109108
fig.layout.title = self.CONST_TITLE_FORMAT.format(self.feature_name)
110109
fig.update_layout(title_font_size=20)
110+
# Center align the title
111111
fig.update_layout(title_x=0.5)
112+
# Disable legend for unrelated plots
112113
fig.update_layout(showlegend=False)
113114
plotly.offline.iplot(
114115
fig,
Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,24 @@
11
#!/usr/bin/env python
22
# -*- coding: utf-8; -*-
3+
from ads.feature_store.statistics.abs_feature_value import AbsFeatureValue
34

45

56
# Copyright (c) 2023 Oracle and/or its affiliates.
67
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
7-
class GenericFeatureValue:
8+
class GenericFeatureValue(AbsFeatureValue):
89
CONST_VALUE = "value"
910

1011
def __init__(self, val: any):
1112
self.val = val
13+
super().__init__()
14+
15+
def __validate__(self):
16+
pass
1217

1318
@classmethod
14-
def from_json(cls, json_dict: dict) -> "GenericFeatureValue":
15-
return GenericFeatureValue(
16-
val=json_dict.get(cls.CONST_VALUE),
17-
)
19+
def __from_json__(cls, json_dict: dict) -> "GenericFeatureValue":
20+
val = None
21+
if type(json_dict) == dict:
22+
val = json_dict.get(cls.CONST_VALUE)
23+
24+
return GenericFeatureValue(val=val)

tests/integration/feature_store/test_dataset_statistics.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,8 @@ def test_dataset_statistics_with_default_config(self):
6767
stat_obj = dataset.get_statistics()
6868
assert stat_obj is not None
6969
assert len(stat_obj.to_pandas().columns) == 6
70-
70+
# Validate visualisation is possible
71+
dataset.get_statistics().to_viz()
7172
self.clean_up_dataset(dataset)
7273
self.clean_up_feature_group(fg)
7374
self.clean_up_entity(entity)

tests/integration/feature_store/test_feature_group_statistics.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,8 @@ def test_feature_group_statistics_with_default_config(self):
6161
stat_obj = fg.get_statistics()
6262
assert stat_obj.content is not None
6363
assert len(stat_obj.to_pandas().columns) == 6
64-
64+
# Validate visualisation is possible
65+
fg.get_statistics().to_viz()
6566
self.clean_up_feature_group(fg)
6667
self.clean_up_entity(entity)
6768
self.clean_up_feature_store(fs)

0 commit comments

Comments
 (0)