Skip to content

Commit 83acd31

Browse files
authored
Merge pull request #89 from OpenMOSS/implement_jan_jumprelu
Anthropic Jan Jumprelu; triton spmm decoding; miscs
2 parents bc57195 + 631dad7 commit 83acd31

File tree

13 files changed

+931
-97
lines changed

13 files changed

+931
-97
lines changed

pyproject.toml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ dependencies = [
2929
"pyyaml>=6.0.2",
3030
"tomlkit>=0.13.2",
3131
"torchvision>=0.20.1",
32-
"pydantic-settings>=2.7.1",
32+
"pydantic-settings>=2.7.1",
3333
]
3434
requires-python = "==3.12.*"
3535
readme = "README.md"
@@ -84,6 +84,9 @@ dev = [
8484
"pyfakefs>=5.7.3",
8585
"mongomock>=4.3.0",
8686
]
87+
triton = [
88+
"triton>=3.1.0",
89+
]
8790

8891
[tool.ruff]
8992
exclude = [".bzr", ".direnv", ".eggs", ".git", ".git-rewrite", ".hg", ".ipynb_checkpoints", ".mypy_cache", ".nox", ".pants.d", ".pyenv", ".pytest_cache", ".pytype", ".ruff_cache", ".svn", ".tox", ".venv", ".vscode", "__pypackages__", "_build", "buck-out", "build", "dist", "node_modules", "site-packages", "venv", "TransformerLens", "ui"]

src/lm_saes/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,12 +29,14 @@
2929
generate_activations,
3030
train_sae,
3131
)
32+
from .sae import SparseAutoEncoder
3233

3334
__all__ = [
3435
"ActivationFactory",
3536
"ActivationWriter",
3637
"CrossCoderConfig",
3738
"CrossCoder",
39+
"SparseAutoEncoder",
3840
"LanguageModelConfig",
3941
"DatasetConfig",
4042
"ActivationFactoryActivationsSource",

src/lm_saes/activation/factory.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
ActivationFactoryDatasetSource,
2323
ActivationFactoryTarget,
2424
)
25-
from lm_saes.utils.concurrent import BackgroundGenerator
2625

2726

2827
class ActivationFactory:
@@ -124,6 +123,7 @@ def _build_pre_aggregation_activations_source_processors(
124123
cache_dir=activations_source.path,
125124
hook_points=cfg.hook_points,
126125
device=activations_source.device,
126+
dtype=activations_source.dtype,
127127
num_workers=activations_source.num_workers,
128128
prefetch_factor=activations_source.prefetch,
129129
)
@@ -149,9 +149,6 @@ def process_activations(**kwargs: Any):
149149
for processor in processors:
150150
stream = processor.process(stream, ignore_token_ids=cfg.ignore_token_ids, model=model)
151151

152-
if activations_source.prefetch is not None:
153-
stream = BackgroundGenerator(stream, max_prefetch=activations_source.prefetch)
154-
155152
return stream
156153

157154
return process_activations

src/lm_saes/activation/processors/cached_activation.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import re
33
from dataclasses import dataclass
44
from pathlib import Path
5-
from typing import Any, Iterable, Iterator, Sequence
5+
from typing import Any, Iterable, Iterator, Optional, Sequence
66

77
import torch
88
from safetensors.torch import load_file
@@ -84,12 +84,14 @@ def __init__(
8484
cache_dir: str | Path,
8585
hook_points: list[str],
8686
device: str = "cpu",
87+
dtype: Optional[torch.dtype] = None,
8788
num_workers: int = 0,
8889
prefetch_factor: int | None = None,
8990
):
9091
self.cache_dir = Path(cache_dir)
9192
self.hook_points = hook_points
9293
self.device = device
94+
self.dtype = dtype
9395
self.num_workers = num_workers
9496
self.prefetch_factor = prefetch_factor
9597

