Skip to content

Commit 56a80d8

Browse files
authored
Merge pull request #92 from OpenMOSS/fix-analysis
fix(analyze): misc things stopping analysis working
2 parents 718de4e + b1e8d1d commit 56a80d8

File tree

10 files changed

+72
-51
lines changed

10 files changed

+72
-51
lines changed

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,12 @@ dependencies = [
2424
"python-dotenv>=1.0.1",
2525
"jaxtyping>=0.2.34",
2626
"safetensors>=0.4.5",
27-
"pydantic>=2.9.2",
27+
"pydantic>=2.10.6",
2828
"argparse>=1.4.0",
2929
"pyyaml>=6.0.2",
3030
"tomlkit>=0.13.2",
3131
"torchvision>=0.20.1",
32-
"pydantic-settings>=2.7.1",
32+
"pydantic-settings>=2.7.1",
3333
]
3434
requires-python = "==3.12.*"
3535
readme = "README.md"

server/app.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ async def oom_error_handler(request, exc):
8888

8989
@app.get("/dictionaries")
9090
def list_dictionaries():
91-
return client.list_saes(sae_series=sae_series)
91+
return client.list_saes(sae_series=sae_series, has_analyses=True)
9292

9393

9494
@app.get("/images/{dataset_name}/{context_idx}/{image_idx}")
@@ -154,6 +154,16 @@ def get_feature(name: str, feature_index: str | int):
154154
del data[image_key]
155155
data["images"] = image_urls
156156

