Skip to content

Commit 5357128

Browse files
authored
Merge pull request #217 from stanfordnlp/zen/lora
[Minor] Making intervention takes input the same shape as the output
2 parents beebdc5 + 2287486 commit 5357128

File tree

1 file changed

+11
-2
lines changed

1 file changed

+11
-2
lines changed

pyvene/models/intervenable_base.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1574,8 +1574,17 @@ def hook_callback(model, args, kwargs, output=None):
15741574
selected_output = selected_output.clone()
15751575

15761576
if self.as_adaptor:
1577-
intervention_additional_kwargs["args"] = args
1578-
intervention_additional_kwargs["kwargs"] = kwargs
1577+
adaptor_input = None
1578+
if len(args) == 0: # kwargs based calls
1579+
# PR: https://github.com/frankaging/align-transformers/issues/11
1580+
# We cannot assume the dict only contain one element
1581+
adaptor_input = kwargs[list(kwargs.keys())[0]]
1582+
else:
1583+
adaptor_input = args
1584+
selected_input = self._gather_intervention_output(
1585+
adaptor_input, key, unit_locations_base[key_i]
1586+
)
1587+
intervention_additional_kwargs["args"] = selected_input
15791588

15801589
if isinstance(
15811590
intervention,

0 commit comments

Comments
 (0)