Skip to content

Commit b2bb4b5

Browse files
authored
Merge pull request #44 from andyrdt/andyrdt/bos_bugfix
safer remove_bos logic
2 parents d639166 + 59abc88 commit b2bb4b5

File tree

2 files changed

+11
-10
lines changed

2 files changed

+11
-10
lines changed

dictionary_learning/buffer.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def __init__(self,
5656
self.refresh_batch_size = refresh_batch_size
5757
self.out_batch_size = out_batch_size
5858
self.device = device
59-
self.remove_bos = remove_bos
59+
self.remove_bos = remove_bos and (self.model.tokenizer.bos_token_id is not None)
6060
self.add_special_tokens = add_special_tokens
6161

6262
def __iter__(self):
@@ -133,14 +133,15 @@ def refresh(self):
133133
input = self.model.inputs.save()
134134

135135
self.submodule.output.stop()
136-
attn_mask = input.value[1]["attention_mask"]
136+
137+
mask = (input.value[1]["attention_mask"] != 0)
137138
hidden_states = hidden_states.value
138139
if isinstance(hidden_states, tuple):
139140
hidden_states = hidden_states[0]
140141
if self.remove_bos:
141-
hidden_states = hidden_states[:, 1:, :]
142-
attn_mask = attn_mask[:, 1:]
143-
hidden_states = hidden_states[attn_mask != 0]
142+
bos_mask = (input.value[1]["input_ids"] == self.model.tokenizer.bos_token_id)
143+
mask = mask & ~bos_mask
144+
hidden_states = hidden_states[mask]
144145

145146
remaining_space = self.activation_buffer_size - current_idx
146147
assert remaining_space > 0

dictionary_learning/pytorch_buffer.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -117,9 +117,9 @@ def __init__(
117117
self.refresh_batch_size = refresh_batch_size
118118
self.out_batch_size = out_batch_size
119119
self.device = device
120-
self.remove_bos = remove_bos
121120
self.add_special_tokens = add_special_tokens
122121
self.tokenizer = AutoTokenizer.from_pretrained(model.name_or_path)
122+
self.remove_bos = remove_bos and (self.tokenizer.bos_token_id is not None)
123123

124124
if not self.tokenizer.pad_token:
125125
self.tokenizer.pad_token = self.tokenizer.eos_token
@@ -192,11 +192,11 @@ def refresh(self):
192192
with t.no_grad():
193193
input = self.tokenized_batch()
194194
hidden_states = collect_activations(self.model, self.submodule, input)
195-
attn_mask = input["attention_mask"]
195+
mask = (input["attention_mask"] != 0)
196196
if self.remove_bos:
197-
hidden_states = hidden_states[:, 1:, :]
198-
attn_mask = attn_mask[:, 1:]
199-
hidden_states = hidden_states[attn_mask != 0]
197+
bos_mask = (input["input_ids"] == self.tokenizer.bos_token_id)
198+
mask = mask & ~bos_mask
199+
hidden_states = hidden_states[mask]
200200

201201
remaining_space = self.activation_buffer_size - current_idx
202202
assert remaining_space > 0

0 commit comments

Comments
 (0)