A proof-of-concept implementation of Titans, introduced by Behrouz et al. (2025). The performance is validated on a synthetic sine-wave dataset and a real-world Wind Power Forecasting dataset from Kaggle.
What exactly are Titans? On a high level, they are seem to be just Transformer models that additionally contain a Neural Memory Module (NMM), responsible for remembering information seen over time.
Its long-term memory resembles the general format of recursive neural networks.
The difference is that instead of keeping the hidden recursive state as a vector or matrix, the NMM learns a key-value mapping, given key and value matrices
The authors of Titans introduce the Neural Memory Module as a form of recursive state, which in contrast to vector or matrix representations can capture non-linear dependencies between data. Instead, they aim to embed the long-term knowledge in a neural network, learning a map from a key space to the space of values, and retrieving the relevant knowledge by processing specific queries. Motivated by the self-attention mechanism, they consider three linear maps
i.e. learns the mapping from the space of keys to the space of values. The knowledge is then recovered by considering
The two following problems are the most confusing thigns for me in the whole paper:
- What exactly is the training objective?
- Where do the matrices
$W_ Q, W_ K, W_ V$ come from?
For the first problem, the paper states that the NMM is trained to optimize
That is, from the meta learning or online learning perspective (Yu Sun et al. 2024), using a matrix-valued memory
$M = W \in R ^ {d_ {in} \times d_ {in}}$ is equivalent to optimize$l(W_ {t-1}, x_ t) = \lVert W_ {t-1} k_ t - v_ t\rVert_ 2^2$ which is an online linear regression objective and so the optimal solution assumes the underlying dependency of historical data is linear.
which assumes that the relation between
In our implementation, we use SiLU(.) activation (Elfwing, Uchibe, and Doya 2018) as the non-linear activation for computing query, key, and values and normalize queries and keys using
$l_ 2$ -norm.
Given this statement, I assumed that the stated NMM training objective is not precise and devised the following:
Assumption 1.: The Neural Memory Module
Returning to the second question, the origin of the matrices
$W_ K$ and$W_ V$ are hyperparameters in the above [NMM] loss function.
For a moment, I believed that these matrices may refer to the weights of the co-existing attention module, but this also does not seem to be the case:
Although in the above, we discussed MAL as the combination of LMMs and attention in a sequential manner, one simple variant of MAL is to treat LMM as a sequence model without any attention.
As a result, seing no clear way how to train
Assumption 2.: The key and value matrices whose relation the NMM is trained to approximate, are randomly sampled at the beginning of the training and remain constant throughout the whole training process. The parameter
The objective for training the NMM is described above in Assumption 1.. Authors introduce a novel methodology of updating the weights of the NMM, based on a "surprise principle". The update equations are as follows:
where
-
$\alpha_ t$ - weight decay -
$\eta_ t$ - forgetting factor -
$\theta_t$ - surprise power
In the paper, these are trainable and depend on the current input
Assumption 3.: The NMM update parameters are constant, chosen to be
Under this assumption, the updates become a Stochastic Gradient Descent with momentum and weight decay. Indeed, taking
Note that in such scenario, to ensure stable training, we should select high
A single layer of a Titam "Memory As Context" (MAC) architecture integrates the Neural Memory Module in a retrieval-update fashion. Its implementation is available in titans.py
as TitanMACLayer
class. Given an input sequence $x \in \mathbb{R}^{n \times d_ {hid}} $ consisting of
- We first retrieve embeddings
$h$ from the NMM$\mathcal{M}$ by$h = \mathcal{M}(q) \in \mathbb{R}^{n \times d_ {hid}}$ where$q = \text{norm}_ {l_ 2}(\text{SiLU}(x W_ Q))$ . - We create a new joint embedding by concatenating a trainable data-independent "persistent memory" weights
$p \in \mathbb{R}^{N_ p \times d_ {hid}}$ , the retrieved long-memory values$h$ and the original input$x$ :
- We pass
$y$ through the attention-based module, deriving$\tilde{y} = \text{Attn}(\tilde{x})$ . Afterwards, to match the original dimension of the sequence, we run it through a SiLU-activated fully connected layer, obtaining$y \in \mathbb{R}^{n\times d_ {hid}}$ - We update the NMM by passing
$q' = \text{norm}_ {l_ 2}(\text{SiLU}(y \cdot W_ Q))$ . - We again pass
$q'$ thorugh the NMM, this time with updated weights, receiving a long-term memory$z \in\mathbb{R}^{n\times d_ {hid}}$ . - Finally, we gate the output
$y$ of the attention module using$z$ and return the obtained sequence.
Due to the addition of the long-term memory module, the authors of Titans claim their architecture can be applied to very long sequences. In particular, they suggest the following methodology for processing sequences of
- Divides this sequence into
$\frac{N}{C}$ chunks$x_ {[0, C)}, x_ {[C, 2C)}, ..., x_{[N-C, N)}$ . - Runs the first
$\frac{N}{C}-1$ chunks through the network, discarding their outputs. This step is for the Neural Memory Module to retrieve and remember valuable data from the context. - Return the result of processing the final chunk
$x_{[N-C, N)}$ .
This way, within each chunk, the attention module acts as short-term memory, recovering relevant information within the chunk and updating the long-term memory (NMM) accordingly. Once we reach the final chunk, NMM should be reminded of all relevant context and assist the attention module in producing an accurate prediction for the next token.
Titans are targetting the problem of processing sequences with long context - ones that exceed the length of the context window available for current LLMs. For this proof-of-concept project, such scale is unavailable. To capture the studied problem, we can instead reduce the size of the context window of the attention mechanism. Given that sequences in our datasets are of length 500, we consider context window of 16, i.e. short-term memory will only be able to view sub-sequences of length 16 at once.
As said above, for all experiments, we select
-
$N_ p = 4$ : the size of persistent memory weights -
$d_ {hid} = 16$ : the dimension of hidden embeddings
To model the non-linear dependencies between keys and values, we implement the Neural Memory Module, as a 2-layer MLP with SiLU activation, and the intermediate hidden dimension of size NeuralMemory
implementing this module is available in neural_memory.py
. To increase flexibility, we do not implement the update as an equivalent step of torch.optim.SGD
, instead modifying the parameters of the network directly. This opens the possibility for supporting trainable parameters
Implemented as MACTitanLayer
in titans.py
. Its behaviour is already described in the section above. For the gating mechanism at the end, we perform an element-wise product of the embedding
The attention mechanism is implemented as 2 torch.nn.TransformerEncoderLayer
layers, each with two heads and SiLU activation.
For the Titan MAC model (see MACTitan
in titans.py
), we stack two MAC layers. To convert input into a sequence of
The architecture of the baseline attention model resembles the one of Titan MAC, with the only difference being the Neural Memory Model not being used.
We validate the effectiveness of small-scale Titans on the task of time series prediction. As mentioned before, we limit the context window size of attention mechanisms to
the objective in the studied problem is to return a sequence
of predictions of the next element. To be precise,
The quality of prediction
We train each model using MSE with Adam optimizer with learning rate
Titans are a composition of attention-based and recurrent models. For this reason, we compare it to:
- (Attn): a Transformer-like attention baseline, able to process only sequences of length
$C$ . The architecture of this model matches the used Titan MAC architecture with removed NMM and persistent memory, - (Attn-PM): the same architecture as above, but also including the persistent memory weights,
- (LSTM): an LSTM-based model, processing the input sequence recursively from start to end. The hidden embedding of this model is selected to be
$16$ .
Both baselines are available in models.py
.
We validate Titan's performance on two datasets: a synthetic sine wave prediction and wind poer forecasting dataset from Kaggle. For more details about the generation and preparation, refer to sinewave_gen.ipynb
and weather_prep.ipynb
.
The test MSE results are presented below:
Model | sinwave | windpow |
---|---|---|
Attn | 0.0262 | 0.0353 |
Attn-PM | 0.0274 | 0.0357 |
Titan MAC (1-layer NMM) | 0.0184 | 0.0373 |
Titan MAC (2-layer NMM) | 0.0183 | 0.0357 |
LSTM | 0.0156 | 0.0356 |
We can observe that scores on the wind power prediction dataset are similar for all the models, with the minor differences potentially coming from randomness in training. The sine wave dataset, however, provides more insight into the predictive power of the studied methods. Expectedly, Attn, unable to capture long-term dependencies, performs much poorer than methods with recursive components. Interestingly, adding the persistent memory weights did not improve the attention-style architecture. The other baseline, LSTM performs the best among all considered models. This is not surprising, as the recursive mechanism in LSTM should be sufficient to capture the nature of sine waves.
Looking at Titan's results, we can see two important things:
- They perform significantly better than the base attention models, with performance closer to LSTM,
- Setting NMM to be a 2-layer MLP yields better results than using a single-layer module.
Both of these observations suggest that the Neural Memory Module is indeed able to capture long-term dependencies in data. Despite losing to LSTM, Titans significantly improve over the fixed-size context window models, showing the potential of including long-term memory in the Transformer architecture.
For all experiments with Titans, we fix
SGD momentum | SGD lr | Validation MSE | ||
---|---|---|---|---|
0.8 | 0.3 | 0.8 | 1.5 | 0.0225 |
0.8 | 0.02 | 0.8 | 0.1 | 0.0186 |
0.8 | 0.01 | 0.8 | 0.05 | 0.0185 |
0.9 | 0.01 | 0.9 | 0.1 | 0.0183 |
0.9 | 0.005 | 0.9 | 0.05 | 0.0186 |
In this proof-of-concept project, we validate the potential of Titans for processing long sequences, beyond the length of the available context window. We show that the Neural Memory Module allows the model to store a recursive state and propagate information about the task to the short-term memory implemented by attention modules. The performance of Titans on the studied time series prediction tasks is closer to LSTM than the attention baseline, suggesting that NMM gives the model a form of recursive structure.
Next steps could include:
- introducing
$\alpha, \beta \text{ and } \gamma$ as learnable, data-dependent parameters, - scaling the architecture to larger context windows and more challenging tasks,
- with the scaled models, target tasks in question answering, NIAH or other natural language problems.