Skip to content

Commit ca6b8db

Browse files
authored
HookedSAETransformer (#536)
* implement HookedSAETransformer * clean up imports * apply format * only recompute error if use_error_term * add tests * run format * fix import * match to hooks API * improve doc strings * improve demo * address Arthur feedback * try to fix indent: * try to fix indent again * change doc code block
1 parent 1139caf commit ca6b8db

File tree

8 files changed

+19799
-1
lines changed

8 files changed

+19799
-1
lines changed

README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ TransformerLens lets you load in 50+ different open source language models, and
2424
activations of the model to you. You can cache any internal activation in the model, and add in
2525
functions to edit, remove or replace these activations as the model runs.
2626

27-
~~ [OCTOBER SURVEY HERE](https://forms.gle/bw7U3PfioacDtFmT8) ~~
27+
The library also now supports mechanistic interpretability with SAEs (sparse autoencoders)! With [HookedSAETransformer](https://colab.research.google.com/github/neelnanda-io/TransformerLens/blob/hooked-sae-transformer/demos/HookedSAETransformerDemo.ipynb), you can splice in SAEs during inference and cache + intervene on SAE activations. We recommend [SAELens](https://github.com/jbloomAus/SAELens) (built on top of TransformerLens) for training SAEs.
2828

2929
## Quick Start
3030

@@ -51,6 +51,7 @@ logits, activations = model.run_with_cache("Hello World")
5151
* [Introduction to the Library and Mech
5252
Interp](https://arena-ch1-transformers.streamlit.app/[1.2]_Intro_to_Mech_Interp)
5353
* [Demo of Main TransformerLens Features](https://neelnanda.io/transformer-lens-demo)
54+
* [Demo of HookedSAETransformer Features](https://colab.research.google.com/github/neelnanda-io/TransformerLens/blob/hooked-sae-transformer/demos/HookedSAETransformerDemo.ipynb)
5455

5556
## Gallery
5657

demos/HookedSAETransformerDemo.ipynb

Lines changed: 18616 additions & 0 deletions
Large diffs are not rendered by default.

tests/unit/test_hooked_sae.py

Lines changed: 191 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,191 @@
1+
import einops
2+
import pytest
3+
import torch
4+
5+
from transformer_lens import HookedSAE, HookedSAEConfig, HookedSAETransformer
6+
7+
MODEL = "solu-1l"
8+
prompt = "Hello World!"
9+
10+
11+
class Counter:
12+
def __init__(self):
13+
self.count = 0
14+
15+
def inc(self, *args, **kwargs):
16+
self.count += 1
17+
18+
19+
@pytest.fixture(scope="module")
20+
def model():
21+
model = HookedSAETransformer.from_pretrained(MODEL)
22+
yield model
23+
model.reset_saes()
24+
25+
26+
def get_sae_config(model, act_name):
27+
site_to_size = {
28+
"hook_z": model.cfg.d_head * model.cfg.n_heads,
29+
"hook_mlp_out": model.cfg.d_model,
30+
"hook_resid_pre": model.cfg.d_model,
31+
"hook_post": model.cfg.d_mlp,
32+
}
33+
site = act_name.split(".")[-1]
34+
d_in = site_to_size[site]
35+
return HookedSAEConfig(d_in=d_in, d_sae=d_in * 2, hook_name=act_name)
36+
37+
38+
@pytest.mark.parametrize(
39+
"act_name",
40+
[
41+
"blocks.0.attn.hook_z",
42+
"blocks.0.hook_mlp_out",
43+
"blocks.0.mlp.hook_post",
44+
"blocks.0.hook_resid_pre",
45+
],
46+
)
47+
def test_forward_reconstructs_input(model, act_name):
48+
"""Verfiy that the HookedSAE returns an output with the same shape as the input activations."""
49+
sae_cfg = get_sae_config(model, act_name)
50+
hooked_sae = HookedSAE(sae_cfg)
51+
52+
_, cache = model.run_with_cache(prompt, names_filter=act_name)
53+
x = cache[act_name]
54+
55+
sae_output = hooked_sae(x)
56+
assert sae_output.shape == x.shape
57+
58+
59+
@pytest.mark.parametrize(
60+
"act_name",
61+
[
62+
"blocks.0.attn.hook_z",
63+
"blocks.0.hook_mlp_out",
64+
"blocks.0.mlp.hook_post",
65+
"blocks.0.hook_resid_pre",
66+
],
67+
)
68+
def test_run_with_cache(model, act_name):
69+
"""Verifies that run_with_cache caches SAE activations"""
70+
sae_cfg = get_sae_config(model, act_name)
71+
hooked_sae = HookedSAE(sae_cfg)
72+
73+
_, cache = model.run_with_cache(prompt, names_filter=act_name)
74+
x = cache[act_name]
75+
76+
sae_output, cache = hooked_sae.run_with_cache(x)
77+
assert sae_output.shape == x.shape
78+
79+
assert "hook_sae_input" in cache
80+
assert "hook_sae_acts_pre" in cache
81+
assert "hook_sae_acts_post" in cache
82+
assert "hook_sae_recons" in cache
83+
assert "hook_sae_output" in cache
84+
85+
86+
@pytest.mark.parametrize(
87+
"act_name",
88+
[
89+
"blocks.0.attn.hook_z",
90+
"blocks.0.hook_mlp_out",
91+
"blocks.0.mlp.hook_post",
92+
"blocks.0.hook_resid_pre",
93+
],
94+
)
95+
def test_run_with_hooks(model, act_name):
96+
"""Verifies that run_with_hooks works with SAE activations"""
97+
c = Counter()
98+
sae_cfg = get_sae_config(model, act_name)
99+
hooked_sae = HookedSAE(sae_cfg)
100+
101+
_, cache = model.run_with_cache(prompt, names_filter=act_name)
102+
x = cache[act_name]
103+
104+
sae_hooks = [
105+
"hook_sae_input",
106+
"hook_sae_acts_pre",
107+
"hook_sae_acts_post",
108+
"hook_sae_recons",
109+
"hook_sae_output",
110+
]
111+
112+
sae_output = hooked_sae.run_with_hooks(
113+
x, fwd_hooks=[(sae_hook_name, c.inc) for sae_hook_name in sae_hooks]
114+
)
115+
assert sae_output.shape == x.shape
116+
117+
assert c.count == len(sae_hooks)
118+
119+
120+
@pytest.mark.parametrize(
121+
"act_name",
122+
[
123+
"blocks.0.attn.hook_z",
124+
"blocks.0.hook_mlp_out",
125+
"blocks.0.mlp.hook_post",
126+
"blocks.0.hook_resid_pre",
127+
],
128+
)
129+
def test_error_term(model, act_name):
130+
"""Verifies that that if we use error_terms, HookedSAE returns an output that is equal to the input activations."""
131+
sae_cfg = get_sae_config(model, act_name)
132+
sae_cfg.use_error_term = True
133+
hooked_sae = HookedSAE(sae_cfg)
134+
135+
_, cache = model.run_with_cache(prompt, names_filter=act_name)
136+
x = cache[act_name]
137+
138+
sae_output = hooked_sae(x)
139+
assert sae_output.shape == x.shape
140+
assert torch.allclose(sae_output, x, atol=1e-6)
141+
142+
143+
# %%
144+
@pytest.mark.parametrize(
145+
"act_name",
146+
[
147+
"blocks.0.attn.hook_z",
148+
"blocks.0.hook_mlp_out",
149+
"blocks.0.mlp.hook_post",
150+
"blocks.0.hook_resid_pre",
151+
],
152+
)
153+
def test_feature_grads_with_error_term(model, act_name):
154+
"""Verifies that pytorch backward computes the correct feature gradients when using error_terms. Motivated by the need to compute feature gradients for attribution patching."""
155+
156+
# Load SAE
157+
sae_cfg = get_sae_config(model, act_name)
158+
sae_cfg.use_error_term = True
159+
hooked_sae = HookedSAE(sae_cfg)
160+
161+
# Get input activations
162+
_, cache = model.run_with_cache(prompt, names_filter=act_name)
163+
x = cache[act_name]
164+
165+
# Cache gradients with respect to feature acts
166+
hooked_sae.reset_hooks()
167+
grad_cache = {}
168+
169+
def backward_cache_hook(act, hook):
170+
grad_cache[hook.name] = act.detach()
171+
172+
hooked_sae.add_hook("hook_sae_acts_post", backward_cache_hook, "bwd")
173+
hooked_sae.add_hook("hook_sae_output", backward_cache_hook, "bwd")
174+
175+
sae_output = hooked_sae(x)
176+
assert torch.allclose(sae_output, x, atol=1e-6)
177+
value = sae_output.sum()
178+
value.backward()
179+
hooked_sae.reset_hooks()
180+
181+
# Compute gradient analytically
182+
if act_name.endswith("hook_z"):
183+
reshaped_output_grad = einops.rearrange(
184+
grad_cache["hook_sae_output"], "... n_heads d_head -> ... (n_heads d_head)"
185+
)
186+
analytic_grad = reshaped_output_grad @ hooked_sae.W_dec.T
187+
else:
188+
analytic_grad = grad_cache["hook_sae_output"] @ hooked_sae.W_dec.T
189+
190+
# Compare analytic gradient with pytorch computed gradient
191+
assert torch.allclose(grad_cache["hook_sae_acts_post"], analytic_grad, atol=1e-6)

0 commit comments

Comments
 (0)