@@ -230,10 +232,13 @@ def process(self, data: None = None, **kwargs) -> Iterable[dict[str, Any]]:
230232

231233
stream = self._process_chunks(hook_chunks, len(hook_chunks[self.hook_points[0]]))
232234
for chunk in stream:
233-
yield move_dict_of_tensor_to_device(
235+
activations = move_dict_of_tensor_to_device(
234236
chunk,
235237
device=self.device,
236-
) # Use pin_memory to load data on cpu, then transfer them to cuda in the main process, as advised in https://discuss.pytorch.org/t/dataloader-multiprocessing-with-dataset-returning-a-cuda-tensor/151022/2.
238+
)
239+
if self.dtype is not None:
240+
activations = {k: v.to(self.dtype) for k, v in activations.items()}
241+
yield activations # Use pin_memory to load data on cpu, then transfer them to cuda in the main process, as advised in https://discuss.pytorch.org/t/dataloader-multiprocessing-with-dataset-returning-a-cuda-tensor/151022/2.
237242
# I wrote this utils function as I notice it is used multiple times in this repo. Do we need to apply it elsewhere?
238243

239244

src/lm_saes/config.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,9 +60,15 @@ class BaseSAEConfig(BaseModelConfig):
6060
apply_decoder_bias_to_pre_encoder: bool = False
6161
norm_activation: str = "dataset-wise"
6262
sparsity_include_decoder_norm: bool = True
63+
force_unit_decoder_norm: bool = False
6364
top_k: int = 50
6465
sae_pretrained_name_or_path: Optional[str] = None
6566
strict_loading: bool = True
67+
use_triton_kernel: bool = False
68+
sparsity_threshold_for_triton_spmm_kernel: float = 0.99
69+
70+
# anthropic jumprelu
71+
jumprelu_threshold_window: float = 2.0
6672

6773
@property
6874
def d_sae(self) -> int:
@@ -113,18 +119,25 @@ def d_sae(self) -> int:
113119

114120
class InitializerConfig(BaseConfig):
115121
bias_init_method: str = "all_zero"
122+
const_times_for_init_b_e: int = 10000
116123
init_decoder_norm: float | None = None
124+
decoder_uniform_bound: float = 1.
117125
init_encoder_norm: float | None = None
126+
encoder_uniform_bound: float = 1.
118127
init_encoder_with_decoder_transpose: bool = True
119-
init_search: bool = True
128+
init_encoder_with_decoder_transpose_factor: float = 1.
129+
init_log_jumprelu_threshold_value: float | None = None
130+
init_search: bool = False
120131
state: Literal["training", "inference"] = "training"
121132
l1_coefficient: float | None = 0.00008
122133

123134

124135
class TrainerConfig(BaseConfig):
125-
lp: int = 1
126136
l1_coefficient: float | None = 0.00008
127137
l1_coefficient_warmup_steps: int | float = 0.1
138+
sparsity_loss_type: Literal["power", "tanh", None] = None
139+
tanh_stretch_coefficient: float = 4.0
140+
p: int = 1
128141
initial_k: int | float | None = None
129142
k_warmup_steps: int | float = 0.1
130143
use_batch_norm_mse: bool = True
@@ -194,11 +207,25 @@ class ActivationFactoryDatasetSource(ActivationFactorySource):
194207

195208

196209
class ActivationFactoryActivationsSource(ActivationFactorySource):
210+
model_config = ConfigDict(arbitrary_types_allowed=True) # allow parsing torch.dtype
211+
197212
type: str = "activations"
198213
path: str
199214
""" The path to the cached activations. """
200215
device: str = "cpu"
201216
""" The device to load the activations on. """
217+
dtype: Optional[Annotated[
218+
torch.dtype,
219+
BeforeValidator(lambda v: convert_str_to_torch_dtype(v) if isinstance(v, str) else v),
220+
PlainSerializer(convert_torch_dtype_to_str),
221+
WithJsonSchema(
222+
{
223+
"type": "string",
224+
},
225+
mode="serialization",
226+
),
227+
]] = None
228+
""" We might want to convert presaved bf16 activations to fp32"""
202229
num_workers: int = 4
203230
""" The number of workers to use for loading the activations. """
204231
prefetch: Optional[int] = 8

