Skip to content

Commit 28f763f

Browse files
authored
Merge pull request #141 from ianhi/scatter-selector
create scatter selector widget
2 parents 9178a51 + 92d1164 commit 28f763f

23 files changed

+827
-101
lines changed

.github/workflows/publish.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ jobs:
1414
- name: Install Python
1515
uses: actions/setup-python@v2
1616
with:
17-
python-version: '3.8'
17+
python-version: '3.x'
1818
- name: Install dependencies
1919
run: |
2020
python -m pip install --upgrade pip

.github/workflows/test.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@ jobs:
1515
runs-on: ubuntu-latest
1616
strategy:
1717
matrix:
18-
python-version: ['3.7', '3.8']
19-
matplotlib-version: ['3.2', '3.3']
18+
python-version: ['3.8.x', '3.9.x']
19+
matplotlib-version: ['3.3']
2020
steps:
2121
- name: Checkout
2222
uses: actions/checkout@v2

docs/API.rst

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,3 +52,18 @@ Functions that make some features in Matplotlib a bit more convenient.
5252
~mpl_interactions.ioff
5353
~mpl_interactions.figure
5454
~mpl_interactions.nearest_idx
55+
~mpl_interactions.indexer
56+
57+
widgets
58+
-------
59+
60+
Custom matplotlib widgets made for use in this library.
61+
62+
.. currentmodule:: mpl_interactions
63+
.. autosummary::
64+
:toctree: autoapi
65+
:nosignatures:
66+
67+
~mpl_interactions.widgets.scatter_selector
68+
~mpl_interactions.widgets.scatter_selector_value
69+
~mpl_interactions.widgets.scatter_selector_index
Binary file not shown.

docs/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@ Further discussion of the behavior as a function of backend can be found on the
9595
examples/imshow.ipynb
9696
examples/hist.ipynb
9797
examples/mpl-sliders.ipynb
98+
examples/scatter-selector.ipynb
9899
examples/image-segmentation.ipynb
99100
examples/zoom-factory.ipynb
100101
examples/heatmap-slicer.ipynb

examples/Usage-Guide.ipynb

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -301,8 +301,17 @@
301301
"3. `'fixed'` - Never automatically update the limits\n",
302302
"4. `[float, float]` - This value will be passed through to plt.xlim or plt.ylim\n",
303303
"\n",
304-
"### Reference parameters values in the Title\n",
305-
"You can make the Title automatically update with information about the values by using `title` argument. Use the name of one of the parameters as a format specifier in the string. For example use the following title string will put value of tau in the title and round it to two decimals: `'{'tau:.2f}'`"
304+
"### Reference parameters values in the Title, xlabel, or ylabel\n",
305+
"You can make the Title automatically update with information about the values by using `iplt.title`. You can either provide a function that returns a string, or you can provide a string with the names of one of the parameters as a format specifier in the string. Ultimately the string will be formatted using:\n",
306+
"```python\n",
307+
"if isinstance(title, Callable):\n",
308+
" title_str = title(**params)\n",
309+
"else:\n",
310+
" title_str = title\n",
311+
"ax.set_title(title_str.format(**params))\n",
312+
"```\n",
313+
"\n",
314+
"the same applies for the x and y labels."
306315
]
307316
},
308317
{
@@ -332,12 +341,15 @@
332341
" tau=tau,\n",
333342
" xlim=\"stretch\",\n",
334343
" ylim=\"auto\",\n",
335-
" title=\"the value of tau is: {tau:.2f}\",\n",
336344
" label=\"interactive!\",\n",
337345
")\n",
338346
"\n",
347+
"iplt.title(\"the value of tau is: {tau:.2f}\", controls=controls['tau'])\n",
339348
"# you can still use plt commands if this is the active figure\n",
340-
"plt.ylabel(\"yikes a ylabel!\")\n",
349+
"def ylabel(tau):\n",
350+
" return f\"tau/2 is {np.round(tau/2,3)}\"\n",
351+
"iplt.ylabel(ylabel, controls = controls['tau'])\n",
352+
"iplt.xlabel(\"This xlabel also changes with tau! tau~={tau:.0f}\", controls = controls['tau'])\n",
341353
"\n",
342354
"# you can new lines - though they won't be updated interactively.\n",
343355
"plt.plot(x, np.sin(x), label=\"Added after, not interactive\")\n",
@@ -381,7 +393,7 @@
381393
"name": "python",
382394
"nbconvert_exporter": "python",
383395
"pygments_lexer": "ipython3",
384-
"version": "3.7.8"
396+
"version": "3.9.0"
385397
}
386398
},
387399
"nbformat": 4,

examples/data/stock-metadata.pickle

