Skip to content

Commit a711efe

Browse files
authored
Merge pull request #37 from chanind/pypi-package
feat: pypi packaging and auto-release with semantic release
2 parents 43421f5 + 0ff8888 commit a711efe

27 files changed

+298
-26
lines changed

.github/workflows/build.yml

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
name: build
2+
3+
on:
4+
push:
5+
branches:
6+
- main
7+
pull_request:
8+
branches:
9+
- main
10+
11+
jobs:
12+
build:
13+
runs-on: ubuntu-latest
14+
strategy:
15+
matrix:
16+
python-version: ["3.10", "3.11", "3.12"]
17+
18+
steps:
19+
- uses: actions/checkout@v4
20+
- name: Set up Python ${{ matrix.python-version }}
21+
uses: actions/setup-python@v5
22+
with:
23+
python-version: ${{ matrix.python-version }}
24+
- name: Cache Huggingface assets
25+
uses: actions/cache@v4
26+
with:
27+
key: huggingface-0-${{ runner.os }}-${{ matrix.python-version }}-${{ hashFiles('**/pyproject.toml') }}
28+
path: ~/.cache/huggingface
29+
restore-keys: |
30+
huggingface-0-${{ runner.os }}-${{ matrix.python-version }}-
31+
- name: Load cached Poetry installation
32+
id: cached-poetry
33+
uses: actions/cache@v4
34+
with:
35+
path: ~/.local # the path depends on the OS
36+
key: poetry-${{ runner.os }}-${{ matrix.python-version }}-1 # increment to reset cache
37+
- name: Install Poetry
38+
if: steps.cached-poetry.outputs.cache-hit != 'true'
39+
uses: snok/install-poetry@v1
40+
with:
41+
virtualenvs-create: true
42+
virtualenvs-in-project: true
43+
installer-parallel: true
44+
- name: Load cached venv
45+
id: cached-poetry-dependencies
46+
uses: actions/cache@v4
47+
with:
48+
path: .venv
49+
key: venv-0-${{ runner.os }}-${{ matrix.python-version }}-${{ hashFiles('**/pyproject.toml') }}
50+
restore-keys: |
51+
venv-0-${{ runner.os }}-${{ matrix.python-version }}-
52+
- name: Install dependencies
53+
if: steps.cached-poetry-dependencies.outputs.cache-hit != 'true'
54+
run: poetry install --no-interaction
55+
- name: Run Unit Tests
56+
run: poetry run pytest tests/unit
57+
- name: Build package
58+
run: poetry build
59+
60+
release:
61+
needs: build
62+
permissions:
63+
contents: write
64+
id-token: write
65+
# https://github.community/t/how-do-i-specify-job-dependency-running-in-another-workflow/16482
66+
if: github.event_name == 'push' && github.ref == 'refs/heads/main' && !contains(github.event.head_commit.message, 'chore(release):')
67+
runs-on: ubuntu-latest
68+
concurrency: release
69+
environment:
70+
name: pypi
71+
steps:
72+
- uses: actions/checkout@v4
73+
with:
74+
fetch-depth: 0
75+
- uses: actions/setup-python@v5
76+
with:
77+
python-version: "3.11"
78+
- name: Semantic Release
79+
id: release
80+
uses: python-semantic-release/python-semantic-release@v8.0.7
81+
with:
82+
github_token: ${{ secrets.GITHUB_TOKEN }}
83+
- name: Publish package distributions to PyPI
84+
uses: pypa/gh-action-pypi-publish@release/v1
85+
if: steps.release.outputs.released == 'true'
86+
- name: Publish package distributions to GitHub Releases
87+
uses: python-semantic-release/upload-to-gh-release@main
88+
if: steps.release.outputs.released == 'true'
89+
with:
90+
github_token: ${{ secrets.GITHUB_TOKEN }}

.gitignore

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ ipython_config.py
9999
# This is especially recommended for binary packages to ensure reproducibility, and is more
100100
# commonly ignored for libraries.
101101
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102-
#poetry.lock
102+
poetry.lock
103103

104104
# pdm
105105
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.

