Skip to content

Commit 834f1e3

Browse files
authored
Merge pull request #93 from OpenMOSS/zf_fix
fix backend missing-shard-idx bug which causes incorrect context inde…
2 parents 56a80d8 + 2eba74a commit 834f1e3

File tree

1 file changed

+11
-7
lines changed

1 file changed

+11
-7
lines changed

server/app.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
from lm_saes.config import MongoDBConfig, SAEConfig
1616
from lm_saes.database import MongoClient
17-
from lm_saes.resource_loaders import load_dataset, load_model
17+
from lm_saes.resource_loaders import load_dataset_shard, load_model
1818
from lm_saes.sae import SparseAutoEncoder
1919

2020
device = "cuda" if torch.cuda.is_available() else "cpu"
@@ -28,7 +28,7 @@
2828

2929
sae_cache: dict[str, SparseAutoEncoder] = {}
3030
lm_cache: dict[str, HookedTransformer] = {}
31-
dataset_cache: dict[str, Dataset] = {}
31+
dataset_cache: dict[tuple[str, int, int], Dataset] = {}
3232

3333

3434
def get_model(name: str) -> HookedTransformer:
@@ -41,12 +41,12 @@ def get_model(name: str) -> HookedTransformer:
4141
return lm_cache[name]
4242

4343

44-
def get_dataset(name: str) -> Dataset:
44+
def get_dataset(name: str, shard_idx: int = 0, n_shards: int = 1) -> Dataset:
4545
cfg = client.get_dataset_cfg(name)
4646
assert cfg is not None, f"Dataset {name} not found"
47-
if name not in dataset_cache:
48-
dataset_cache[name] = load_dataset(cfg)[0]
49-
return dataset_cache[name]
47+
if (name, shard_idx, n_shards) not in dataset_cache:
48+
dataset_cache[name, shard_idx, n_shards] = load_dataset_shard(cfg, shard_idx, n_shards)
49+
return dataset_cache[name, shard_idx, n_shards]
5050

5151

5252
def get_sae(name: str) -> SparseAutoEncoder:
@@ -144,7 +144,11 @@ def get_feature(name: str, feature_index: str | int):
144144
dataset_name = sampling.dataset_name[i]
145145
model_name = sampling.model_name[i]
146146
model = get_model(model_name)
147-
data = get_dataset(dataset_name)[context_idx]
147+
data = get_dataset(
148+
dataset_name,
149+
sampling.shard_idx[i] if sampling.shard_idx is not None else 0,
150+
sampling.n_shards[i] if sampling.n_shards is not None else 1,
151+
)[context_idx]
148152
_, token_origins = model.to_tokens_with_origins(data)
149153

150154
# Replace image_key with image_url

0 commit comments

Comments
 (0)