Skip to content

Commit 02413d0

Browse files
Merge pull request #1 from onnela-lab/reorg
Reorganize repository to use `uv`.
2 parents ae4eb20 + e0d8e72 commit 02413d0

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

46 files changed

+4455
-566
lines changed

.coveragerc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
[report]
2+
exclude_lines =
3+
if __name__ == "__main__":
4+
pragma: no cover
5+
raise NotImplementedError

.github/workflows/main.yaml

Lines changed: 30 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,43 +1,43 @@
1-
name: Summary Statistics
1+
name: CI
2+
23
on:
34
push:
4-
branches: [main]
5+
branches: ["main"]
6+
tags: ["*"]
57
pull_request:
6-
branches: [main]
8+
branches: ["main"]
9+
workflow_dispatch:
710

811
env:
9-
cmdstanVersion: "2.31.0"
12+
cmdstanVersion: "2.36.0"
1013

1114
jobs:
1215
build:
13-
runs-on: ubuntu-latest
16+
name: Continuous Integration
17+
runs-on: "ubuntu-latest"
1418
steps:
15-
- name: Check out the repository
16-
uses: actions/checkout@v3
17-
- name: Setup python
18-
uses: actions/setup-python@v4
19+
- uses: "actions/checkout@v4"
20+
- name: Install uv
21+
uses: astral-sh/setup-uv@v5
22+
with:
23+
enable-cache: true
24+
cache-dependency-glob: uv.lock
25+
- name: Set up Python
26+
uses: actions/setup-python@v5
1927
with:
20-
python-version: '3.10'
21-
cache: pip
22-
- name: Pre-install packages that are build requirements.
23-
run: |
24-
pip install `grep pybind11== requirements.txt` `grep wheel== requirements.txt`
25-
pip install --no-build-isolation `grep -oE 'git\+.*?fasttr.git@\w+' requirements.txt`
26-
pip install --extra-index-url https://download.pytorch.org/whl/cpu --no-compile --no-deps \
27-
`grep torch== requirements.txt` \
28-
`grep typing-extensions== requirements.txt` \
29-
`grep mpmath== requirements.txt` \
30-
`grep sympy== requirements.txt`
31-
- name: Install Python dependencies.
32-
run: pip install --extra-index-url https://download.pytorch.org/whl/cpu --no-compile --no-deps -r requirements.txt
33-
- name: Cache cmdstan.
34-
uses: actions/cache@v3
28+
python-version-file: .python-version
29+
- name: Install project
30+
run: uv sync --all-groups
31+
- name: Cache cmdstan
32+
uses: actions/cache@v4
3533
with:
3634
path: /home/runner/.cmdstan
3735
key: cmdstan-${{ env.cmdstanVersion }}
38-
- name: Install cmdstan.
39-
run: python -m cmdstanpy.install_cmdstan --version ${{ env.cmdstanVersion }}
40-
- name: Lint the code.
41-
run: flake8
42-
- name: Run the tests.
43-
run: pytest --cov=summaries --cov-report=term-missing --cov-fail-under=100
36+
- name: Install cmdstan
37+
run: uv run python -m cmdstanpy.install_cmdstan --version ${{ env.cmdstanVersion }}
38+
- name: Lint the code
39+
run: true
40+
- name: Run test_infer_mdn_tree separately
41+
run: uv run pytest tests/scripts/test_infer_mdn_tree.py --cov=summaries --cov-report=term-missing
42+
- name: Run all other tests
43+
run: uv run pytest --cov=summaries --cov-report=term-missing --cov-fail-under=100 --cov-append --ignore=tests/scripts/test_infer_mdn_tree.py

.gitignore

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,5 +163,5 @@ workspace/
163163
*.ipynb
164164
*.tmp
165165
*.hpp
166-
summaries/scripts/infer_benchmark
166+
src/summaries/scripts/infer_benchmark
167167
????-??-??-*

.python-version

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
3.12

notebooks/benchmark.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import torch
1111

1212

1313
mpl.style.use("../.mplstyle")
14+
figwidth, figheight = mpl.rcParams["figure.figsize"]
1415
```
1516

1617
```python
@@ -67,7 +68,7 @@ X = torch.as_tensor(observed["data"][idx], dtype=dtype)
6768
x = X[:, 0]
6869
samples = mdnabc["samples"][idx].squeeze()
6970

70-
fig, axes = plt.subplots(2, 2, sharex="col", sharey="row")
71+
fig, axes = plt.subplots(2, 2, sharex="col", sharey="row", figsize=(figwidth, figheight))
7172

7273
ax = axes[0, 0] # ----- ----- ----- -----
7374

notebooks/experiments.md

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,13 @@ figwidth, figheight = mpl.rcParams["figure.figsize"]
1313
```
1414

1515
```python
16-
fig, ax = plt.subplots()
16+
fig, ax = plt.subplots(figsize=(figwidth, figheight))
1717