157+
token_origins = token_origins[: len(feature_acts)]
158+
feature_acts = feature_acts[: len(token_origins)]
159+
160+
if "text" in data:
161+
text_ranges = [
162+
origin["range"] for origin in token_origins if origin is not None and origin["key"] == "text"
163+
]
164+
max_text_origin = max(text_ranges, key=lambda x: x[1])
165+
data["text"] = data["text"][: max_text_origin[1]]
166+
157167
samples.append(
158168
{
159169
**data,

src/lm_saes/activation/processors/cached_activation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,7 @@ def process(self, data: None = None, **kwargs) -> Iterable[dict[str, Any]]:
232232

233233
stream = self._process_chunks(hook_chunks, len(hook_chunks[self.hook_points[0]]))
234234
for chunk in stream:
235-
activations = move_dict_of_tensor_to_device(
235+
activations: dict[str, Any] = move_dict_of_tensor_to_device(
236236
chunk,
237237
device=self.device,
238238
)

src/lm_saes/analysis/feature_analyzer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ def analyze_chunk(
160160
meta = {k: [m[k] for m in batch["meta"]] for k in batch["meta"][0].keys()}
161161

162162
# Get feature activations from SAE
163-
feature_acts = sae.encode(batch[sae.cfg.hook_point_in])
163+
feature_acts = sae.encode(batch[sae.cfg.hook_point_in], tokens=batch["tokens"])
164164
# Update activation statistics
165165
act_times += feature_acts.gt(0.0).sum(dim=[0, 1])
166166
max_feature_acts = torch.max(max_feature_acts, feature_acts.max(dim=0).values.max(dim=0).values)

src/lm_saes/database.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -163,8 +163,17 @@ def list_analyses(self, sae_name: str, sae_series: str | None = None) -> list[st
163163
)
164164
]
165165

166-
def list_saes(self, sae_series: str | None = None) -> list[str]:
167-
return [d["name"] for d in self.sae_collection.find({"series": sae_series} if sae_series is not None else {})]
166+
def list_saes(self, sae_series: str | None = None, has_analyses: bool = False) -> list[str]:
167+
sae_names = [
168+
d["name"] for d in self.sae_collection.find({"series": sae_series} if sae_series is not None else {})
169+
]
170+
if has_analyses:
171+
sae_names = [
172+
d["sae_name"]
173+
for d in self.analysis_collection.find({"sae_series": sae_series} if sae_series is not None else {})
174+
]
175+
sae_names = list(set(sae_names))
176+
return sae_names
168177

169178
def get_dataset(self, name: str) -> Optional[DatasetRecord]:
170179
dataset = self.dataset_collection.find_one({"name": name})

src/lm_saes/mixcoder.py

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -96,10 +96,11 @@ def _get_full_state_dict(self):
9696

9797
@torch.no_grad()
9898
def _load_full_state_dict(self, state_dict: dict[str, torch.Tensor]) -> None:
99-
super()._load_full_state_dict(state_dict)
10099
modality_indices_keys = [k for k in state_dict.keys() if k.startswith("modality_indices.")]
101100
assert len(modality_indices_keys) == len(self.cfg.modalities) - 1 # shared modality is not included
102101
self.modality_indices = {key.split(".", 1)[1]: state_dict[key] for key in modality_indices_keys}
102+
state_dict = {k: v for k, v in state_dict.items() if not k.startswith("modality_indices.")}
103+
super()._load_full_state_dict(state_dict)
103104

104105
@torch.no_grad()
105106
def transform_to_unit_decoder_norm(self):
@@ -130,7 +131,7 @@ def standardize_parameters_of_dataset_norm(self, dataset_average_activation_norm
130131
self.cfg.norm_activation = "inference"
131132

132133
@torch.no_grad()
133-
def init_encoder_with_decoder_transpose(self, factor: float = 1.):
134+
def init_encoder_with_decoder_transpose(self, factor: float = 1.0):
134135
for modality in self.cfg.modalities.keys():
135136
self._init_encoder_with_decoder_transpose(self.encoder[modality], self.decoder[modality], factor)
136137

@@ -256,16 +257,16 @@ def encode(
256257
"""
257258
assert "tokens" in kwargs
258259
tokens = kwargs["tokens"]
259-
feature_acts = torch.zeros(x.shape[0], self.cfg.d_sae, device=x.device, dtype=x.dtype)
260-
hidden_pre = torch.zeros(x.shape[0], self.cfg.d_sae, device=x.device, dtype=x.dtype)
260+
feature_acts = torch.zeros(*x.shape[:-1], self.cfg.d_sae, device=x.device, dtype=x.dtype)
261+
hidden_pre = torch.zeros(*x.shape[:-1], self.cfg.d_sae, device=x.device, dtype=x.dtype)
261262
input_norm_factor = self.compute_norm_factor(x, hook_point=self.cfg.hook_point_in)
262263
x = x * input_norm_factor
263264
for modality, (start, end) in self.modality_index.items():
264265
x_temp = x
265266
if modality == "shared":
266267
# shared modality is not encoded directly but summed up during other modalities' encoding
267268
continue
268-
activation_mask = self.get_modality_token_mask(tokens, modality).unsqueeze(1)
269+
activation_mask = self.get_modality_token_mask(tokens, modality).unsqueeze(-1)
269270
if self.cfg.use_decoder_bias and self.cfg.apply_decoder_bias_to_pre_encoder:
270271
modality_bias = (
271272
self.decoder[modality].bias.to_local() # TODO: check if this is correct # type: ignore
@@ -296,20 +297,22 @@ def encode(
296297
true_feature_acts_shared = hidden_pre_shared
297298

298299
true_feature_acts_concat = (
299-
torch.cat([true_feature_acts_modality, true_feature_acts_shared], dim=1) * activation_mask
300+
torch.cat([true_feature_acts_modality, true_feature_acts_shared], dim=-1) * activation_mask
300301
)
301302
activation_mask_concat = self.activation_function(true_feature_acts_concat)
302303
feature_acts_concat = true_feature_acts_concat * activation_mask_concat
303-
feature_acts_modality = feature_acts_concat[:, : self.cfg.modalities[modality]]
304-
feature_acts_shared = feature_acts_concat[:, self.cfg.modalities[modality] :]
305-
assert feature_acts_shared.shape[1] == self.cfg.modalities["shared"]
304+
feature_acts_modality = feature_acts_concat[..., : self.cfg.modalities[modality]]
305+
feature_acts_shared = feature_acts_concat[..., self.cfg.modalities[modality] :]
306+
assert (
307+
feature_acts_shared.shape[-1] == self.cfg.modalities["shared"]
308+
), f"{feature_acts_shared.shape} does not match {self.cfg.modalities['shared']}. {feature_acts_concat.shape[-1]} != {self.cfg.modalities['shared']}."
306309

307-
feature_acts[:, start:end] += feature_acts_modality
308-
hidden_pre[:, start:end] += hidden_pre_modality
310+
feature_acts[..., start:end] += feature_acts_modality
311+
hidden_pre[..., start:end] += hidden_pre_modality
309312

310313
shared_start, shared_end = self.modality_index["shared"]
311-
feature_acts[:, shared_start:shared_end] += feature_acts_shared
312-
hidden_pre[:, shared_start:shared_end] += hidden_pre_shared
314+
feature_acts[..., shared_start:shared_end] += feature_acts_shared
315+
hidden_pre[..., shared_start:shared_end] += hidden_pre_shared
313316

314317
hidden_pre = self.hook_hidden_pre(hidden_pre)
315318
feature_acts = self.hook_feature_acts(feature_acts)

src/lm_saes/runner.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,12 @@
1818
ActivationWriterConfig,
1919
BaseSAEConfig,
2020
BufferShuffleConfig,
21+
CrossCoderConfig,
2122
DatasetConfig,
2223
FeatureAnalyzerConfig,
2324
InitializerConfig,
2425
LanguageModelConfig,
26+
MixCoderConfig,
2527
MongoDBConfig,
2628
TrainerConfig,
2729
WandbConfig,
@@ -408,15 +410,12 @@ def analyze_sae(settings: AnalyzeSAESettings) -> None:
408410
mongo_client = MongoClient(settings.mongo)
409411
activation_factory = ActivationFactory(settings.activation_factory)
410412

411-
if settings.sae.sae_type == "sae":
412-
sae = SparseAutoEncoder.from_config(settings.sae)
413-
elif settings.sae.sae_type == "crosscoder":
414-
sae = CrossCoder.from_config(settings.sae)
415-
elif settings.sae.sae_type == "mixcoder":
413+
if isinstance(settings.sae, MixCoderConfig):
416414
sae = MixCoder.from_config(settings.sae)
415+
elif isinstance(settings.sae, CrossCoderConfig):
416+
sae = CrossCoder.from_config(settings.sae)
417417
else:
418-
# TODO: add support for different SAE config types, e.g. MixCoderConfig, CrossCoderConfig, etc.
419-
raise ValueError(f"SAE type {settings.sae.sae_type} not supported.")
418+
sae = SparseAutoEncoder.from_config(settings.sae)
420419

421420
analyzer = FeatureAnalyzer(settings.analyzer)
422421

ui/src/components/feature/sample.tsx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ export const FeatureActivationSample = ({ sample, sampleName, maxFeatureAct }: F
131131
<span
132132
className={cn("relative cursor-help", getAccentClassname(maxSegmentAct, maxFeatureAct, "bg"))}
133133
>
134-
{segmentText}
134+
{segmentText.replaceAll("\n", "↵").replaceAll("\t", "→")}
135135
</span>
136136
</HoverCardTrigger>
137137
<HoverCardContent>

ui/tsconfig.json

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
{
22
"compilerOptions": {
3-
"target": "ES2020",
3+
"target": "ES2021",
44
"useDefineForClassFields": true,
5-
"lib": ["ES2020", "DOM", "DOM.Iterable"],
5+
"lib": ["ES2021", "DOM", "DOM.Iterable"],
66
"module": "ESNext",
77
"skipLibCheck": true,
88

0 commit comments

Comments
 (0)