Skip to content

fix(analyze): misc things stopping analysis working #92

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Feb 18, 2025
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,12 @@ dependencies = [
"python-dotenv>=1.0.1",
"jaxtyping>=0.2.34",
"safetensors>=0.4.5",
"pydantic>=2.9.2",
"pydantic>=2.10.6",
"argparse>=1.4.0",
"pyyaml>=6.0.2",
"tomlkit>=0.13.2",
"torchvision>=0.20.1",
"pydantic-settings>=2.7.1",
"pydantic-settings>=2.7.1",
]
requires-python = "==3.12.*"
readme = "README.md"
Expand Down
12 changes: 11 additions & 1 deletion server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ async def oom_error_handler(request, exc):

@app.get("/dictionaries")
def list_dictionaries():
return client.list_saes(sae_series=sae_series)
return client.list_saes(sae_series=sae_series, has_analyses=True)


@app.get("/images/{dataset_name}/{context_idx}/{image_idx}")
Expand Down Expand Up @@ -154,6 +154,16 @@ def get_feature(name: str, feature_index: str | int):
del data[image_key]
data["images"] = image_urls

token_origins = token_origins[: len(feature_acts)]
feature_acts = feature_acts[: len(token_origins)]

if "text" in data:
text_ranges = [
origin["range"] for origin in token_origins if origin is not None and origin["key"] == "text"
]
max_text_origin = max(text_ranges, key=lambda x: x[1])
data["text"] = data["text"][: max_text_origin[1]]