48.5 KB
Binary file not shown.

examples/data/stock-prices.npz

1.48 MB
Binary file not shown.

examples/mpl-sliders-same-figure.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import numpy as np
33
from matplotlib.widgets import Slider
44

5-
from mpl_interactions import interactive_plot_factory
5+
from mpl_interactions import ipyplot as iplt
66

77
fig, ax = plt.subplots()
88
plt.subplots_adjust(bottom=0.25)
@@ -15,5 +15,5 @@ def f(x, freq):
1515

1616
axfreq = plt.axes([0.25, 0.1, 0.65, 0.03])
1717
slider = Slider(axfreq, label="freq", valmin=0.05, valmax=10)
18-
controls = interactive_plot_factory(ax, f, x=x, freq=slider)
18+
controls = iplt.plot(x, f, freq=slider, ax=ax)
1919
plt.show()

examples/plot.ipynb

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -194,12 +194,7 @@
194194
"3. `fixed`\n",
195195
" - always used the initial values of the limits\n",
196196
"4. a tuple\n",
197-
" - You can pass a value such as `[-4,5]` to have the limits not be updated by moving the sliders.\n",
198-
" \n",
199-
" \n",
200-
"### Title\n",
201-
"\n",
202-
"You can make the title auto update with information about the values by using the `title` argument. Just use the name of one of the parameters as in a format specifier in the string. e.g. to put the value of `tau` in and round it to two decimals use the following title string: `{'tau:.2f}'`"
197+
" - You can pass a value such as `[-4,5]` to have the limits not be updated by moving the sliders."
203198
]
204199
},
205200
{
@@ -227,10 +222,10 @@
227222
" tau=tau,\n",
228223
" xlim=\"stretch\",\n",
229224
" ylim=\"auto\",\n",
230-
" title=\"the value of tau is: {tau:.2f}\",\n",
231225
" label=\"interactive!\",\n",
232226
")\n",
233227
"\n",
228+
"iplt.title(\"the value of tau is: {tau:.2f}\", controls=controls[\"tau\"])\n",
234229
"# you can still use plt commands if this is the active figure\n",
235230
"plt.ylabel(\"yikes a ylabel!\")\n",
236231
"\n",
@@ -335,7 +330,7 @@
335330
"name": "python",
336331
"nbconvert_exporter": "python",
337332
"pygments_lexer": "ipython3",
338-
"version": "3.7.8"
333+
"version": "3.9.0"
339334
}
340335
},
341336
"nbformat": 4,

examples/scatter-selector.ipynb

