This repository contains task_vectors
, a Python package for task vectors.
Task vectors (also called function vectors in some papers) were discovered by a few groups (e.g. Todd et al. and Hendel et al.) in late 2023.
Generally, task vectors can be understood as structures within language models corresponding to particular "tasks" or functions. These structures are often coaxed out of the model using ICL, then applied and tested in a zero-shot context.
In my experiments, I've found that a "soft" formulation of task vectors is generally the most performant. Instead of implementing the REINFORCE-based replacement rule from Hojel et al., smoothly blending each attention head with its mean value for a given task has much better performance across a wide variety of tasks.
Task vectors can serve several purposes, including:
- A constructive (rather than extractive) interpretability tool. As long as your task can be "learned" via ICL, task vectors can be used to identify neurons and circuits responsible for it
- A very parameter-efficient fine-tuning tool. Although not as effective as
more parameter-heavy methods like LoRA, task vectors can be used to improve
zero-shot performance on many classes of tasks while using a single parameter
(or fewer) per attention head. For example, task vectors for Llama-3.1 405B can be formulated with 16,128 parameters (<
$4 \cdot 10^{-4} %$ ).
Interestingly, when formulated correctly, task vectors can actually exceed the performance provided by direct ICL in some cases.
Task vectors have some shortcomings:
- Cannot be discovered unsupervisedly in the same pattern as sparse autoencoders. You must know which "task" you are trying to extract a task vector first
- They are not zero-cost as a fine-tuning tool. Unlike LoRA and related methods, task vector formulations generally require knowledge of both the current and mean activation of a given attention head. Soft task vectors, for example, introduce another addition operation that prevents the task vector weights from being directly merged into model weights, adding (slight) overhead to the process.
pip install git+https://github.com/jacknewsom/task_vectors
The main utility of this package is TaskVectorExtractor
, which wraps most of
the common operations you'll perform with task vectors. See
experiments/train.py
for an end-to-end example.
To implement the soft task vector intervention (see above)
task_vectors
introduces the Intervention
, effectively a wrapper around
PyTorch's register_forward_hook.
Intervention
s allow you to edit the outputs of any nn.Module
by specifying
a Hook
to be called on it. For example, the SaveInputsHook
will save the inputs
to all specified targets for as long as the Intervention
context is active:
...
savehook = SaveInputsHook()
with Intervention(targets={"moduleA": moduleA, "moduleB": moduleB}, hook):
model.generate(**inputs, max_new_tokens=4)
...
>>> len(savehook["moduleA"])
4
>>> len(savehook["moduleB"])
4
To implement a custom Intervention
on some target module, subclass Hook
and implement __call__
. Intervention
sets up your Hook
to run similarly
to a vanilla PyTorch hook, but also provides a name
field that can be useful
to determine which intervention to perform on a given layer or to reference
and use global state.