1818
vspace = 1
1919
hspace = 2.1
2020
kwargs = {
2121
"ha": "center",
22-
"va": "center",
22+
"va": "center",
2323
"fontsize": "small",
2424
}
2525
fkwargs = kwargs | {
@@ -64,7 +64,7 @@ texts = {
6464
"mdn-compressed-samples": (2 * hspace, -7 * vspace, "MDN-compressed samples\n$\\tilde\\theta\\in\\mathbb{R}^p\\sim \\tilde f\\left(\\theta\\mid t(y)\\right)$", vkwargs),
6565
"estimator": (0, -6 * vspace, "mixture density\nnetwork $h:\\mathbb{R}^q\\rightarrow\\mathcal{F}$", fkwargs),
6666
"estimate": (0, -7 * vspace, "density estimate\n$\\hat f\\left(\\theta\\mid t(z)\\right)\\in \\mathcal{F}$", vkwargs),
67-
"loss": (- hspace / 2, -8 * vspace, "NLP loss", fkwargs),
67+
"loss": (- hspace / 2, -8 * vspace, "NLP loss", fkwargs),
6868
}
6969

7070
elements = {}
@@ -80,12 +80,12 @@ if True:
8080
ax.set_xlim(xmin, xmax)
8181
ax.set_ylim(ymin, ymax)
8282
ax.set_aspect("equal")
83-
83+
8484
# Then adjust based on the actual extent of the box containing the text.
8585
fig.tight_layout()
8686
fig.draw_without_rendering()
8787
transform = ax.transData.inverted()
88-
extents = np.asarray([transform.transform(element.get_bbox_patch().get_window_extent())
88+
extents = np.asarray([transform.transform(element.get_bbox_patch().get_window_extent())
8989
for element in elements.values()])
9090
xmin, ymin = extents.min(axis=0)[0]
9191
xmax, ymax = extents.max(axis=0)[1]
@@ -105,12 +105,12 @@ connections = [
105105
[("params", 6), ("simulator", 12)],
106106
[("simulator", 6), ("simulated_data", 12)],
107107
[
108-
("simulated_data", 3),
108+
("simulated_data", 3),
109109
(get_anchor(elements["compressor"], 11).x, get_anchor(elements["simulated_data"], 3).y),
110110
("compressor", 11),
111111
],
112112
[
113-
("compressor", 7),
113+
("compressor", 7),
114114
(get_anchor(elements["compressor"], 7).x, get_anchor(elements["simulated_summaries"], 2.75).y),
115115
("simulated_summaries", 2.75),
116116
],
@@ -119,7 +119,7 @@ connections = [
119119
(get_anchor(elements["compressor"], 6).x, get_anchor(elements["simulated_summaries"], 3.25).y),
120120
(get_anchor(elements["compressor"], 6).x, get_anchor(elements["abc"], 9).y),
121121
("abc", 9),
122-
122+
123123
],
124124
[("simulated_summaries", 6), ("estimator", 12)],
125125
[("estimator", 6), ("estimate", 12)],
@@ -136,12 +136,12 @@ connections = [
136136

137137
# Observed.
138138
[
139-
("observed_data", 9),
139+
("observed_data", 9),
140140
(get_anchor(elements["compressor"], 1).x, get_anchor(elements["observed_data"], 9).y),
141141
("compressor", 1),
142142
],
143143
[
144-
("compressor", 5),
144+
("compressor", 5),
145145
(get_anchor(elements["compressor"], 5).x, get_anchor(elements["observed_summaries"], 9).y),
146146
("observed_summaries", 9),
147147
],
@@ -198,7 +198,7 @@ for (row, col), cell in celld.items():
198198
x = cell.get_x()
199199
y = cell.get_y() - (row + 1) * (height - original_height)
200200
phantom = mpl.patches.Rectangle(
201-
(x, y), cell.get_width(), cell.get_height(), transform=ax.transAxes,
201+
(x, y), cell.get_width(), cell.get_height(), transform=ax.transAxes,
202202
facecolor=color, zorder=-1)
203203
ax.add_patch(phantom)
204204

notebooks/illustrations.md

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ figwidth, figheight = mpl.rcParams["figure.figsize"]
1515
# Illustration of different methods and classes of summaries
1616

1717
```python
18-
fig, axes = plt.subplots(1, 2, figsize=(figwidth, 3), width_ratios=(3, 2))
18+
fig, axes = plt.subplots(1, 2, figsize=(figwidth, 2.55), width_ratios=(3.2, 2))
1919

2020
ax = axes[1]
2121
sets = {
@@ -126,7 +126,7 @@ for key, value in sets.items():
126126
ax.text(*value["textxy"], value["text"], **textkwargs, zorder=9)
127127

128128
ax.set_xlim(0, 11)
129-
ax.set_ylim(-0.2, 10.6)
129+
ax.set_ylim(-3, 10.6)
130130
ax.set_aspect("equal")
131131
ax.set_axis_off()
132132

@@ -157,10 +157,10 @@ elements = {}
157157
for key, (x, y, text) in texts.items():
158158
element = ax.text(x, y, text, ha="center", va="center", fontsize="small", bbox=bbox)
159159
elements[key] = element
160-
160+
161161
hpad = 1.2
162162
ax.set_xlim(-hpad, hspace + hpad)
163-
ax.set_ylim(-2.5, 3.1)
163+
ax.set_ylim(-2.5, 2.5)
164164
ax.set_aspect("equal")
165165

166166
ax.set_axis_off()
@@ -170,7 +170,7 @@ fig.draw_without_rendering()
170170
scale = ax.transAxes.get_matrix()[0, 0] / ax.transData.get_matrix()[0, 0]
171171
width = 2
172172
info = mpl.patches.FancyBboxPatch(
173-
(-width / 2, -3.25 * info_vspace), width, 5.75 * info_vspace, boxstyle=boxstyle, mutation_scale=1 / scale,
173+
(-width / 2, -3.25 * info_vspace), width, 5.75 * info_vspace, boxstyle=boxstyle, mutation_scale=1 / scale,
174174
color="C0", alpha=0.2,
175175
)
176176
ax.add_patch(info)
@@ -181,8 +181,8 @@ fig.draw_without_rendering()
181181
info_right = get_anchor(info, 3).x
182182
paths = [
183183
("--", [
184-
get_anchor(info, 0),
185-
(get_anchor(info, 0).x, get_anchor(elements["approx"], 9).y),
184+
get_anchor(info, 0),
185+
(get_anchor(info, 0).x, get_anchor(elements["approx"], 9).y),
186186
get_anchor(elements["approx"], 9)
187187
]),
188188
("--", [
@@ -212,7 +212,7 @@ for ls, path in paths:
212212
ax.add_patch(patch)
213213
arrow = mpl.patches.PathPatch(arrow_path(path, 0.075), color="gray")
214214
ax.add_patch(arrow)
215-
215+
216216
# Add the bidirectional arrow for large-sample correspondence.
217217
if ls == "-":
218218
arrow = mpl.patches.PathPatch(arrow_path(path, 0.075, backward=True), color="gray")
@@ -227,7 +227,7 @@ labels = [
227227
"special case",
228228
"large-sample limit",
229229
]
230-
ax.legend(handles, labels, loc="upper center")
230+
ax.legend(handles, labels, loc="lower left", bbox_to_anchor=(.9, 0))
231231

232232
label_axes(axes)
233233
fig.savefig("../workspace/figures/illustration.pdf")

notebooks/prior-dependence.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ from summaries.entropy import estimate_mutual_information
77
from scipy import stats
88

99
mpl.style.use("../.mplstyle")
10+
figwidth, figheight = mpl.rcParams["figure.figsize"]
1011
```
1112

1213
```python
@@ -55,7 +56,7 @@ num_points = 200 # Number of points in the figure (we sample more for MI estima
5556

5657
results = generate_data(m, n, scale)
5758

58-
fig, axes = plt.subplots(2, 2, sharex=True)
59+
fig, axes = plt.subplots(2, 2, sharex=True, figsize=(figwidth, 0.8 * figheight))
5960

6061
# Show the two priors.
6162
ax = axes[1, 0]
@@ -73,7 +74,7 @@ ax = axes[0, 0]
7374
lin = np.linspace(-1, 1, 100) * (1 + 3 * scale)
7475
ax.plot(lin, np.maximum(0, lin), label=r'location', color='k')
7576
ax.plot(lin, np.minimum(np.exp(lin / 2), 1), label=r'scale', color='k', ls='--')
76-
ax.set_ylabel('likelihood parameters')
77+
ax.set_ylabel('likelihood\nparameters')
7778

7879
# Plot the scatter of summaries against parameter value for both priors.
7980
step = m // num_points # Only plot `num_points` for better visualisation.

notebooks/statistics.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ from pathlib import Path
99
from scipy import stats
1010

1111
mpl.style.use("../.mplstyle")
12+
figwidth, figheight = mpl.rcParams["figure.figsize"]
1213
```
1314

1415
```python
@@ -173,7 +174,7 @@ pass
173174
```
174175

175176
```python
176-
fig = plt.figure()
177+
fig = plt.figure(figsize=(figwidth, 0.9 * figheight))
177178
gs = fig.add_gridspec(2, 2)
178179

179180
# Declare how each point should be visualized.

notebooks/tough-nuts.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ from snippets.plot import label_axes
88

99

1010
mpl.style.use("../.mplstyle")
11+
figwidth, figheight = mpl.rcParams["figure.figsize"]
1112
```
1213

1314
```python
@@ -57,7 +58,7 @@ second_moment = np.linspace(.2, .725, 101)
5758
aa, ss = np.meshgrid(a, second_moment)
5859
entropy_gain = evaluate_entropy_gain(aa, b, n, ss)
5960

60-
fig, axes = plt.subplots(1, 2)
61+
fig, axes = plt.subplots(1, 2, figsize=(figwidth, figheight * 0.8))
6162

6263
# Plot entropy gain with "centered" colorbar.
6364
ax = axes[0]

0 commit comments

Comments
 (0)