samples.append(
{
**data,
Expand Down
16 changes: 14 additions & 2 deletions src/lm_saes/activation/processors/cached_activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,12 +232,24 @@ def process(self, data: None = None, **kwargs) -> Iterable[dict[str, Any]]:

stream = self._process_chunks(hook_chunks, len(hook_chunks[self.hook_points[0]]))
for chunk in stream:
activations = move_dict_of_tensor_to_device(
activations: dict[str, Any] = move_dict_of_tensor_to_device(
chunk,
device=self.device,
)
if self.dtype is not None:
activations = {k: v.to(self.dtype) for k, v in activations.items()}

# De-batch activations if tokens are more or equal to 3 dimensions
while activations["tokens"].ndim >= 3:

def flatten(x: torch.Tensor | list[list[Any]]) -> torch.Tensor | list[Any]:
if isinstance(x, torch.Tensor):
return x.flatten(start_dim=0, end_dim=1)
else:
return [a for b in x for a in b]

activations = {k: flatten(v) for k, v in activations.items()}

yield activations # Use pin_memory to load data on cpu, then transfer them to cuda in the main process, as advised in https://discuss.pytorch.org/t/dataloader-multiprocessing-with-dataset-returning-a-cuda-tensor/151022/2.
# I wrote this utils function as I notice it is used multiple times in this repo. Do we need to apply it elsewhere?

Expand All @@ -259,4 +271,4 @@ def __getitem__(self, chunk_idx):
return self.activation_loader.load_chunk_for_hooks(
chunk_idx,
self.hook_chunks,
)
)
2 changes: 1 addition & 1 deletion src/lm_saes/analysis/feature_analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def analyze_chunk(
meta = {k: [m[k] for m in batch["meta"]] for k in batch["meta"][0].keys()}

# Get feature activations from SAE
feature_acts = sae.encode(batch[sae.cfg.hook_point_in])
feature_acts = sae.encode(batch[sae.cfg.hook_point_in], tokens=batch["tokens"])
# Update activation statistics
act_times += feature_acts.gt(0.0).sum(dim=[0, 1])
max_feature_acts = torch.max(max_feature_acts, feature_acts.max(dim=0).values.max(dim=0).values)
Expand Down
13 changes: 11 additions & 2 deletions src/lm_saes/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,8 +163,17 @@ def list_analyses(self, sae_name: str, sae_series: str | None = None) -> list[st
)
]

def list_saes(self, sae_series: str | None = None) -> list[str]:
return [d["name"] for d in self.sae_collection.find({"series": sae_series} if sae_series is not None else {})]
def list_saes(self, sae_series: str | None = None, has_analyses: bool = False) -> list[str]:
sae_names = [
d["name"] for d in self.sae_collection.find({"series": sae_series} if sae_series is not None else {})
]
if has_analyses:
sae_names = [
d["sae_name"]
for d in self.analysis_collection.find({"sae_series": sae_series} if sae_series is not None else {})
]
sae_names = list(set(sae_names))
return sae_names

def get_dataset(self, name: str) -> Optional[DatasetRecord]:
dataset = self.dataset_collection.find_one({"name": name})
Expand Down
29 changes: 16 additions & 13 deletions src/lm_saes/mixcoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,10 +96,11 @@ def _get_full_state_dict(self):

@torch.no_grad()
def _load_full_state_dict(self, state_dict: dict[str, torch.Tensor]) -> None:
super()._load_full_state_dict(state_dict)
modality_indices_keys = [k for k in state_dict.keys() if k.startswith("modality_indices.")]
assert len(modality_indices_keys) == len(self.cfg.modalities) - 1 # shared modality is not included
self.modality_indices = {key.split(".", 1)[1]: state_dict[key] for key in modality_indices_keys}
state_dict = {k: v for k, v in state_dict.items() if not k.startswith("modality_indices.")}
super()._load_full_state_dict(state_dict)

@torch.no_grad()
def transform_to_unit_decoder_norm(self):
Expand Down Expand Up @@ -130,7 +131,7 @@ def standardize_parameters_of_dataset_norm(self, dataset_average_activation_norm
self.cfg.norm_activation = "inference"

@torch.no_grad()
def init_encoder_with_decoder_transpose(self, factor: float = 1.):
def init_encoder_with_decoder_transpose(self, factor: float = 1.0):
for modality in self.cfg.modalities.keys():
self._init_encoder_with_decoder_transpose(self.encoder[modality], self.decoder[modality], factor)

Expand Down Expand Up @@ -256,16 +257,16 @@ def encode(
"""
assert "tokens" in kwargs
tokens = kwargs["tokens"]
feature_acts = torch.zeros(x.shape[0], self.cfg.d_sae, device=x.device, dtype=x.dtype)
hidden_pre = torch.zeros(x.shape[0], self.cfg.d_sae, device=x.device, dtype=x.dtype)
feature_acts = torch.zeros(*x.shape[:-1], self.cfg.d_sae, device=x.device, dtype=x.dtype)
hidden_pre = torch.zeros(*x.shape[:-1], self.cfg.d_sae, device=x.device, dtype=x.dtype)
input_norm_factor = self.compute_norm_factor(x, hook_point=self.cfg.hook_point_in)
x = x * input_norm_factor
for modality, (start, end) in self.modality_index.items():
x_temp = x
if modality == "shared":
# shared modality is not encoded directly but summed up during other modalities' encoding
continue
activation_mask = self.get_modality_token_mask(tokens, modality).unsqueeze(1)
activation_mask = self.get_modality_token_mask(tokens, modality).unsqueeze(-1)
if self.cfg.use_decoder_bias and self.cfg.apply_decoder_bias_to_pre_encoder:
modality_bias = (
self.decoder[modality].bias.to_local() # TODO: check if this is correct # type: ignore
Expand Down Expand Up @@ -296,20 +297,22 @@ def encode(
true_feature_acts_shared = hidden_pre_shared

true_feature_acts_concat = (
torch.cat([true_feature_acts_modality, true_feature_acts_shared], dim=1) * activation_mask
torch.cat([true_feature_acts_modality, true_feature_acts_shared], dim=-1) * activation_mask
)
activation_mask_concat = self.activation_function(true_feature_acts_concat)
feature_acts_concat = true_feature_acts_concat * activation_mask_concat
feature_acts_modality = feature_acts_concat[:, : self.cfg.modalities[modality]]
feature_acts_shared = feature_acts_concat[:, self.cfg.modalities[modality] :]
assert feature_acts_shared.shape[1] == self.cfg.modalities["shared"]
feature_acts_modality = feature_acts_concat[..., : self.cfg.modalities[modality]]
feature_acts_shared = feature_acts_concat[..., self.cfg.modalities[modality] :]
assert (
feature_acts_shared.shape[-1] == self.cfg.modalities["shared"]
), f"{feature_acts_shared.shape} does not match {self.cfg.modalities['shared']}. {feature_acts_concat.shape[-1]} != {self.cfg.modalities['shared']}."

feature_acts[:, start:end] += feature_acts_modality
hidden_pre[:, start:end] += hidden_pre_modality
feature_acts[..., start:end] += feature_acts_modality
hidden_pre[..., start:end] += hidden_pre_modality

shared_start, shared_end = self.modality_index["shared"]
feature_acts[:, shared_start:shared_end] += feature_acts_shared
hidden_pre[:, shared_start:shared_end] += hidden_pre_shared
feature_acts[..., shared_start:shared_end] += feature_acts_shared
hidden_pre[..., shared_start:shared_end] += hidden_pre_shared

hidden_pre = self.hook_hidden_pre(hidden_pre)
feature_acts = self.hook_feature_acts(feature_acts)
Expand Down
11 changes: 10 additions & 1 deletion src/lm_saes/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,20 @@
ActivationWriterConfig,
BaseSAEConfig,
BufferShuffleConfig,
CrossCoderConfig,
DatasetConfig,
FeatureAnalyzerConfig,
InitializerConfig,
LanguageModelConfig,
MixCoderConfig,
MongoDBConfig,
TrainerConfig,
WandbConfig,
)
from lm_saes.crosscoder import CrossCoder
from lm_saes.database import MongoClient
from lm_saes.initializer import Initializer
from lm_saes.mixcoder import MixCoder
from lm_saes.resource_loaders import load_dataset, load_model
from lm_saes.sae import SparseAutoEncoder
from lm_saes.trainer import Trainer
Expand Down Expand Up @@ -406,7 +410,12 @@ def analyze_sae(settings: AnalyzeSAESettings) -> None:
mongo_client = MongoClient(settings.mongo)
activation_factory = ActivationFactory(settings.activation_factory)

sae = SparseAutoEncoder.from_config(settings.sae)
if isinstance(settings.sae, MixCoderConfig):
sae = MixCoder.from_config(settings.sae)
elif isinstance(settings.sae, CrossCoderConfig):
sae = CrossCoder.from_config(settings.sae)
else:
sae = SparseAutoEncoder.from_config(settings.sae)

analyzer = FeatureAnalyzer(settings.analyzer)

Expand Down
2 changes: 1 addition & 1 deletion ui/src/components/feature/sample.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ export const FeatureActivationSample = ({ sample, sampleName, maxFeatureAct }: F
<span
className={cn("relative cursor-help", getAccentClassname(maxSegmentAct, maxFeatureAct, "bg"))}
>
{segmentText}
{segmentText.replaceAll("\n", "↵").replaceAll("\t", "→")}
</span>
</HoverCardTrigger>
<HoverCardContent>
Expand Down
4 changes: 2 additions & 2 deletions ui/tsconfig.json
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
{
"compilerOptions": {
"target": "ES2020",
"target": "ES2021",
"useDefineForClassFields": true,
"lib": ["ES2020", "DOM", "DOM.Iterable"],
"lib": ["ES2021", "DOM", "DOM.Iterable"],
"module": "ESNext",
"skipLibCheck": true,

Expand Down
Loading