Skip to content

Commit 718de4e

Browse files
authored
Merge pull request #91 from OpenMOSS/implement_analyze_dispatch
implement dynamic dispatch for sae instantiation in analyze_sae(runner)
2 parents 743ff47 + 09ffee8 commit 718de4e

File tree

2 files changed

+36
-14
lines changed

2 files changed

+36
-14
lines changed

src/lm_saes/runner.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,10 @@
2626
TrainerConfig,
2727
WandbConfig,
2828
)
29+
from lm_saes.crosscoder import CrossCoder
2930
from lm_saes.database import MongoClient
3031
from lm_saes.initializer import Initializer
32+
from lm_saes.mixcoder import MixCoder
3133
from lm_saes.resource_loaders import load_dataset, load_model
3234
from lm_saes.sae import SparseAutoEncoder
3335
from lm_saes.trainer import Trainer
@@ -406,7 +408,15 @@ def analyze_sae(settings: AnalyzeSAESettings) -> None:
406408
mongo_client = MongoClient(settings.mongo)
407409
activation_factory = ActivationFactory(settings.activation_factory)
408410

409-
sae = SparseAutoEncoder.from_config(settings.sae)
411+
if settings.sae.sae_type == "sae":
412+
sae = SparseAutoEncoder.from_config(settings.sae)
413+
elif settings.sae.sae_type == "crosscoder":
414+
sae = CrossCoder.from_config(settings.sae)
415+
elif settings.sae.sae_type == "mixcoder":
416+
sae = MixCoder.from_config(settings.sae)
417+
else:
418+
# TODO: add support for different SAE config types, e.g. MixCoderConfig, CrossCoderConfig, etc.
419+
raise ValueError(f"SAE type {settings.sae.sae_type} not supported.")
410420

411421
analyzer = FeatureAnalyzer(settings.analyzer)
412422

src/lm_saes/trainer.py

Lines changed: 25 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -137,14 +137,16 @@ def _log(self, sae: SparseAutoEncoder, log_info: dict, batch: dict[str, Tensor])
137137
"sparsity/below_1e-5": (feature_sparsity < 1e-5).sum().item(),
138138
"sparsity/below_1e-6": (feature_sparsity < 1e-6).sum().item(),
139139
}
140-
if sae.cfg.sae_type == 'crosscoder':
141-
overall_act_freq_scores = all_reduce_tensor(feature_sparsity, aggregate='max')
142-
wandb_log_dict.update({
143-
"sparsity/overall_above_1e-1": (overall_act_freq_scores > 1e-1).sum().item(),
144-
"sparsity/overall_above_1e-2": (overall_act_freq_scores > 1e-2).sum().item(),
145-
"sparsity/overall_below_1e-5": (overall_act_freq_scores < 1e-5).sum().item(),
146-
"sparsity/overall_below_1e-6": (overall_act_freq_scores < 1e-6).sum().item(),
147-
})
140+
if sae.cfg.sae_type == "crosscoder":
141+
overall_act_freq_scores = all_reduce_tensor(feature_sparsity, aggregate="max")
142+
wandb_log_dict.update(
143+
{
144+
"sparsity/overall_above_1e-1": (overall_act_freq_scores > 1e-1).sum().item(),
145+
"sparsity/overall_above_1e-2": (overall_act_freq_scores > 1e-2).sum().item(),
146+
"sparsity/overall_below_1e-5": (overall_act_freq_scores < 1e-5).sum().item(),
147+
"sparsity/overall_below_1e-6": (overall_act_freq_scores < 1e-6).sum().item(),
148+
}
149+
)
148150

149151
self.wandb_logger.log(wandb_log_dict, step=self.cur_step + 1)
150152
log_info["act_freq_scores"] = torch.zeros_like(log_info["act_freq_scores"])
@@ -161,7 +163,11 @@ def _log(self, sae: SparseAutoEncoder, log_info: dict, batch: dict[str, Tensor])
161163
wandb_log_dict = {
162164
# losses
163165
"losses/mse_loss": l_rec.item(),
164-
**({"losses/sparsity_loss": log_info["l_s"].mean().item()} if log_info.get("l_s", None) is not None else {}),
166+
**(
167+
{"losses/sparsity_loss": log_info["l_s"].mean().item()}
168+
if log_info.get("l_s", None) is not None
169+
else {}
170+
),
165171
"losses/overall_loss": log_info["loss"].item(),
166172
# variance explained
167173
"metrics/explained_variance": explained_variance.mean().item(),
@@ -179,10 +185,16 @@ def _log(self, sae: SparseAutoEncoder, log_info: dict, batch: dict[str, Tensor])
179185
"details/n_training_tokens": self.cur_tokens,
180186
}
181187
wandb_log_dict.update(sae.log_statistics())
182-
if sae.cfg.sae_type == 'crosscoder':
183-
wandb_log_dict.update({
184-
"metrics/overall_l0": all_reduce_tensor(log_info["feature_acts"], aggregate='max').gt(0).float().sum(-1).mean()
185-
})
188+
if sae.cfg.sae_type == "crosscoder":
189+
wandb_log_dict.update(
190+
{
191+
"metrics/overall_l0": all_reduce_tensor(log_info["feature_acts"], aggregate="max")
192+
.gt(0)
193+
.float()
194+
.sum(-1)
195+
.mean()
196+
}
197+
)
186198
elif sae.cfg.sae_type == "mixcoder":
187199
assert isinstance(sae, MixCoder)
188200
for modality, (start, end) in sae.modality_index.items():

0 commit comments

Comments
 (0)