Skip to content

Commit 743ff47

Browse files
authored
Merge pull request #90 from OpenMOSS/update-source-dtype-fix
fix(activation): preserve tokens type during dtype conversion
2 parents 83acd31 + 6b84c49 commit 743ff47

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

src/lm_saes/activation/processors/cached_activation.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -237,7 +237,9 @@ def process(self, data: None = None, **kwargs) -> Iterable[dict[str, Any]]:
237237
device=self.device,
238238
)
239239
if self.dtype is not None:
240-
activations = {k: v.to(self.dtype) for k, v in activations.items()}
240+
for k, v in activations.items():
241+
if k in self.hook_points:
242+
activations[k] = v.to(self.dtype)
241243
yield activations # Use pin_memory to load data on cpu, then transfer them to cuda in the main process, as advised in https://discuss.pytorch.org/t/dataloader-multiprocessing-with-dataset-returning-a-cuda-tensor/151022/2.
242244
# I wrote this utils function as I notice it is used multiple times in this repo. Do we need to apply it elsewhere?
243245

@@ -259,4 +261,4 @@ def __getitem__(self, chunk_idx):
259261
return self.activation_loader.load_chunk_for_hooks(
260262
chunk_idx,
261263
self.hook_chunks,
262-
)
264+
)

0 commit comments

Comments
 (0)