Skip to content

Commit d7753a4

Browse files
committed
fix(activation): preserve tokens type during dtype conversion
1 parent 83acd31 commit d7753a4

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

src/lm_saes/activation/processors/cached_activation.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -237,7 +237,7 @@ def process(self, data: None = None, **kwargs) -> Iterable[dict[str, Any]]:
237237
device=self.device,
238238
)
239239
if self.dtype is not None:
240-
activations = {k: v.to(self.dtype) for k, v in activations.items()}
240+
activations = {k: v.to(self.dtype) if k != "tokens" else v for k, v in activations.items()}
241241
yield activations # Use pin_memory to load data on cpu, then transfer them to cuda in the main process, as advised in https://discuss.pytorch.org/t/dataloader-multiprocessing-with-dataset-returning-a-cuda-tensor/151022/2.
242242
# I wrote this utils function as I notice it is used multiple times in this repo. Do we need to apply it elsewhere?
243243

@@ -259,4 +259,4 @@ def __getitem__(self, chunk_idx):
259259
return self.activation_loader.load_chunk_for_hooks(
260260
chunk_idx,
261261
self.hook_chunks,
262-
)
262+
)

0 commit comments

Comments
 (0)