14
14
15
15
from lm_saes .config import MongoDBConfig , SAEConfig
16
16
from lm_saes .database import MongoClient
17
- from lm_saes .resource_loaders import load_dataset , load_model
17
+ from lm_saes .resource_loaders import load_dataset_shard , load_model
18
18
from lm_saes .sae import SparseAutoEncoder
19
19
20
20
device = "cuda" if torch .cuda .is_available () else "cpu"
28
28
29
29
sae_cache : dict [str , SparseAutoEncoder ] = {}
30
30
lm_cache : dict [str , HookedTransformer ] = {}
31
- dataset_cache : dict [str , Dataset ] = {}
31
+ dataset_cache : dict [tuple [ str , int , int ] , Dataset ] = {}
32
32
33
33
34
34
def get_model (name : str ) -> HookedTransformer :
@@ -41,12 +41,12 @@ def get_model(name: str) -> HookedTransformer:
41
41
return lm_cache [name ]
42
42
43
43
44
- def get_dataset (name : str ) -> Dataset :
44
+ def get_dataset (name : str , shard_idx : int = 0 , n_shards : int = 1 ) -> Dataset :
45
45
cfg = client .get_dataset_cfg (name )
46
46
assert cfg is not None , f"Dataset { name } not found"
47
- if name not in dataset_cache :
48
- dataset_cache [name ] = load_dataset (cfg )[ 0 ]
49
- return dataset_cache [name ]
47
+ if ( name , shard_idx , n_shards ) not in dataset_cache :
48
+ dataset_cache [name , shard_idx , n_shards ] = load_dataset_shard (cfg , shard_idx , n_shards )
49
+ return dataset_cache [name , shard_idx , n_shards ]
50
50
51
51
52
52
def get_sae (name : str ) -> SparseAutoEncoder :
@@ -144,7 +144,11 @@ def get_feature(name: str, feature_index: str | int):
144
144
dataset_name = sampling .dataset_name [i ]
145
145
model_name = sampling .model_name [i ]
146
146
model = get_model (model_name )
147
- data = get_dataset (dataset_name )[context_idx ]
147
+ data = get_dataset (
148
+ dataset_name ,
149
+ sampling .shard_idx [i ] if sampling .shard_idx is not None else 0 ,
150
+ sampling .n_shards [i ] if sampling .n_shards is not None else 1 ,
151
+ )[context_idx ]
148
152
_ , token_origins = model .to_tokens_with_origins (data )
149
153
150
154
# Replace image_key with image_url
0 commit comments