Skip to content

Commit 10f35ba

Browse files
committed
now padding kwargs into interventions
1 parent 817bd49 commit 10f35ba

File tree

2 files changed

+23
-19
lines changed

2 files changed

+23
-19
lines changed

pyvene/models/interventions.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ class ZeroIntervention(ConstantSourceIntervention, LocalistRepresentationInterve
153153
def __init__(self, **kwargs):
154154
super().__init__(**kwargs)
155155

156-
def forward(self, base, source=None, subspaces=None):
156+
def forward(self, base, source=None, subspaces=None, **kwargs):
157157
return _do_intervention_by_swap(
158158
base,
159159
torch.zeros_like(base),
@@ -175,7 +175,7 @@ class CollectIntervention(ConstantSourceIntervention):
175175
def __init__(self, **kwargs):
176176
super().__init__(**kwargs)
177177

178-
def forward(self, base, source=None, subspaces=None):
178+
def forward(self, base, source=None, subspaces=None, **kwargs):
179179
return _do_intervention_by_swap(
180180
base,
181181
source,
@@ -197,7 +197,7 @@ class SkipIntervention(BasisAgnosticIntervention, LocalistRepresentationInterven
197197
def __init__(self, **kwargs):
198198
super().__init__(**kwargs)
199199

200-
def forward(self, base, source, subspaces=None):
200+
def forward(self, base, source, subspaces=None, **kwargs):
201201
# source here is the base example input to the hook
202202
return _do_intervention_by_swap(
203203
base,
@@ -220,7 +220,7 @@ class VanillaIntervention(Intervention, LocalistRepresentationIntervention):
220220
def __init__(self, **kwargs):
221221
super().__init__(**kwargs)
222222

223-
def forward(self, base, source, subspaces=None):
223+
def forward(self, base, source, subspaces=None, **kwargs):
224224
return _do_intervention_by_swap(
225225
base,
226226
source if self.source_representation is None else self.source_representation,
@@ -242,7 +242,7 @@ class AdditionIntervention(BasisAgnosticIntervention, LocalistRepresentationInte
242242
def __init__(self, **kwargs):
243243
super().__init__(**kwargs)
244244

245-
def forward(self, base, source, subspaces=None):
245+
def forward(self, base, source, subspaces=None, **kwargs):
246246
return _do_intervention_by_swap(
247247
base,
248248
source if self.source_representation is None else self.source_representation,
@@ -264,7 +264,7 @@ class SubtractionIntervention(BasisAgnosticIntervention, LocalistRepresentationI
264264
def __init__(self, **kwargs):
265265
super().__init__(**kwargs)
266266

267-
def forward(self, base, source, subspaces=None):
267+
def forward(self, base, source, subspaces=None, **kwargs):
268268

269269
return _do_intervention_by_swap(
270270
base,
@@ -289,7 +289,7 @@ def __init__(self, **kwargs):
289289
rotate_layer = RotateLayer(self.embed_dim)
290290
self.rotate_layer = torch.nn.utils.parametrizations.orthogonal(rotate_layer)
291291

292-
def forward(self, base, source, subspaces=None):
292+
def forward(self, base, source, subspaces=None, **kwargs):
293293
rotated_base = self.rotate_layer(base)
294294
rotated_source = self.rotate_layer(source)
295295
# interchange
@@ -340,7 +340,7 @@ def set_intervention_boundaries(self, intervention_boundaries):
340340
torch.tensor([intervention_boundaries]), requires_grad=True
341341
)
342342

343-
def forward(self, base, source, subspaces=None):
343+
def forward(self, base, source, subspaces=None, **kwargs):
344344
batch_size = base.shape[0]
345345
rotated_base = self.rotate_layer(base)
346346
rotated_source = self.rotate_layer(source)
@@ -391,7 +391,7 @@ def get_temperature(self):
391391
def set_temperature(self, temp: torch.Tensor):
392392
self.temperature.data = temp
393393

394-
def forward(self, base, source, subspaces=None):
394+
def forward(self, base, source, subspaces=None, **kwargs):
395395
batch_size = base.shape[0]
396396
rotated_base = self.rotate_layer(base)
397397
rotated_source = self.rotate_layer(source)
@@ -431,7 +431,7 @@ def get_temperature(self):
431431
def set_temperature(self, temp: torch.Tensor):
432432
self.temperature.data = temp
433433

434-
def forward(self, base, source, subspaces=None):
434+
def forward(self, base, source, subspaces=None, **kwargs):
435435
batch_size = base.shape[0]
436436
# get boundary mask between 0 and 1 from sigmoid
437437
mask_sigmoid = torch.sigmoid(self.mask / torch.tensor(self.temperature))
@@ -456,7 +456,7 @@ def __init__(self, **kwargs):
456456
rotate_layer = LowRankRotateLayer(self.embed_dim, kwargs["low_rank_dimension"])
457457
self.rotate_layer = torch.nn.utils.parametrizations.orthogonal(rotate_layer)
458458

459-
def forward(self, base, source, subspaces=None):
459+
def forward(self, base, source, subspaces=None, **kwargs):
460460
rotated_base = self.rotate_layer(base)
461461
rotated_source = self.rotate_layer(source)
462462
if subspaces is not None:
@@ -529,7 +529,7 @@ def __init__(self, **kwargs):
529529
)
530530
self.trainable = False
531531

532-
def forward(self, base, source, subspaces=None):
532+
def forward(self, base, source, subspaces=None, **kwargs):
533533
base_norm = (base - self.pca_mean) / self.pca_std
534534
source_norm = (source - self.pca_mean) / self.pca_std
535535

@@ -565,7 +565,7 @@ def __init__(self, **kwargs):
565565
prng(1, 4, self.embed_dim)))
566566
self.register_buffer('noise_level', torch.tensor(noise_level))
567567

568-
def forward(self, base, source=None, subspaces=None):
568+
def forward(self, base, source=None, subspaces=None, **kwargs):
569569
base[..., : self.interchange_dim] += self.noise * self.noise_level
570570
return base
571571

@@ -585,7 +585,7 @@ def __init__(self, **kwargs):
585585
self.autoencoder = AutoencoderLayer(
586586
self.embed_dim, kwargs["latent_dim"])
587587

588-
def forward(self, base, source, subspaces=None):
588+
def forward(self, base, source, subspaces=None, **kwargs):
589589
base_dtype = base.dtype
590590
base = base.to(self.autoencoder.encoder[0].weight.dtype)
591591
base_latent = self.autoencoder.encode(base)
@@ -619,7 +619,7 @@ def encode(self, input_acts):
619619
def decode(self, acts):
620620
return acts @ self.W_dec + self.b_dec
621621

622-
def forward(self, base, source=None, subspaces=None):
622+
def forward(self, base, source=None, subspaces=None, **kwargs):
623623
# generate latents for base and source runs.
624624
base_latent = self.encode(base)
625625
source_latent = self.encode(source)

pyvene/models/modeling_utils.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -446,15 +446,19 @@ def scatter_neurons(
446446

447447

448448
def do_intervention(
449-
base_representation, source_representation, intervention, subspaces
449+
base_representation,
450+
source_representation,
451+
intervention,
452+
subspaces,
453+
**kwargs
450454
):
451455
"""Do the actual intervention."""
452456

453457
if isinstance(intervention, LambdaIntervention):
454458
if subspaces is None:
455-
return intervention(base_representation, source_representation)
459+
return intervention(base_representation, source_representation, **kwargs)
456460
else:
457-
return intervention(base_representation, source_representation, subspaces)
461+
return intervention(base_representation, source_representation, subspaces, **kwargs)
458462

459463
num_unit = base_representation.shape[1]
460464

@@ -478,7 +482,7 @@ def do_intervention(
478482
assert False # what's going on?
479483

480484
intervention_output = intervention(
481-
base_representation_f, source_representation_f, subspaces
485+
base_representation_f, source_representation_f, subspaces, **kwargs
482486
)
483487
if isinstance(intervention_output, InterventionOutput):
484488
intervened_representation = intervention_output.output

0 commit comments

Comments
 (0)