Skip to content

Commit 87c31bd

Browse files
authored
Uptake bokeh and fix ADS BokehHeatMap (#529)
1 parent 030824b commit 87c31bd

File tree

4 files changed

+162
-23
lines changed

4 files changed

+162
-23
lines changed

ads/dataset/correlation_plot.py

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#!/usr/bin/env python
22
# -*- coding: utf-8; -*-
33

4-
# Copyright (c) 2020, 2022 Oracle and/or its affiliates.
4+
# Copyright (c) 2020, 2024 Oracle and/or its affiliates.
55
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
66

77
import os
@@ -26,7 +26,6 @@ class BokehHeatMap(object):
2626

2727
@runtime_dependency(module="bokeh", install_from=OptionalDependency.VIZ)
2828
def __init__(self, ds):
29-
3029
from bokeh.io import output_notebook
3130
from bokeh.palettes import BuPu
3231

@@ -120,8 +119,8 @@ def plot_heat_map(
120119
y_range=yrange,
121120
toolbar_location="below",
122121
toolbar_sticky=False,
123-
plot_width=600,
124-
plot_height=600,
122+
width=600,
123+
height=600,
125124
)
126125

127126
p.rect(
@@ -196,8 +195,8 @@ def plot_hbar(
196195
y_range=(0, len(matrix["Y"]) + 1),
197196
toolbar_location="below",
198197
toolbar_sticky=False,
199-
plot_width=600,
200-
plot_height=600,
198+
width=600,
199+
height=600,
201200
)
202201

203202
p.hbar(
@@ -229,7 +228,6 @@ def plot_hbar(
229228
level="glyph",
230229
y_offset=-5,
231230
source=source,
232-
render_mode="canvas",
233231
)
234232

235233
p.add_layout(color_bar, "right")
@@ -264,7 +262,7 @@ def generate_heatmap(
264262
from bokeh.plotting import figure
265263

266264
if len(corr_matrix) == 0:
267-
tab = bokeh.models.Panel(
265+
tab = bokeh.models.TabPanel(
268266
child=figure(title=msg + ", nothing to display"),
269267
title=title,
270268
)
@@ -286,7 +284,7 @@ def generate_heatmap(
286284
tool_tips=[("X", "@x"), ("Y", "@y"), ("Corr", "@corr")],
287285
)
288286

289-
tab = bokeh.models.Panel(child=p, title=title)
287+
tab = bokeh.models.TabPanel(child=p, title=title)
290288
return tab
291289

292290
@runtime_dependency(module="bokeh", install_from=OptionalDependency.VIZ)
@@ -323,7 +321,7 @@ def generate_target_heatmap(
323321
from bokeh.plotting import figure
324322

325323
if len(corr_matrix) == 0:
326-
tab = bokeh.models.Panel(
324+
tab = bokeh.models.TabPanel(
327325
child=figure(title=msg + ", nothing to display"),
328326
title=title,
329327
)
@@ -333,7 +331,7 @@ def generate_target_heatmap(
333331

334332
assert correlation_target, "Correlation target is required for this plot"
335333
if correlation_target not in corr_matrix.columns:
336-
tab = bokeh.models.Panel(
334+
tab = bokeh.models.TabPanel(
337335
child=figure(title="No Data to display"), title=title
338336
)
339337
return tab
@@ -382,7 +380,7 @@ def generate_target_heatmap(
382380
column_name=correlation_target,
383381
)
384382

385-
tab = bokeh.models.Panel(child=p, title=title)
383+
tab = bokeh.models.TabPanel(child=p, title=title)
386384
return tab
387385

388386
@runtime_dependency(module="bokeh", install_from=OptionalDependency.VIZ)

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ torch = [
147147
"torchvision",
148148
]
149149
viz = [
150-
"bokeh>=2.3.0, <=2.4.3",
150+
"bokeh>=3.0.0, <3.2.0", # starting 3.2.0 bokeh not supporting python3.8; relax after ADS will drop py3.8 support
151151
"folium>=0.12.1",
152152
"graphviz<0.17",
153153
"scipy>=1.5.4",

tests/unitary/default_setup/telemetry/test_agent.py

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#!/usr/bin/env python
22

3-
# Copyright (c) 2022, 2023 Oracle and/or its affiliates.
3+
# Copyright (c) 2022, 2024 Oracle and/or its affiliates.
44
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
55

66
import importlib
@@ -50,7 +50,6 @@ def test_user_agent_rp(self, mock_signer, monkeypatch):
5050
monkeypatch.delenv("OCI_RESOURCE_PRINCIPAL_VERSION", raising=False)
5151
monkeypatch.delenv(EXTRA_USER_AGENT_INFO, raising=False)
5252
importlib.reload(ads.config)
53-
importlib.reload(ads.telemetry)
5453
auth_info = ads.auth.resource_principal()
5554
assert (
5655
auth_info["config"].get("additional_user_agent")
@@ -125,12 +124,7 @@ def test_user_agent_default_signer_known_resources(
125124
monkeypatch.delenv(EXTRA_USER_AGENT_INFO, raising=False)
126125
if INPUT_DATA[EXTRA_USER_AGENT_INFO] is not None:
127126
monkeypatch.setenv(EXTRA_USER_AGENT_INFO, INPUT_DATA[EXTRA_USER_AGENT_INFO])
128-
129127
importlib.reload(ads.config)
130-
importlib.reload(ads)
131-
importlib.reload(ads.auth)
132-
importlib.reload(ads.telemetry)
133-
134128
with patch("oci.config.from_file", return_value=self.test_config):
135129
auth_info = ads.auth.default_signer()
136130
assert (
@@ -151,10 +145,7 @@ def test_user_agent_default_signer_ociservice(
151145
):
152146
monkeypatch.setenv("OCI_RESOURCE_PRINCIPAL_VERSION", "1.1")
153147
monkeypatch.delenv(EXTRA_USER_AGENT_INFO, raising=False)
154-
155148
importlib.reload(ads.config)
156-
importlib.reload(ads.telemetry)
157-
158149
with patch("oci.config.from_file", return_value=self.test_config):
159150
auth_info = ads.auth.default_signer()
160151
assert (
Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
#!/usr/bin/env python
2+
3+
# Copyright (c) 2024 Oracle and/or its affiliates.
4+
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
5+
6+
import os
7+
8+
import pandas as pd
9+
import pytest
10+
import unittest
11+
12+
from ads.dataset.correlation_plot import BokehHeatMap
13+
from ads.dataset.dataset import ADSDataset
14+
from ads.dataset.exception import ValidationError
15+
16+
17+
class TestBokehHeatMap(unittest.TestCase):
18+
dataset = pd.DataFrame(
19+
columns=[
20+
"one",
21+
"two",
22+
"three",
23+
"four",
24+
"five",
25+
],
26+
data=[
27+
[
28+
3.4,
29+
999,
30+
1,
31+
None,
32+
95354,
33+
],
34+
[
35+
3,
36+
999,
37+
5,
38+
None,
39+
90421,
40+
],
41+
[
42+
2,
43+
999,
44+
1,
45+
15,
46+
89352,
47+
],
48+
[
49+
3,
50+
999,
51+
6,
52+
11,
53+
89427,
54+
],
55+
[
56+
2,
57+
999,
58+
1,
59+
46,
60+
94342,
61+
],
62+
],
63+
)
64+
bokeh_heatmap = BokehHeatMap(dataset)
65+
corr_matrix = dataset.corr()
66+
67+
def test_plot_heat_map(self):
68+
p = self.bokeh_heatmap.plot_heat_map(
69+
self.corr_matrix,
70+
xrange=self.corr_matrix.index.values.tolist(),
71+
yrange=self.corr_matrix.columns.values.tolist(),
72+
title="heat_map",
73+
)
74+
75+
self.assertEqual(str(type(p)), "<class 'bokeh.plotting._figure.figure'>")
76+
assert p.width == 600
77+
assert p.height == 600
78+
assert p.xaxis.major_label_orientation == "vertical"
79+
assert len(p.yaxis) == 1
80+
assert p.title.text == "heat_map"
81+
82+
def test_plot_hbar(self):
83+
rows = self.corr_matrix.index.values.tolist()
84+
columns = self.corr_matrix.columns
85+
corr_flatten = pd.DataFrame(
86+
[(r, c, self.corr_matrix[r][c]) for c in columns for r in rows],
87+
columns=["X", "Y", "corr"],
88+
)
89+
p = self.bokeh_heatmap.plot_hbar(
90+
corr_flatten, title="plot_hbar", column_name="name in title"
91+
)
92+
93+
self.assertEqual(str(type(p)), "<class 'bokeh.plotting._figure.figure'>")
94+
assert p.width == 600
95+
assert p.height == 600
96+
assert p.toolbar_location == "below"
97+
assert p.title.text == "plot_hbar (name in title)"
98+
99+
def test_generate_heatmap(self):
100+
tabs = self.bokeh_heatmap.generate_heatmap(
101+
self.corr_matrix, title="heatmap", msg="", correlation_threshold=-1
102+
)
103+
104+
self.assertEqual(str(type(tabs)), "<class 'bokeh.models.layouts.TabPanel'>")
105+
self.assertEqual(
106+
str(type(tabs.child)), "<class 'bokeh.plotting._figure.figure'>"
107+
)
108+
assert tabs.child.width == 600
109+
assert tabs.child.height == 600
110+
assert len(tabs.child.yaxis) == 1
111+
assert tabs.title == "heatmap"
112+
113+
def test_generate_target_heatmap(self):
114+
tabs = self.bokeh_heatmap.generate_target_heatmap(
115+
self.corr_matrix,
116+
title="target_heatmap",
117+
correlation_target="one",
118+
msg="",
119+
correlation_threshold=-1,
120+
)
121+
122+
self.assertEqual(str(type(tabs)), "<class 'bokeh.models.layouts.TabPanel'>")
123+
self.assertEqual(
124+
str(type(tabs.child)), "<class 'bokeh.plotting._figure.figure'>"
125+
)
126+
assert tabs.child.width == 600
127+
assert tabs.child.height == 600
128+
assert tabs.child.toolbar_location == "below"
129+
assert tabs.title == "target_heatmap"
130+
131+
def test_plot_correlation_heatmap(self):
132+
current_dir = os.path.dirname(os.path.abspath(__file__))
133+
data_file = os.path.join(current_dir, "data", "orcl_attrition.csv")
134+
df = pd.read_csv(data_file)
135+
ds = ADSDataset.from_dataframe(df)
136+
bokeh_heatmap = BokehHeatMap(ds)
137+
138+
with pytest.raises(ValidationError):
139+
bokeh_heatmap.plot_correlation_heatmap(
140+
ds=ds, correlation_methods="wrong_correlation_methods"
141+
)
142+
143+
bokeh_heatmap.plot_correlation_heatmap(ds=ds, correlation_methods="pearson")
144+
bokeh_heatmap.plot_correlation_heatmap(ds=ds, correlation_methods="cramers v")
145+
bokeh_heatmap.plot_correlation_heatmap(
146+
ds=ds, correlation_methods="correlation ratio"
147+
)
148+
149+
with pytest.raises(ValueError):
150+
bokeh_heatmap.plot_correlation_heatmap(ds=ds, plot_type="wrong_plot_type")

0 commit comments

Comments
 (0)