Skip to content

jacknewsom/task_vectors

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

1 Commit
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

task_vectors: Tools for extracting, training, and evaluating task vectors

This repository contains task_vectors, a Python package for task vectors.

About task vectors

Task vectors (also called function vectors in some papers) were discovered by a few groups (e.g. Todd et al. and Hendel et al.) in late 2023.

Generally, task vectors can be understood as structures within language models corresponding to particular "tasks" or functions. These structures are often coaxed out of the model using ICL, then applied and tested in a zero-shot context.

In my experiments, I've found that a "soft" formulation of task vectors is generally the most performant. Instead of implementing the REINFORCE-based replacement rule from Hojel et al., smoothly blending each attention head with its mean value for a given task has much better performance across a wide variety of tasks.

Why are task vectors useful?

Task vectors can serve several purposes, including:

  • A constructive (rather than extractive) interpretability tool. As long as your task can be "learned" via ICL, task vectors can be used to identify neurons and circuits responsible for it
  • A very parameter-efficient fine-tuning tool. Although not as effective as more parameter-heavy methods like LoRA, task vectors can be used to improve zero-shot performance on many classes of tasks while using a single parameter (or fewer) per attention head. For example, task vectors for Llama-3.1 405B can be formulated with 16,128 parameters (< $4 \cdot 10^{-4} %$).

Interestingly, when formulated correctly, task vectors can actually exceed the performance provided by direct ICL in some cases.

When are task vectors not useful?

Task vectors have some shortcomings:

  • Cannot be discovered unsupervisedly in the same pattern as sparse autoencoders. You must know which "task" you are trying to extract a task vector first
  • They are not zero-cost as a fine-tuning tool. Unlike LoRA and related methods, task vector formulations generally require knowledge of both the current and mean activation of a given attention head. Soft task vectors, for example, introduce another addition operation that prevents the task vector weights from being directly merged into model weights, adding (slight) overhead to the process.

Installation

pip install git+https://github.com/jacknewsom/task_vectors

Usage

The main utility of this package is TaskVectorExtractor, which wraps most of the common operations you'll perform with task vectors. See experiments/train.py for an end-to-end example.

Interventions

To implement the soft task vector intervention (see above)

$$h' = h \cdot (1 - w) + h_\mu \cdot w,$$

task_vectors introduces the Intervention, effectively a wrapper around PyTorch's register_forward_hook. Interventions allow you to edit the outputs of any nn.Module by specifying a Hook to be called on it. For example, the SaveInputsHook will save the inputs to all specified targets for as long as the Intervention context is active:

...
savehook = SaveInputsHook()

with Intervention(targets={"moduleA": moduleA, "moduleB": moduleB}, hook):
    model.generate(**inputs, max_new_tokens=4)
...
>>> len(savehook["moduleA"])
4
>>> len(savehook["moduleB"])
4

To implement a custom Intervention on some target module, subclass Hook and implement __call__. Intervention sets up your Hook to run similarly to a vanilla PyTorch hook, but also provides a name field that can be useful to determine which intervention to perform on a given layer or to reference and use global state.

About

Tools for extracting, training, and evaluating task vectors

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages