This repository contains the official implementation of the paper "Neural ODE Transformers: Analyzing Internal Dynamics and Adaptive Fine-tuning", accepted at the International Conference on Learning Representations (ICLR) 2025.
Recent advancements in large language models (LLMs) based on transformer architectures have sparked significant interest in understanding their inner workings. In this paper, we introduce a novel approach to modeling transformer architectures using highly flexible non-autonomous neural ordinary differential equations (ODEs). Our proposed model parameterizes all weights of attention and feed-forward blocks through neural networks, expressing these weights as functions of a continuous layer index. Through spectral analysis of the model's dynamics, we uncover an increase in eigenvalue magnitude that challenges the weight-sharing assumption prevalent in existing theoretical studies. We also leverage the Lyapunov exponent to examine token-level sensitivity, enhancing model interpretability. Our neural ODE transformer demonstrates performance comparable to or better than vanilla transformers across various configurations and datasets, while offering flexible fine-tuning capabilities that can adapt to different architectural constraints.
Our model formulates transformers as neural ODEs with highly flexible non-autonomous vector fields. Instead of shared weights across layers, we parameterize all weights through neural networks that express these weights as functions of a continuous layer index (time). The model includes:
- Time-dependent weights for attention components (Q, K, V)
- Time-dependent weights for feed-forward networks
- Representation of weights using a time-dependent unit that embeds time information in the Fourier domain
Comparable or better performance than vanilla transformers across various configurations
Significant improvements in downstream tasks, particularly in reading comprehension
Flexible fine-tuning capabilities that can adapt to different architectural constraints
The implementation is built on JAX, utilizing an ecosystem that includes Equinox, Haliax, and the Levanter framework.
If you find this work useful, please consider citing:
@inproceedings{
tong2025neural,
title={Neural {ODE} Transformers: Analyzing Internal Dynamics and Adaptive Fine-tuning},
author={Anh Tong and Thanh Nguyen-Tang and Dongeun Lee and Duc Nguyen and Toan Tran and David Leo Wright Hall and Cheongwoong Kang and Jaesik Choi},
booktitle={The Thirteenth International Conference on Learning Representations},
year={2025},
url={https://openreview.net/forum?id=XnDyddPcBT}
}