Sangmin Bae1* Yujin Kim1* Reza Bayat2* Sungnyun Kim1 Jiyoun Ha3 Tal Schuster4 Adam Fisch4 Hrayr Harutyunyan5 Ziwei Ji4 Aaron Courville2,6† Se-Young Yun1†.
1KAIST AI 2Mila 3Google Cloud 4Google DeepMind 5Google Research 6Université de Montréal
*Equal contribution. †Corresponding authors.
Early-exiting [CALM], an Adaptive Computation technique, helps LLMs run more efficiently by allowing them to skip unnecessary computations when they're confident in their predictions. This creates dynamic pathways through the model. However, putting early-exiting into practice faces two main bottlenecks:
-
Missing Key-Value (KV) cache problem: When tokens exit early, they skip computing the KV pairs for remained deeper layers. But these missing values are essential for decoding future tokens, and trying to approximate them often hurts performance.
-
Inefficient batched inference: Tokens that exit early end up waiting for others in the same batch to finish their full computation. This "idling" prevents efficient batching and wastes processing time.
Previous work tackled these challenges individually:
-
FREE Framework addressed the missing KV cache problem using parallel decoding (an early form of self-speculative decoding), which efficiently computes the exact KV pairs for early-exited tokens. However, compatibility with batch inference remained poor.
-
Recursive Transformers aimed to mitigate inefficient batched inference through parameter sharing, enabling tokens at different depths to be processed together. Yet, two separate training processes for integrating parameter sharing and early-exiting degraded performance, and this model still required handling the missing KV cache.
Our new research [MoR] introduces a unified framework that directly tackles both the missing KV cache and batched inference issues. We achieve this with a ✨ routing mechanism trained end-to-end effectively, which dynamically assigns the optimal recursion depth to each token. We further enhance this by introducing a ✨ recursion-wise KV caching strategy that selectively stores KV pairs, resolving the missing cache problem while optimizing memory usage. We achieve up to 2× greater inference throughput compared to standard transformers at similar accuracy, while also reducing total training FLOPs and memory requirements.
To get started, follow these steps to set up your development environment. We recommend using conda
for dependency management.
-
Create and Activate Conda Environment:
conda create -n mor python=3.12 conda activate mor
-
Install Required Packages: First, ensure your
pip
andsetuptools
are up to date. Then, installtorch
and the dependencies listed in requirements.txt.Note: We specifically used
torch==2.6.0+cu124
,flash_attn==2.7.4.post1
, andtransformers==4.52.4
. If you encounter issues, consider these exact versions.pip install --upgrade pip pip install --upgrade setuptools pip install torch pip install -r requirements.txt # If you experience issues with flash-attn, try: # pip install flash-attn --no-build-isolation
Our models are pretrained on a deduplicated subset of the FineWeb-Edu dataset, available as part of the SmolLM-Corpus.
Follow these steps to download and prepare the dataset:
-
Create Data Directories: Create the necessary directories (
hf_cache
,hf_datasets
,hf_models
, andresults
) under your designated data path. Replace{your_data_path}
with your actual path.mkdir -p {your_data_path}/mixture_of_recursions/hf_cache mkdir -p {your_data_path}/mixture_of_recursions/hf_datasets mkdir -p {your_data_path}/mixture_of_recursions/hf_models mkdir -p {your_data_path}/mixture_of_recursions/results
-
Create Symbolic Links: Establish symbolic links from your data path to your project's current path. Replace
{your_data_path}
and{your_project_path}
accordingly.ln -s {your_data_path}/mixture_of_recursions/* {your_project_path}/mixture_of_recursions/
-
Download Pretraining Corpus: Execute the provided script to download the
fineweb-edu-dedup
dataset.bash lm_dataset/download_scripts/download_fineweb-edu-dedup.sh
-
Move Cached Dataset: Move the cached dataset to the
DATA_DIR
as specified in lm_dataset/load_dataset.py. This ensures the dataset is located in the correct directory for loading. You can do steps 3 and 4 together by downloading the dataset using the script at download_langauge_modeling_datasets.sh.
For details on the dataset used, please refer to the lm_dataset directory. We used our custom language modeling dataset class for this project. There are opportunities for optimization:
- Improving the loading speed of the data
state_dict
when restarting interrupted training. - Additionally, our current input packing allows for attention across different documents, which could be mitigated by integrating future advancements like FlexAttention.
Our models are built upon the Llama architecture, specifically by modifying the LlamaForCausalLM class. The Expert-choice and Token-choice versions of our MoR architecture can be found in expert_choice_router.py and token_choice_router.py, respectively.
The high-level routing process follows these steps:
-
Obtain Indices: Get indices from the top-k selections of the routers.
-
Index Tokens: Index the current tokens using these obtained indices. For expert-choice, we can avoid variable-length complexities since all samples have the same length. However, for token-choice, we handle variable lengths by adding padding using
rnn_utils.pad_sequence
. -
Compute Shared Blocks: Perform computations on the shared blocks once.
-
Scatter and Combine:
scatter_add
the computed tokens back to their original shape using the previously obtained indices. -
Repeat: Repeat this process for the desired number of recursions.
Note: We believe further optimizations are possible through the use of FlexAttention
and grouped_mm
.
Moreover in case of KV sharing, we computed all sequences against a shared KV cache for simplicity. The outputs for tokens participating in a specific recursion were then extracted by masking.
To streamline the creation of training and evaluation scripts, we made an automated Python utility: generate_pretrain_eval_fewshot_configs.py. By specifying your custom config name and providing arguments, the script will automatically generate a YAML file for your configuration. This script will generate configurations based on example.yaml
files located under conf/pretrain or conf/eval_fewshot.
To generate a configuration, simply run:
python util/generate_pretrain_eval_fewshot_configs --name {config_name}
# Example: python util/generate_pretrain_eval_fewshot_configs --name 250720_pretrain_smollm-360m_rec3_middle_cycle_random_lr3e-3_mor_expert_linear_alpha_0.1_sigmoid_aux_loss_0.001
We typically conducted training using 4 H100 or A100 GPUs. For distributed training, we leveraged either Accelerate
or DeepSpeed
ZeRO Stage 2. Exploring FSDP, Tensor Parallelism, and Pipeline Parallelism for MoR models is left for future work.
Here are the training commands for both versions:
# DeepSpeed
HYDRA_FULL_ERROR=1 deepspeed --include localhost:0,1,2,3 --no_local_rank --master_port 25720 pretrain.py --config-name example
# Accelerate
HYDRA_FULL_ERROR=1 CUDA_VISIBLE_DEVICES=0,1,2,3 accelerate launch --config_file acc_configs/default_config.yaml --main_process_port 25720 pretrain.py --config-name example
Few-shot accuracy was measured using lm-evaluation-harness.
Here are the evaluation commands:
# DeepSpeed
HYDRA_FULL_ERROR=1 deepspeed --include localhost:0,1,2,3 --no_local_rank --master_port 23393 eval_fewshot.py --config-name example
# Accelerate
HYDRA_FULL_ERROR=1 CUDA_VISIBLE_DEVICES=0,1,2,3 accelerate launch --config_file acc_configs/default_config.yaml --main_process_port 23393 eval_fewshot.py --config-name example
You can also evaluate the validation loss using evaluate_fineweb_test.py. We specifically measure validation loss for scaling laws analysis.
To run training and few-shot evaluation concurrently with a single command, execute the following shell script:
bash scripts/pretrain_eval_fewshot.sh {launcher} {wandb_mode} {gpu_indices} {exp1_config} {exp2_config} ...
# Example: bash scripts/pretrain_eval_fewshot.sh deepspeed online 0,1,2,3 250720_pretrain_smollm-360m_rec3_middle_cycle_random_lr3e-3_mor_expert_linear_alpha_0.1_sigmoid_aux_loss_0.001
We share pretrained checkpoints for our 360M parameter Vanilla, Recursive, and MoR models in Google Drive. Move checkpoints under ./checkpoints
folder.
Alternatively, you can use the following commands to download them, but please be aware of a potential bug:
pip install gdown
mkdir -p checkpoints
gdown --folder 'https://drive.google.com/drive/folders/1pYKJOu2aBGC-jgoWbfP6T_vqEYtUVxa4?usp=sharing' -O checkpoints
Additionally, you can find a script to explore the routing behavior of the expert-choice MoR model in the notebooks/250727_get_mor_routing_decision.ipynb
notebook.
If you find our work useful, please cite it as:
@misc{bae2025mixtureofrecursionslearningdynamicrecursive,
title={Mixture-of-Recursions: Learning Dynamic Recursive Depths for Adaptive Token-Level Computation},
author={Sangmin Bae and Yujin Kim and Reza Bayat and Sungnyun Kim and Jiyoun Ha and Tal Schuster and Adam Fisch and Hrayr Harutyunyan and Ziwei Ji and Aaron Courville and Se-Young Yun},
year={2025},
eprint={2507.10524},
archivePrefix={arXiv},
primaryClass={cs.CL},
url={https://arxiv.org/abs/2507.10524},
}