Lines changed: 243 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,243 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"metadata": {},
6+
"source": [
7+
"# scatter_selector widget\n",
8+
"\n",
9+
"A set of custom matplotlib widgets that allow you to select points on a scatter plot as use that as input to other interactive plots. There are three variants that differ only in what they pass to their callbacks:\n",
10+
"\n",
11+
"1. `scatter_selector`: callbacks will receive `index, (x, y)` where `index` is the position of the point in the of the points.\n",
12+
"2. `scatter_selector_values`: callbacks will receive `x, y`\n",
13+
"3. `scatter_selector_index`: callbacks will receive `index`\n",
14+
"\n",
15+
"\n",
16+
"In this example we will use `scatter_selector_index` along with the `indexer` convenience function to make line plots of stock data. However, you can use custom functions for the interactive plots, or even attach your own callbacks to the scatter_selector widgets.\n",
17+
"\n",
18+
"\n",
19+
"## PCA of Stock Data\n",
20+
"\n",
21+
"For this example we will plot companies in SP500 in a scatter plot by principle components extracted from principal components analysis (PCA) an interactive visualization of companies in SP500 using [PCA](https://towardsdatascience.com/a-one-stop-shop-for-principal-component-analysis-5582fb7e0a9c). The data was originally obtained from https://www.kaggle.com/camnugent/sandp500 and the data was cleaned using code derived from https://github.com/Hekstra-Lab/scientific-python-bootcamp/tree/master/day3\n",
22+
"\n"
23+
]
24+
},
25+
{
26+
"cell_type": "code",
27+
"execution_count": null,
28+
"metadata": {},
29+
"outputs": [],
30+
"source": [
31+
"%matplotlib ipympl\n",
32+
"import pickle\n",
33+
"\n",
34+
"import ipywidgets as widgets\n",
35+
"import matplotlib.pyplot as plt\n",
36+
"import numpy as np\n",
37+
"\n",
38+
"import mpl_interactions.ipyplot as iplt\n",
39+
"from mpl_interactions import indexer, panhandler, zoom_factory\n",
40+
"from mpl_interactions.utils import indexer\n",
41+
"from mpl_interactions.widgets import scatter_selector_index"
42+
]
43+
},
44+
{
45+
"cell_type": "markdown",
46+
"metadata": {},
47+
"source": [
48+
"### Data loading/cleaning\n",
49+
"\n",
50+
"For this example we have pre-cleaned data that we will just load. If you are curious on how the data was originally processed you see the full code at the bottom of this notebook.\n",
51+
"\n",
52+
"The datafiles that we load for this example are available for download at https://github.com/ianhi/mpl-interactions/tree/master/examples/data"
53+
]
54+
},
55+
{
56+
"cell_type": "code",
57+
"execution_count": null,
58+
"metadata": {},
59+
"outputs": [],
60+
"source": [
61+
"import pickle\n",
62+
"\n",
63+
"with open(\"data/stock-metadata.pickle\", \"rb\") as f:\n",
64+
" meta = pickle.load(f)\n",
65+
"prices = np.load(\"data/stock-prices.npz\")[\"prices\"]\n",
66+
"names = meta[\"names\"]\n",
67+
"good_idx = meta[\n",
68+
" \"good_idx\"\n",
69+
"] # only plot the ones for which we were able to parse sector info\n",
70+
"data_colors = meta[\"data_colors\"]\n",
71+
"\n",
72+
"# calculate the daily price difference\n",
73+
"price_changes = np.diff(prices)\n",
74+
"\n",
75+
"# Below is a pretty standard way of normalizing numerical data\n",
76+
"normalized_price_changes = price_changes - price_changes.mean(axis=-1, keepdims=True)\n",
77+
"normalized_price_changes /= price_changes.std(axis=-1, keepdims=True)\n",
78+
"\n",
79+
"# calculate the covariance matrix\n",
80+
"covariance = np.cov(normalized_price_changes.T)\n",
81+
"\n",
82+
"# Calculate the eigenvectors (i.e. the principle components)\n",
83+
"evals, evecs = np.linalg.eig(covariance)\n",
84+
"evecs = np.real(evecs)\n",
85+
"\n",
86+
"# project the companies onto the principle components\n",
87+
"transformed = normalized_price_changes @ evecs\n",
88+
"\n",
89+
"# take only the first two components for plotting\n",
90+
"# we also take only the subset of companies for which it was easy to extract a sector and a name\n",
91+
"x, y = transformed[good_idx][:, 0], transformed[good_idx][:, 1]"
92+
]
93+
},
94+
{
95+
"cell_type": "markdown",
96+
"metadata": {},
97+
"source": [
98+
"### Making the plot\n",
99+
"\n",
100+
"We create the left scatter plot using the `scatter_selector_index` which will tell use the index of the company that was clicked on. Since this is just a Matplotlib `AxesWidget` it can be passed directly to `iplt.plot` as a kwarg and the `controls` object will handle it approriately.\n",
101+
"\n",
102+
"In this example we also make use of the function `mpl_interactions.utils.indexer`. This is a convenience function that handles indexing an array for you. So these two statements are equivalent:\n",
103+
"\n",
104+
"```python\n",
105+
"# set up data\n",
106+
"arr = np.random.randn(4,100).cumsum(-1)\n",
107+
"\n",
108+
"def f(idx):\n",
109+
" return arr[idx]\n",
110+
"iplt.plot(f, idx=np.arange(4))\n",
111+
"\n",
112+
"# or equivalently\n",
113+
"iplt.plot(indexer(arr), idx=np.arange(4))\n",
114+
"```"
115+
]
116+
},
117+
{
118+
"cell_type": "code",
119+
"execution_count": null,
120+
"metadata": {
121+
"gif": "scatter-selector-stocks.apng"
122+
},
123+
"outputs": [],
124+
"source": [
125+
"fig, axs = plt.subplots(1, 2, figsize=(10, 5), gridspec_kw={\"width_ratios\": [1.5, 1]})\n",
126+
"index = scatter_selector_index(axs[0], x, y, c=data_colors, cmap=\"tab20\")\n",
127+
"\n",
128+
"# plot all the stock traces in light gray\n",
129+
"plt.plot(prices.T, color=\"k\", alpha=0.05)\n",
130+
"\n",
131+
"# add interactive components to the subplot on the right\n",
132+
"# note the use of indexer\n",
133+
"controls = iplt.plot(indexer(prices), idx=index, color=\"r\")\n",
134+
"iplt.title(indexer(names), controls=controls[\"idx\"])\n",
135+
"\n",
136+
"# styling + zooming\n",
137+
"axs[0].set_xlabel(\"PC-1\")\n",
138+
"axs[0].set_ylabel(\"PC-2\")\n",
139+
"axs[1].set_xlabel(\"days\")\n",
140+
"axs[1].set_ylabel(\"Price in $\")\n",
141+
"axs[1].set_yscale(\"log\")\n",
142+
"cid = zoom_factory(axs[0])\n",
143+
"ph = panhandler(fig)"
144+
]
145+
},
146+
{
147+
"cell_type": "markdown",
148+
"metadata": {},
149+
"source": [
150+
"### Datacleaning\n",
151+
"\n",
152+
"Below is the code we used to clean and save the datasets. While we start out with 500 companies we end up with only 468 as some of them we were unable to easily and correctly parse so they were thrown away."
153+
]
154+
},
155+
{
156+
"cell_type": "code",
157+
"execution_count": null,
158+
"metadata": {},
159+
"outputs": [],
160+
"source": [
161+
"# NBVAL_SKIP\n",
162+
"# Download the data from https://www.kaggle.com/camnugent/sandp500\n",
163+
"# and save it into a folder named `data`\n",
164+
"import glob\n",
165+
"\n",
166+
"test = np.loadtxt(\"data/A_data.csv\", delimiter=\",\", skiprows=1, usecols=1)\n",
167+
"sp500_glob = glob.glob(\n",
168+
" \"data/*.csv\",\n",
169+
")\n",
170+
"names = []\n",
171+
"prices = np.zeros((len(sp500_glob), test.shape[0]))\n",
172+
"prices_good = []\n",
173+
"fails = []\n",
174+
"for i, f in enumerate(sp500_glob):\n",
175+
" fname = f.split(\"/\")[-1]\n",
176+
" names.append(fname.split(\"_\")[0])\n",
177+
" try:\n",
178+
" prices[i] = np.loadtxt(f, delimiter=\",\", skiprows=1, usecols=1)\n",
179+
" prices_good.append(True)\n",
180+
" except:\n",
181+
" fails.append(fname.split(\"_\")[0])\n",
182+
" prices_good.append(False)\n",
183+
" pass\n",
184+
"prices = prices[prices_good]\n",
185+
"np.savez_compressed(\"data/stock-prices.npz\", prices=prices)\n",
186+
"\n",
187+
"# processing names and sector info\n",
188+
"\n",
189+
"arr = np.loadtxt(\n",
190+
" \"data/SP500_names.csv\", delimiter=\"|\", skiprows=1, dtype=str, encoding=\"utf-8\"\n",
191+
")\n",
192+
"name_dict = {a[0].strip(): a[[1, 2, 3]] for a in arr}\n",
193+
"# idx_to_info = {i:name_dict[real_names[i]] for i in range(468)}\n",
194+
"good_names = []\n",
195+
"primary = []\n",
196+
"secondary = []\n",
197+
"good_idx = np.zeros(real_names.shape[0], dtype=bool)\n",
198+
"for i, name in enumerate(real_names):\n",
199+
" try:\n",
200+
" info = name_dict[name]\n",
201+
" good_idx[i] = True\n",
202+
" good_names.append(info[0])\n",
203+
" primary.append(info[1])\n",
204+
" secondary.append(info[2])\n",
205+
" except:\n",
206+
" pass\n",
207+
"psector_dict = {val: i for i, val in enumerate(np.unique(primary))}\n",
208+
"data_colors = np.array([psector_dict[val] for val in primary], dtype=int)\n",
209+
"import pickle\n",
210+
"\n",
211+
"meta = {\n",
212+
" \"good_idx\": good_idx,\n",
213+
" \"names\": good_names,\n",
214+
" \"sector\": psector_dict,\n",
215+
" \"data_colors\": data_colors,\n",
216+
"}\n",
217+
"with open(\"data/stock-metadata.pickle\", \"wb\") as outfile:\n",
218+
" pickle.dump(meta, outfile)"
219+
]
220+
}
221+
],
222+
"metadata": {
223+
"kernelspec": {
224+
"display_name": "Python 3",
225+
"language": "python",
226+
"name": "python3"
227+
},
228+
"language_info": {
229+
"codemirror_mode": {
230+
"name": "ipython",
231+
"version": 3
232+
},
233+
"file_extension": ".py",
234+
"mimetype": "text/x-python",
235+
"name": "python",
236+
"nbconvert_exporter": "python",
237+
"pygments_lexer": "ipython3",
238+
"version": "3.7.8"
239+
}
240+
},
241+
"nbformat": 4,
242+
"nbformat_minor": 4
243+
}

examples/test.dat

48.5 KB
Binary file not shown.

mpl_interactions/_version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
1-
version_info = (0, 12, 0)
1+
version_info = (0, 13, 0)
22
__version__ = ".".join(map(str, version_info))

0 commit comments

Comments
 (0)