Skip to content

Commit a14b4e3

Browse files
committed
Fixing integration tests
1 parent c552745 commit a14b4e3

File tree

5 files changed

+123
-20
lines changed

5 files changed

+123
-20
lines changed

ads/feature_store/dataset.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -865,10 +865,11 @@ def _update_from_oci_dataset_model(self, oci_dataset: OCIDataset) -> "Dataset":
865865
features_list.append(output_feature)
866866

867867
value = {self.CONST_ITEMS: features_list}
868-
else:
868+
elif infra_attr == self.CONST_FEATURE_GROUP:
869869
value = getattr(self.oci_dataset, dsc_attr)
870+
else:
871+
value = dataset_details[infra_attr]
870872
self.set_spec(infra_attr, value)
871-
872873
return self
873874

874875
def materialise(

ads/feature_store/execution_strategy/spark/spark_execution.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
#!/usr/bin/env python
22
# -*- coding: utf-8; -*-
3-
import json
4-
53
# Copyright (c) 2023 Oracle and/or its affiliates.
64
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
75

@@ -29,8 +27,6 @@
2927
raise
3028

3129
from ads.feature_store.common.enums import (
32-
FeatureStoreJobType,
33-
LifecycleState,
3430
EntityType,
3531
ExpectationType,
3632
)
@@ -47,6 +43,11 @@
4743

4844
from ads.feature_store.feature_statistics.statistics_service import StatisticsService
4945
from ads.feature_store.common.utils.utility import validate_input_feature_details
46+
from typing import TYPE_CHECKING
47+
48+
if TYPE_CHECKING:
49+
from ads.feature_store.feature_group import FeatureGroup
50+
from ads.feature_store.dataset import Dataset
5051

5152
logger = logging.getLogger(__name__)
5253

@@ -76,7 +77,10 @@ def __init__(self, metastore_id: str = None):
7677
self._jvm = self._spark_context._jvm
7778

7879
def ingest_feature_definition(
79-
self, feature_group, feature_group_job: FeatureGroupJob, dataframe
80+
self,
81+
feature_group: "FeatureGroup",
82+
feature_group_job: FeatureGroupJob,
83+
dataframe,
8084
):
8185
try:
8286
self._save_offline_dataframe(dataframe, feature_group, feature_group_job)
@@ -90,7 +94,7 @@ def ingest_dataset(self, dataset, dataset_job: DatasetJob):
9094
raise SparkExecutionException(e).with_traceback(e.__traceback__)
9195

9296
def delete_feature_definition(
93-
self, feature_group, feature_group_job: FeatureGroupJob
97+
self, feature_group: "FeatureGroup", feature_group_job: FeatureGroupJob
9498
):
9599
"""
96100
Deletes a feature definition from the system.
@@ -122,7 +126,7 @@ def delete_feature_definition(
122126
output_details=output_details,
123127
)
124128

125-
def delete_dataset(self, dataset, dataset_job: DatasetJob):
129+
def delete_dataset(self, dataset: "Dataset", dataset_job: DatasetJob):
126130
"""
127131
Deletes a dataset from the system.
128132
@@ -154,7 +158,7 @@ def delete_dataset(self, dataset, dataset_job: DatasetJob):
154158
)
155159