src/lm_saes/crosscoder.py

Lines changed: 83 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Literal, Union, overload
1+
from typing import Callable, Literal, Union, cast, overload
22

33
import torch
44
from jaxtyping import Float
@@ -32,6 +32,56 @@ def _decoder_norm(self, decoder: torch.nn.Linear, keepdim: bool = False, local_o
3232
)
3333
return decoder_norm
3434

35+
def activation_function_factory(self) -> Callable[[torch.Tensor], torch.Tensor]:
36+
assert self.cfg.act_fn.lower() in [
37+
"relu",
38+
"topk",
39+
"jumprelu",
40+
"batchtopk",
41+
], f"Not implemented activation function {self.cfg.act_fn}"
42+
if self.cfg.act_fn.lower() == "jumprelu":
43+
44+
class STEFunction(torch.autograd.Function):
45+
@staticmethod
46+
def forward(ctx, input: torch.Tensor, log_jumprelu_threshold: torch.Tensor):
47+
jumprelu_threshold = log_jumprelu_threshold.exp()
48+
jumprelu_threshold = all_reduce_tensor(jumprelu_threshold, aggregate="sum")
49+
ctx.save_for_backward(input, jumprelu_threshold)
50+
return input.gt(jumprelu_threshold).to(input.dtype)
51+
52+
@staticmethod
53+
def backward(ctx, *grad_outputs: torch.Tensor):
54+
assert len(grad_outputs) == 1
55+
grad_output = grad_outputs[0]
56+
57+
input, jumprelu_threshold = ctx.saved_tensors
58+
grad_input = torch.zeros_like(input)
59+
grad_log_jumprelu_threshold_unscaled = torch.where(
60+
(input - jumprelu_threshold).abs() < self.cfg.jumprelu_threshold_window * 0.5,
61+
-jumprelu_threshold / self.cfg.jumprelu_threshold_window,
62+
0.0,
63+
)
64+
grad_log_jumprelu_threshold = (
65+
grad_log_jumprelu_threshold_unscaled
66+
/ torch.where(
67+
((input - jumprelu_threshold).abs() < self.cfg.jumprelu_threshold_window * 0.5)
68+
* (input != 0.0),
69+
input,
70+
1.0,
71+
)
72+
* grad_output
73+
)
74+
grad_log_jumprelu_threshold = grad_log_jumprelu_threshold.sum(
75+
dim=tuple(range(grad_log_jumprelu_threshold.ndim - 1))
76+
)
77+
78+
return grad_input, grad_log_jumprelu_threshold
79+
80+
return lambda x: cast(torch.Tensor, STEFunction.apply(x, self.log_jumprelu_threshold))
81+
82+
else:
83+
return super().activation_function_factory()
84+
3585
@overload
3686
def encode(
3787
self,
@@ -109,14 +159,14 @@ def encode(
109159
hidden_pre = self.hook_hidden_pre(hidden_pre)
110160

111161
if self.cfg.sparsity_include_decoder_norm:
112-
true_feature_acts = hidden_pre * self._decoder_norm(
162+
sparsity_scores = hidden_pre * self._decoder_norm(
113163
decoder=self.decoder,
114164
local_only=True,
115165
)
116166
else:
117-
true_feature_acts = hidden_pre
167+
sparsity_scores = hidden_pre
118168

119-
activation_mask = self.activation_function(true_feature_acts)
169+
activation_mask = self.activation_function(sparsity_scores)
120170
feature_acts = hidden_pre * activation_mask
121171

122172
feature_acts = self.hook_feature_acts(feature_acts)
@@ -131,7 +181,9 @@ def compute_loss(
131181
batch: dict[str, torch.Tensor],
132182
*,
133183
use_batch_norm_mse: bool = False,
134-
lp: int = 1,
184+
sparsity_loss_type: Literal["power", "tanh", None] = None,
185+
tanh_stretch_coefficient: float = 4.0,
186+
p: int = 1,
135187
return_aux_data: Literal[True] = True,
136188
**kwargs,
137189
) -> tuple[
@@ -145,7 +197,9 @@ def compute_loss(
145197
batch: dict[str, torch.Tensor],
146198
*,
147199
use_batch_norm_mse: bool = False,
148-
lp: int = 1,
200+
sparsity_loss_type: Literal["power", "tanh", None] = None,
201+
tanh_stretch_coefficient: float = 4.0,
202+
p: int = 1,
149203
return_aux_data: Literal[False],
150204
**kwargs,
151205
) -> Float[torch.Tensor, " batch"]: ...
@@ -162,7 +216,9 @@ def compute_loss(
162216
) = None,
163217
*,
164218
use_batch_norm_mse: bool = False,
165-
lp: int = 1,
219+
sparsity_loss_type: Literal["power", "tanh", None] = None,
220+
tanh_stretch_coefficient: float = 4.0,
221+
p: int = 1,
166222
return_aux_data: bool = True,
167223
**kwargs,
168224
) -> Union[
@@ -194,25 +250,31 @@ def compute_loss(
194250
.sqrt()
195251
)
196252

197-
l_rec = l_rec.mean()
198-
l_rec = all_reduce_tensor(l_rec, aggregate="mean")
253+
l_rec = l_rec.sum(dim=-1).mean()
199254

200255
loss = l_rec
201256
loss_dict = {
202257
"l_rec": l_rec,
203258
}
204259

205-
# l_l1: (batch,)
206-
feature_acts = feature_acts * self._decoder_norm(
207-
decoder=self.decoder,
208-
local_only=True,
209-
)
210-
211-
if "topk" not in self.cfg.act_fn:
212-
l_lp = torch.norm(feature_acts, p=lp, dim=-1)
213-
loss_dict["l_lp"] = l_lp
260+
if sparsity_loss_type == "power":
261+
l_s = torch.norm(feature_acts * self._decoder_norm(decoder=self.decoder), p=p, dim=-1)
262+
loss_dict["l_s"] = self.current_l1_coefficient * l_s.mean()
214263
assert self.current_l1_coefficient is not None
215-
loss = loss + self.current_l1_coefficient * l_lp.mean()
264+
loss = loss + self.current_l1_coefficient * l_s.mean()
265+
elif sparsity_loss_type == "tanh":
266+
l_s = torch.tanh(tanh_stretch_coefficient * feature_acts * self._decoder_norm(decoder=self.decoder)).sum(
267+
dim=-1
268+
)
269+
loss_dict["l_s"] = self.current_l1_coefficient * l_s.mean()
270+
assert self.current_l1_coefficient is not None
271+
loss = loss + self.current_l1_coefficient * l_s.mean()
272+
elif sparsity_loss_type is None:
273+
pass
274+
else:
275+
raise ValueError(f"sparsity_loss_type f{sparsity_loss_type} not supported.")
276+
277+
loss = all_reduce_tensor(loss, aggregate="mean")
216278

217279
if return_aux_data:
218280
aux_data = {
@@ -229,7 +291,8 @@ def compute_loss(
229291

230292
@torch.no_grad()
231293
def log_statistics(self):
232-
return {}
294+
assert self.dataset_average_activation_norm is not None
295+
return {f"info/{k}": v for k, v in self.dataset_average_activation_norm.items()}
233296

234297
def initialize_with_same_weight_across_layers(self):
235298
self.encoder.weight.data = get_tensor_from_specific_rank(self.encoder.weight.data.clone(), src=0)

0 commit comments

Comments
 (0)