This is the code to carry out our experiments, we contributed DeltaProduct to flash-linear-attention which is the version we will keep updating.
Linear Recurrent Neural Networks (linear RNNs) have emerged as competitive alternatives to Transformers for sequence modeling, offering efficient training and linear-time inference. However, existing architectures face a fundamental trade-off between expressivity and efficiency, dictated by the structure of their state-transition matrices. Diagonal matrices, used in models such as Mamba, GLA, or mLSTM, yield fast runtime but have limited expressivity. To address this, recent architectures such as DeltaNet and RWKV-7 adopted a diagonal plus rank-1 structure, which allows simultaneous token and channel mixing, improving associative recall and, as recently shown, state-tracking when allowing negative eigenvalues in the state-transition matrices. Building on the interpretation of DeltaNet's recurrence as performing one step of online gradient descent per token on an associative recall loss, we introduce DeltaProduct, which instead takes multiple (
$n_h$ ) steps per token. This naturally leads to diagonal plus rank-$n_h$ state-transition matrices, formed as products of$n_h$ generalized Householder transformations, providing a tunable mechanism to balance expressivity and efficiency. We provide a detailed theoretical characterization of the state-tracking capability of DeltaProduct in finite precision and how it improves by increasing$n_h$ . Our extensive experiments demonstrate that DeltaProduct outperforms DeltaNet in both state-tracking and language modeling, while also showing significantly improved length extrapolation capabilities.
This repository builds upon two key sources:
- Illusion of State in State-Space Models. The primary modification in this version is the integration of the DeltaProduct mechanism.
- Flash-Linear-Attention. We leverage the triton implementation of (Gated) DeltaNet in order to implement DeltaProduct.
Please install flash-linear attention and the state-tracking folder in your python environment.
For the language modelling experiments you need the following packages installed in your environment:
causal-conv1d
triton>3.1.0
accelerate
transformers
datasets
wandb
This section demonstrates how to generate data and run experiments for learning sequences in the symmetric group
First, generate the training and testing datasets:
Training Data:
PYTHONPATH=$PWD python src/generate_data.py --group=S3 --k=128 --samples=100000
Test Data:
PYTHONPATH=$PWD python src/generate_data.py --group=S3 --k=512 --samples=100000
The following examples illustrate the training process with different configurations of the DeltaProduct. The notation
PYTHONPATH=$PWD python src/main.py train --group=S3 --k=128 --k_test=512 --n_layers=1 --epochs=100 --allow_neg_eigval=True --num_householder=1 --batch_size=2048 --seed=666 --lr=1e-3 --n_heads=8 --use_scheduler=True
Outcome: This configuration is not expected to converge. For
PYTHONPATH=$PWD python src/main.py train --group=S3 --k=128 --k_test=512 --n_layers=1 --epochs=100 --allow_neg_eigval=True --num_householder=2 --batch_size=2048 --seed=666 --lr=1e-3 --n_heads=8 --use_scheduler=True
Outcome: This configuration is expected to train successfully.
PYTHONPATH=$PWD python src/main.py train --group=S3 --k=128 --k_test=512 --n_layers=1 --epochs=100 --allow_neg_eigval=False --num_householder=1 --batch_size=2048 --seed=666 --lr=1e-3 --n_heads=8 --use_scheduler=True
Outcome: This configuration is not expected to converge. Restricting eigenvalues to allow_neg_eigval=False
) is unsuitable as negative eigenvalues are required.
PYTHONPATH=$PWD python src/main.py train --group=S3 --k=128 --k_test=512 --n_layers=1 --epochs=100 --allow_neg_eigval=False --num_householder=2 --batch_size=2048 --seed=666 --lr=1e-3 --n_heads=8 --use_
Outcome: This configuration is not expected to converge. Restricting eigenvalues to allow_neg_eigval=False
) is unsuitable as negative eigenvalues are required.
The checkpoints can be downloaded at this dropbox link.
The checkpoints are abbreviated with the following rules: Model_eigenvalue-range_training-context-length
For example: DN_-1_1_4k
is a DeltaNet with negative eigenvalues trained on 4096 token context.
To reproduce the experiments you will need access to a SLURM cluster to use the training and evaluation scripts.
The training script is located in language_modeling/slurm_scripts/training.sh
. You will have to adapt the parameters of the script to your specific SLURM setup (partition, nodes, gres
) and the model you wish to train. For this you need to specify the model config (configs are in language_modeling/model_configs
).
lm-eval
can be run by using thelanguage_modeling/slurm_scripts/lm_eval_harness_static_evals.sh
- To collect length extrapolation losses and perplexities, you can use the
language_modeling /slurm_scripts/length_extrapolation_eval.sh
script. - To collect data on the effective rank you need to first use
language_modeling/custom_evals/collect_activations.sh
to collect the activations and then/language_modeling/slurm_scripts/run_effective_rank.sh
to get the effective rank plots and data.