Skip to content

Commit 44bb55b

Browse files
committed
fix(cached_acts): re-implement changes mistakenly removed in 71ff9f9
1 parent 59ce8ac commit 44bb55b

File tree

1 file changed

+9
-0
lines changed

1 file changed

+9
-0
lines changed

src/lm_saes/activation/processors/cached_activation.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,15 @@ def process(self, data: None = None, **kwargs) -> Iterable[dict[str, Any]]:
240240
for k, v in activations.items():
241241
if k in self.hook_points:
242242
activations[k] = v.to(self.dtype)
243+
244+
while activations["tokens"].ndim >= 3:
245+
def flatten(x: torch.Tensor | list[list[Any]]) -> torch.Tensor | list[Any]:
246+
if isinstance(x, torch.Tensor):
247+
return x.flatten(start_dim=0, end_dim=1)
248+
else:
249+
return [a for b in x for a in b]
250+
activations = {k: flatten(v) for k, v in activations.items()}
251+
243252
yield activations # Use pin_memory to load data on cpu, then transfer them to cuda in the main process, as advised in https://discuss.pytorch.org/t/dataloader-multiprocessing-with-dataset-returning-a-cuda-tensor/151022/2.
244253
# I wrote this utils function as I notice it is used multiple times in this repo. Do we need to apply it elsewhere?
245254

0 commit comments

Comments
 (0)