156160
@staticmethod
157-
def _validate_expectation(expectation_type, validation_output):
161+
def _validate_expectation(expectation_type, validation_output: dict):
158162
"""
159163
Validates the expectation based on the given expectation type and the validation output.
160164

ads/feature_store/statistics.py

Lines changed: 105 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,94 @@
1-
import pandas as pd
2-
from typing import Dict
3-
from copy import deepcopy
1+
import matplotlib.figure
2+
from matplotlib.gridspec import GridSpec
3+
from matplotlib.figure import Figure
44

55
from ads.feature_store.response.response_builder import ResponseBuilder
6-
from ads.jobs.builders.base import Builder
7-
from ads.common import utils
6+
7+
import matplotlib.pyplot as plt
8+
import matplotlib.font_manager as fm
9+
import json
10+
11+
12+
def add_plots_for_stat(fig: Figure, feature: str, stat: dict):
13+
freq_dist = stat.get(Statistics.CONST_FREQUENCY_DISTRIBUTION)
14+
top_k = stat.get(Statistics.CONST_TOP_K_FREQUENT)
15+
probability = stat.get(Statistics.CONST_PROBABILITY_DISTRIBUTION)
16+
17+
def subplot_generator():
18+
plot_count = 0
19+
if stat.get(Statistics.CONST_FREQUENCY_DISTRIBUTION) is not None:
20+
plot_count += 1
21+
# if stat.get(Statistics.CONST_TOP_K_FREQUENT) is not None:
22+
# plot_count += 1
23+
if stat.get(Statistics.CONST_PROBABILITY_DISTRIBUTION) is not None:
24+
plot_count += 1
25+
26+
for i in range(0, plot_count):
27+
yield fig.add_subplot(1, plot_count, i + 1)
28+
29+
subplots = subplot_generator()
30+
if freq_dist is not None:
31+
axs = next(subplots)
32+
fig.suptitle(feature, fontproperties=fm.FontProperties(weight="bold"))
33+
axs.hist(
34+
x=freq_dist.get("bins"),
35+
weights=freq_dist.get("frequency"),
36+
cumulative=False,
37+
color="teal",
38+
mouseover=True,
39+
animated=True,
40+
)
41+
42+
axs.set_xlabel(
43+
"Bins", fontdict={"fontproperties": fm.FontProperties(size="xx-small")}
44+
)
45+
axs.set_ylabel(
46+
"Frequency", fontdict={"fontproperties": fm.FontProperties(size="xx-small")}
47+
)
48+
axs.set_title(
49+
"Frequency Distribution",
50+
fontdict={"fontproperties": fm.FontProperties(size="small")},
51+
)
52+
axs.set_xticks(freq_dist.get("bins"))
53+
if probability is not None:
54+
axs = next(subplots)
55+
fig.suptitle(feature, fontproperties=fm.FontProperties(weight="bold"))
56+
axs.bar(
57+
probability.get("bins"),
58+
probability.get("density"),
59+
color="teal",
60+
mouseover=True,
61+
animated=True,
62+
)
63+
axs.set_xlabel(
64+
"Bins", fontdict={"fontproperties": fm.FontProperties(size="xx-small")}
65+
)
66+
axs.set_ylabel(
67+
"Density", fontdict={"fontproperties": fm.FontProperties(size="xx-small")}
68+
)
69+
axs.set_title(
70+
"Probability Distribution",
71+
fontdict={"fontproperties": fm.FontProperties(size="small")},
72+
)
73+
axs.set_xticks(probability.get("bins"))
74+
75+
76+
def subfigure_generator(count: int, fig: Figure):
77+
rows = count
78+
subfigs = fig.subfigures(rows, 1)
79+
for i in range(0, rows):
80+
yield subfigs[i]
881

982

1083
class Statistics(ResponseBuilder):
1184
"""
1285
Represents statistical information.
1386
"""
1487

88+
CONST_FREQUENCY_DISTRIBUTION = "FrequencyDistribution"
89+
CONST_PROBABILITY_DISTRIBUTION = "ProbabilityDistribution"
90+
CONST_TOP_K_FREQUENT = "TopKFrequentElements"
91+
1592
@property
1693
def kind(self) -> str:
1794
"""
@@ -23,3 +100,26 @@ def kind(self) -> str:
23100
The kind of the statistics object, which is always "statistics".
24101
"""
25102
return "statistics"
103+
104+
def to_viz(self):
105+
if self.content is not None:
106+
stats: dict = json.loads(self.content)
107+
fig: Figure = plt.figure(figsize=(20, 20), dpi=150)
108+
plt.subplots_adjust(hspace=3)
109+
110+
stats = {
111+
feature: stat
112+
for feature, stat in stats.items()
113+
if Statistics.__graph_exists__(stat)
114+
}
115+
subfigures = subfigure_generator(len(stats), fig)
116+
for feature, stat in stats.items():
117+
sub_figure = next(subfigures)
118+
add_plots_for_stat(sub_figure, feature, stat)
119+
120+
@staticmethod
121+
def __graph_exists__(stat: dict):
122+
return (
123+
stat.get(Statistics.CONST_FREQUENCY_DISTRIBUTION) != None
124+
or stat.get(Statistics.CONST_PROBABILITY_DISTRIBUTION) != None
125+
)

tests/integration/feature_store/test_base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,8 @@
2222

2323

2424
client_kwargs = dict(
25-
retry_strategy=oci.retry.NoneRetryStrategy,
26-
service_endpoint=os.getenv("service_endpoint"),
25+
retry_strategy=oci.retry.NoneRetryStrategy(),
26+
fs_service_endpoint=os.getenv("service_endpoint"),
2727
)
2828
ads.set_auth(client_kwargs=client_kwargs)
2929

tests/integration/feature_store/test_dataset_complex.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -70,8 +70,6 @@ def test_manual_dataset(
7070
).create()
7171
assert len(dataset_resource.feature_groups) == 1
7272
assert dataset_resource.feature_groups[0].id == feature_group.id
73-
assert dataset_resource.get_spec(
74-
Dataset.CONST_FEATURE_GROUP
75-
).is_manual_association
73+
assert dataset_resource.is_manual_association
7674
dataset_resource.delete()
7775
return dataset_resource

0 commit comments

Comments
 (0)