@@ -805,7 +805,7 @@ def _intervention_setter(
805
805
keys ,
806
806
unit_locations_base ,
807
807
subspaces ,
808
- intervention_additional_kwargs : Optional [ Dict ] = None ,
808
+ intervention_additional_kwargs ,
809
809
) -> HandlerList :
810
810
"""
811
811
Create a list of setter tracer that will set activations
@@ -1528,7 +1528,7 @@ def _intervention_setter(
1528
1528
keys ,
1529
1529
unit_locations_base ,
1530
1530
subspaces ,
1531
- intervention_additional_kwargs : Optional [ Dict ] = None ,
1531
+ intervention_additional_kwargs ,
1532
1532
) -> HandlerList :
1533
1533
"""
1534
1534
Create a list of setter handlers that will set activations
@@ -1763,7 +1763,7 @@ def _wait_for_forward_with_parallel_intervention(
1763
1763
]
1764
1764
if subspaces is not None
1765
1765
else None ,
1766
- intervention_additional_kwargs ,
1766
+ intervention_additional_kwargs = intervention_additional_kwargs ,
1767
1767
)
1768
1768
# for setters, we don't remove them.
1769
1769
all_set_handlers .extend (set_handlers )
@@ -1775,6 +1775,7 @@ def _wait_for_forward_with_serial_intervention(
1775
1775
unit_locations ,
1776
1776
activations_sources : Optional [Dict ] = None ,
1777
1777
subspaces : Optional [List ] = None ,
1778
+ intervention_additional_kwargs : Optional [Dict ] = None ,
1778
1779
):
1779
1780
all_set_handlers = HandlerList ([])
1780
1781
for group_id , keys in self ._intervention_group .items ():
@@ -1831,7 +1832,7 @@ def _wait_for_forward_with_serial_intervention(
1831
1832
]
1832
1833
if subspaces is not None
1833
1834
else None ,
1834
- intervention_additional_kwargs ,
1835
+ intervention_additional_kwargs = intervention_additional_kwargs ,
1835
1836
)
1836
1837
# for setters, we don't remove them.
1837
1838
all_set_handlers .extend (set_handlers )
0 commit comments