Skip to content

Commit 1b91ca1

Browse files
committed
tmp set the loading device to cuda:0
1 parent 2287486 commit 1b91ca1

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

pyvene/models/intervenable_base.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1373,7 +1373,8 @@ def load_intervention(self, load_directory, include_model=True):
13731373
intervention = v
13741374
binary_filename = f"intkey_{k}.bin"
13751375
if isinstance(intervention, TrainableIntervention):
1376-
saved_state_dict = torch.load(os.path.join(load_directory, binary_filename))
1376+
saved_state_dict = torch.load(
1377+
os.path.join(load_directory, binary_filename), map_location='cuda:0')
13771378
intervention.load_state_dict(saved_state_dict)
13781379

13791380
# load model's trainable parameters as well

0 commit comments

Comments
 (0)