README.md

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,9 @@ Some dictionaries trained using this repository (and associated training checkpo
88

99
Navigate to the to the location where you would like to clone this repo, clone and enter the repo, and install the requirements.
1010
```bash
11-
git clone https://github.com/saprmarks/dictionary_learning
12-
cd dictionary_learning
13-
pip install -r requirements.txt
11+
pip install dictionary-learning
1412
```
1513

16-
To use `dictionary_learning`, include it as a subdirectory in some project's directory and import it; see the examples below.
17-
1814
We also provide a [demonstration](https://github.com/adamkarvonen/dictionary_learning_demo), which trains and evaluates 2 SAEs in ~30 minutes before plotting the results.
1915

2016
# Using trained dictionaries

__init__.py

Lines changed: 0 additions & 2 deletions
This file was deleted.

dictionary_learning/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
__version__ = "0.1.0"
2+
3+
from .dictionary import AutoEncoder, GatedAutoEncoder, JumpReluAutoEncoder
4+
from .buffer import ActivationBuffer
5+
6+
__all__ = ["AutoEncoder", "GatedAutoEncoder", "JumpReluAutoEncoder", "ActivationBuffer"]
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.

interp.py renamed to dictionary_learning/interp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -188,4 +188,4 @@ def feature_umap(
188188
hover_name=df.index,
189189
color=colors,
190190
)
191-
raise ValueError("n_components must be 2 or 3")
191+
raise ValueError("n_components must be 2 or 3")

trainers/__init__.py renamed to dictionary_learning/trainers/__init__.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,15 @@
55
from .top_k import TopKTrainer
66
from .jumprelu import JumpReluTrainer
77
from .batch_top_k import BatchTopKTrainer, BatchTopKSAE
8+
9+
10+
__all__ = [
11+
"StandardTrainer",
12+
"GatedSAETrainer",
13+
"PAnnealTrainer",
14+
"GatedAnnealTrainer",
15+
"TopKTrainer",
16+
"JumpReluTrainer",
17+
"BatchTopKTrainer",
18+
"BatchTopKSAE",
19+
]
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.

pyproject.toml

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
[tool.poetry]
2+
name = "dictionary-learning"
3+
version = "0.1.0"
4+
description = "Dictionary learning via sparse autoencoders on neural network activations"
5+
authors = ["Samuel Marks", "Adam Karvonen", "Aaron Mueller"]
6+
packages = [{ include = "dictionary_learning" }]
7+
license = "MIT"
8+
readme = "README.md"
9+
keywords = [
10+
"deep-learning",
11+
"sparse-autoencoders",
12+
"mechanistic-interpretability",
13+
"PyTorch",
14+
]
15+
classifiers = ["Topic :: Scientific/Engineering :: Artificial Intelligence"]
16+
repository = "https://github.com/saprmarks/dictionary_learning"
17+
homepage = "https://github.com/saprmarks/dictionary_learning"
18+
19+
20+
[tool.poetry.dependencies]
21+
python = "^3.10"
22+
circuitsvis = ">=1.43.2"
23+
datasets = ">=2.18.0"
24+
einops = ">=0.7.0"
25+
nnsight = ">=0.3.0,<0.4.0"
26+
pandas = ">=2.2.1"
27+
plotly = ">=5.18.0"
28+
tqdm = ">=4.66.1"
29+
zstandard = ">=0.22.0"
30+
wandb = ">=0.12.0"
31+
umap-learn = ">=0.5.6"
32+
llvmlite = ">=0.40.0"
33+
34+
[tool.poetry.group.dev.dependencies]
35+
pytest = "^8.3.4"
36+
37+
[build-system]
38+
requires = ["poetry-core>=2.0.0,<3.0.0"]
39+
build-backend = "poetry.core.masonry.api"
40+
41+
[tool.semantic_release]
42+
version_variables = ["dictionary_learning/__init__.py:__version__"]
43+
version_toml = ["pyproject.toml:tool.poetry.version"]
44+
branch = "main"
45+
build_command = "pip install poetry && poetry build"

requirements.txt

Lines changed: 0 additions & 13 deletions
This file was deleted.

tests/test_end_to_end.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,11 @@
77
from dictionary_learning.training import trainSAE
88
from dictionary_learning.trainers.standard import StandardTrainer
99
from dictionary_learning.trainers.top_k import TopKTrainer, AutoEncoderTopK
10-
from dictionary_learning.utils import hf_dataset_to_generator, get_nested_folders, load_dictionary
10+
from dictionary_learning.utils import (
11+
hf_dataset_to_generator,
12+
get_nested_folders,
13+
load_dictionary,
14+
)
1115
from dictionary_learning.buffer import ActivationBuffer
1216
from dictionary_learning.dictionary import (
1317
AutoEncoder,
@@ -62,10 +66,8 @@ def test_sae_training():
6266
"""End to end test for training an SAE. Takes ~2 minutes on an RTX 3090.
6367
This isn't a nice suite of unit tests, but it's better than nothing.
6468
I have observed that results can slightly vary with library versions. For full determinism,
65-
use pytorch 2.5.1 and nnsight 0.3.7.
69+
use pytorch 2.5.1 and nnsight 0.3.7."""
6670

67-
NOTE: `dictionary_learning` is meant to be used as a submodule. Thus, to run this test, you need to use `dictionary_learning` as a submodule
68-
and run the test from the root of the repository using `pytest -s`. Refer to https://github.com/adamkarvonen/dictionary_learning_demo for an example"""
6971
random.seed(RANDOM_SEED)
7072
t.manual_seed(RANDOM_SEED)
7173

tests/unit/test_dictionary.py

Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
import torch as t
2+
import pytest
3+
from dictionary_learning.dictionary import (
4+
AutoEncoder,
5+
GatedAutoEncoder,
6+
AutoEncoderNew,
7+
JumpReluAutoEncoder,
8+
)
9+
10+
11+
@pytest.mark.parametrize(
12+
"sae_cls", [AutoEncoder, GatedAutoEncoder, JumpReluAutoEncoder]
13+
)
14+
def test_forward_equals_decode_encode(sae_cls: type) -> None:
15+
"""Test that forward pass equals decode(encode(x)) for all SAE types"""
16+
batch_size = 4
17+
act_dim = 8
18+
dict_size = 6
19+
x = t.randn(batch_size, act_dim)
20+
21+
sae = sae_cls(activation_dim=act_dim, dict_size=dict_size)
22+
23+
# Test without output_features
24+
forward_out = sae(x)
25+
encode_decode = sae.decode(sae.encode(x))
26+
assert t.allclose(forward_out, encode_decode)
27+
28+
# Test with output_features
29+
forward_out, features = sae(x, output_features=True)
30+
encode_features = sae.encode(x)
31+
assert t.allclose(features, encode_features)
32+
33+
34+
def test_simple_autoencoder() -> None:
35+
"""Test AutoEncoder with simple weight matrices"""
36+
sae = AutoEncoder(activation_dim=2, dict_size=2)
37+
38+
# Set simple weights
39+
with t.no_grad():
40+
sae.encoder.weight.data = t.tensor([[1.0, 0.0], [0.0, 1.0]])
41+
sae.decoder.weight.data = t.tensor([[1.0, 0.0], [0.0, 1.0]])
42+
sae.encoder.bias.data = t.zeros(2)
43+
sae.bias.data = t.zeros(2)
44+
45+
# Test encoding
46+
x = t.tensor([[2.0, -1.0]])
47+
encoded = sae.encode(x)
48+
assert t.allclose(encoded, t.tensor([[2.0, 0.0]])) # ReLU clips negative value
49+
50+
# Test decoding
51+
decoded = sae.decode(encoded)
52+
assert t.allclose(decoded, t.tensor([[2.0, 0.0]]))
53+
54+
55+
def test_simple_gated_autoencoder() -> None:
56+
"""Test GatedAutoEncoder with simple weight matrices"""
57+
sae = GatedAutoEncoder(activation_dim=2, dict_size=2)
58+
59+
# Set simple weights and biases
60+
with t.no_grad():
61+
sae.encoder.weight.data = t.tensor([[1.0, 0.0], [0.0, 1.0]])
62+
sae.decoder.weight.data = t.tensor([[1.0, 0.0], [0.0, 1.0]])
63+
sae.gate_bias.data = t.zeros(2)
64+
sae.mag_bias.data = t.zeros(2)
65+
sae.r_mag.data = t.zeros(2)
66+
sae.decoder_bias.data = t.zeros(2)
67+
68+
x = t.tensor([[2.0, -1.0]])
69+
encoded = sae.encode(x)
70+
assert t.allclose(
71+
encoded, t.tensor([[2.0, 0.0]])
72+
) # Only positive values pass through
73+
74+
75+
def test_normalize_decoder() -> None:
76+
"""Test that normalize_decoder maintains output while normalizing weights"""
77+
sae = AutoEncoder(activation_dim=4, dict_size=3)
78+
x = t.randn(2, 4)
79+
80+
# Get initial output
81+
initial_output = sae(x)
82+
83+
# Normalize decoder
84+
sae.normalize_decoder()
85+
86+
# Check decoder weights are normalized
87+
norms = t.norm(sae.decoder.weight, dim=0)
88+
assert t.allclose(norms, t.ones_like(norms))
89+
90+
# Check output is maintained
91+
new_output = sae(x)
92+
assert t.allclose(initial_output, new_output, atol=1e-4)
93+
94+
95+
def test_scale_biases() -> None:
96+
"""Test that scale_biases correctly scales all bias terms"""
97+
sae = AutoEncoder(activation_dim=4, dict_size=3)
98+
99+
# Record initial biases
100+
initial_encoder_bias = sae.encoder.bias.data.clone()
101+
initial_bias = sae.bias.data.clone()
102+
103+
scale = 2.0
104+
sae.scale_biases(scale)
105+
106+
assert t.allclose(sae.encoder.bias.data, initial_encoder_bias * scale)
107+
assert t.allclose(sae.bias.data, initial_bias * scale)
108+
109+
110+
@pytest.mark.parametrize(
111+
"sae_cls", [AutoEncoder, GatedAutoEncoder, AutoEncoderNew, JumpReluAutoEncoder]
112+
)
113+
def test_output_shapes(sae_cls: type) -> None:
114+
"""Test that output shapes are correct for all operations"""
115+
batch_size = 3
116+
act_dim = 4
117+
dict_size = 5
118+
x = t.randn(batch_size, act_dim)
119+
120+
sae = sae_cls(activation_dim=act_dim, dict_size=dict_size)
121+
122+
# Test encode shape
123+
encoded = sae.encode(x)
124+
assert encoded.shape == (batch_size, dict_size)
125+
126+
# Test decode shape
127+
decoded = sae.decode(encoded)
128+
assert decoded.shape == (batch_size, act_dim)
129+
130+
# Test forward shape with and without features
131+
output = sae(x)
132+
assert output.shape == (batch_size, act_dim)
133+
134+
output, features = sae(x, output_features=True)
135+
assert output.shape == (batch_size, act_dim)
136+
assert features.shape == (batch_size, dict_size)

0 commit comments

Comments
 (0)