|
| 1 | +import einops |
| 2 | +import pytest |
| 3 | +import torch |
| 4 | + |
| 5 | +from transformer_lens import HookedSAE, HookedSAEConfig, HookedSAETransformer |
| 6 | + |
| 7 | +MODEL = "solu-1l" |
| 8 | +prompt = "Hello World!" |
| 9 | + |
| 10 | + |
| 11 | +class Counter: |
| 12 | + def __init__(self): |
| 13 | + self.count = 0 |
| 14 | + |
| 15 | + def inc(self, *args, **kwargs): |
| 16 | + self.count += 1 |
| 17 | + |
| 18 | + |
| 19 | +@pytest.fixture(scope="module") |
| 20 | +def model(): |
| 21 | + model = HookedSAETransformer.from_pretrained(MODEL) |
| 22 | + yield model |
| 23 | + model.reset_saes() |
| 24 | + |
| 25 | + |
| 26 | +def get_sae_config(model, act_name): |
| 27 | + site_to_size = { |
| 28 | + "hook_z": model.cfg.d_head * model.cfg.n_heads, |
| 29 | + "hook_mlp_out": model.cfg.d_model, |
| 30 | + "hook_resid_pre": model.cfg.d_model, |
| 31 | + "hook_post": model.cfg.d_mlp, |
| 32 | + } |
| 33 | + site = act_name.split(".")[-1] |
| 34 | + d_in = site_to_size[site] |
| 35 | + return HookedSAEConfig(d_in=d_in, d_sae=d_in * 2, hook_name=act_name) |
| 36 | + |
| 37 | + |
| 38 | +@pytest.mark.parametrize( |
| 39 | + "act_name", |
| 40 | + [ |
| 41 | + "blocks.0.attn.hook_z", |
| 42 | + "blocks.0.hook_mlp_out", |
| 43 | + "blocks.0.mlp.hook_post", |
| 44 | + "blocks.0.hook_resid_pre", |
| 45 | + ], |
| 46 | +) |
| 47 | +def test_forward_reconstructs_input(model, act_name): |
| 48 | + """Verfiy that the HookedSAE returns an output with the same shape as the input activations.""" |
| 49 | + sae_cfg = get_sae_config(model, act_name) |
| 50 | + hooked_sae = HookedSAE(sae_cfg) |
| 51 | + |
| 52 | + _, cache = model.run_with_cache(prompt, names_filter=act_name) |
| 53 | + x = cache[act_name] |
| 54 | + |
| 55 | + sae_output = hooked_sae(x) |
| 56 | + assert sae_output.shape == x.shape |
| 57 | + |
| 58 | + |
| 59 | +@pytest.mark.parametrize( |
| 60 | + "act_name", |
| 61 | + [ |
| 62 | + "blocks.0.attn.hook_z", |
| 63 | + "blocks.0.hook_mlp_out", |
| 64 | + "blocks.0.mlp.hook_post", |
| 65 | + "blocks.0.hook_resid_pre", |
| 66 | + ], |
| 67 | +) |
| 68 | +def test_run_with_cache(model, act_name): |
| 69 | + """Verifies that run_with_cache caches SAE activations""" |
| 70 | + sae_cfg = get_sae_config(model, act_name) |
| 71 | + hooked_sae = HookedSAE(sae_cfg) |
| 72 | + |
| 73 | + _, cache = model.run_with_cache(prompt, names_filter=act_name) |
| 74 | + x = cache[act_name] |
| 75 | + |
| 76 | + sae_output, cache = hooked_sae.run_with_cache(x) |
| 77 | + assert sae_output.shape == x.shape |
| 78 | + |
| 79 | + assert "hook_sae_input" in cache |
| 80 | + assert "hook_sae_acts_pre" in cache |
| 81 | + assert "hook_sae_acts_post" in cache |
| 82 | + assert "hook_sae_recons" in cache |
| 83 | + assert "hook_sae_output" in cache |
| 84 | + |
| 85 | + |
| 86 | +@pytest.mark.parametrize( |
| 87 | + "act_name", |
| 88 | + [ |
| 89 | + "blocks.0.attn.hook_z", |
| 90 | + "blocks.0.hook_mlp_out", |
| 91 | + "blocks.0.mlp.hook_post", |
| 92 | + "blocks.0.hook_resid_pre", |
| 93 | + ], |
| 94 | +) |
| 95 | +def test_run_with_hooks(model, act_name): |
| 96 | + """Verifies that run_with_hooks works with SAE activations""" |
| 97 | + c = Counter() |
| 98 | + sae_cfg = get_sae_config(model, act_name) |
| 99 | + hooked_sae = HookedSAE(sae_cfg) |
| 100 | + |
| 101 | + _, cache = model.run_with_cache(prompt, names_filter=act_name) |
| 102 | + x = cache[act_name] |
| 103 | + |
| 104 | + sae_hooks = [ |
| 105 | + "hook_sae_input", |
| 106 | + "hook_sae_acts_pre", |
| 107 | + "hook_sae_acts_post", |
| 108 | + "hook_sae_recons", |
| 109 | + "hook_sae_output", |
| 110 | + ] |
| 111 | + |
| 112 | + sae_output = hooked_sae.run_with_hooks( |
| 113 | + x, fwd_hooks=[(sae_hook_name, c.inc) for sae_hook_name in sae_hooks] |
| 114 | + ) |
| 115 | + assert sae_output.shape == x.shape |
| 116 | + |
| 117 | + assert c.count == len(sae_hooks) |
| 118 | + |
| 119 | + |
| 120 | +@pytest.mark.parametrize( |
| 121 | + "act_name", |
| 122 | + [ |
| 123 | + "blocks.0.attn.hook_z", |
| 124 | + "blocks.0.hook_mlp_out", |
| 125 | + "blocks.0.mlp.hook_post", |
| 126 | + "blocks.0.hook_resid_pre", |
| 127 | + ], |
| 128 | +) |
| 129 | +def test_error_term(model, act_name): |
| 130 | + """Verifies that that if we use error_terms, HookedSAE returns an output that is equal to the input activations.""" |
| 131 | + sae_cfg = get_sae_config(model, act_name) |
| 132 | + sae_cfg.use_error_term = True |
| 133 | + hooked_sae = HookedSAE(sae_cfg) |
| 134 | + |
| 135 | + _, cache = model.run_with_cache(prompt, names_filter=act_name) |
| 136 | + x = cache[act_name] |
| 137 | + |
| 138 | + sae_output = hooked_sae(x) |
| 139 | + assert sae_output.shape == x.shape |
| 140 | + assert torch.allclose(sae_output, x, atol=1e-6) |
| 141 | + |
| 142 | + |
| 143 | +# %% |
| 144 | +@pytest.mark.parametrize( |
| 145 | + "act_name", |
| 146 | + [ |
| 147 | + "blocks.0.attn.hook_z", |
| 148 | + "blocks.0.hook_mlp_out", |
| 149 | + "blocks.0.mlp.hook_post", |
| 150 | + "blocks.0.hook_resid_pre", |
| 151 | + ], |
| 152 | +) |
| 153 | +def test_feature_grads_with_error_term(model, act_name): |
| 154 | + """Verifies that pytorch backward computes the correct feature gradients when using error_terms. Motivated by the need to compute feature gradients for attribution patching.""" |
| 155 | + |
| 156 | + # Load SAE |
| 157 | + sae_cfg = get_sae_config(model, act_name) |
| 158 | + sae_cfg.use_error_term = True |
| 159 | + hooked_sae = HookedSAE(sae_cfg) |
| 160 | + |
| 161 | + # Get input activations |
| 162 | + _, cache = model.run_with_cache(prompt, names_filter=act_name) |
| 163 | + x = cache[act_name] |
| 164 | + |
| 165 | + # Cache gradients with respect to feature acts |
| 166 | + hooked_sae.reset_hooks() |
| 167 | + grad_cache = {} |
| 168 | + |
| 169 | + def backward_cache_hook(act, hook): |
| 170 | + grad_cache[hook.name] = act.detach() |
| 171 | + |
| 172 | + hooked_sae.add_hook("hook_sae_acts_post", backward_cache_hook, "bwd") |
| 173 | + hooked_sae.add_hook("hook_sae_output", backward_cache_hook, "bwd") |
| 174 | + |
| 175 | + sae_output = hooked_sae(x) |
| 176 | + assert torch.allclose(sae_output, x, atol=1e-6) |
| 177 | + value = sae_output.sum() |
| 178 | + value.backward() |
| 179 | + hooked_sae.reset_hooks() |
| 180 | + |
| 181 | + # Compute gradient analytically |
| 182 | + if act_name.endswith("hook_z"): |
| 183 | + reshaped_output_grad = einops.rearrange( |
| 184 | + grad_cache["hook_sae_output"], "... n_heads d_head -> ... (n_heads d_head)" |
| 185 | + ) |
| 186 | + analytic_grad = reshaped_output_grad @ hooked_sae.W_dec.T |
| 187 | + else: |
| 188 | + analytic_grad = grad_cache["hook_sae_output"] @ hooked_sae.W_dec.T |
| 189 | + |
| 190 | + # Compare analytic gradient with pytorch computed gradient |
| 191 | + assert torch.allclose(grad_cache["hook_sae_acts_post"], analytic_grad, atol=1e-6) |
0 commit comments