Skip to content

Commit 8138d93

Browse files
authored
Merge pull request #219 from stanfordnlp/zen/lora
[Minor] Fix intervention loading device to default to cuda:0
2 parents dad20d8 + 1b91ca1 commit 8138d93

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

pyvene/models/intervenable_base.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1373,7 +1373,8 @@ def load_intervention(self, load_directory, include_model=True):
13731373
intervention = v
13741374
binary_filename = f"intkey_{k}.bin"
13751375
if isinstance(intervention, TrainableIntervention):
1376-
saved_state_dict = torch.load(os.path.join(load_directory, binary_filename))
1376+
saved_state_dict = torch.load(
1377+
os.path.join(load_directory, binary_filename), map_location='cuda:0')
13771378
intervention.load_state_dict(saved_state_dict)
13781379

13791380
# load model's trainable parameters as well

0 commit comments

Comments
 (0)