Skip to content

Commit ddd9edb

Browse files
authored
Merge pull request #6 from baconsaur/main
Reduce memory overhead when capturing tensors
2 parents abc2879 + 4aad8e3 commit ddd9edb

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

llm_steer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,8 @@ def _add_steer_vector(self, layer_idx: int, steerElem: SteerElement):
227227

228228
def _capture_tensor(self, layer_idx: int, tokens: Tensor):
229229
self._set_forward_fn(ActivationMode.CAPTURE, layer_idx)
230-
self.model(tokens)
230+
with torch.inference_mode():
231+
self.model(tokens)
231232
result = self.captured_tensor
232233
print(f"captured tensor: {result}")
233234
return result

0 commit comments

Comments
 (0)