Official codebase for the ICML 2025 paper ELMO: Efficiency via Low-precision and Peak Memory Optimization in Large Output Spaces.
This repository implements end-to-end float8 (FP8) training for the Extreme Multi-label Classification (XMC) Classifier. By leveraging techniques such as gradient fusion and memory-efficient chunking, ELMO achieves up to 10x reduction in peak memory usage on datasets with up to 8.6 million labels, enabling scalable training in large output spaces.
👉 You can find the camera-ready paper here.
- ✅ Pure Low Precision (FP8 and BF16) training throughout.
- ✅ Skip-Loss.
- ✅ Split operations and Chunking for Peak memory optimization.
- ✅ SGD optimizer with Stochastic Rounding for XMC layer.
- ✅ Fused Gradient for Peak Memory Optimization.
- ✅ FP8 Encoder with torch.ao.
- ✅ FP8 for XMC layer.
conda create -y -n elmo python=3.11
conda activate elmo
git clone https://github.com/xmc-aalto/elmo.git
cd elmo
bash setup_env.sh
Download datasets from the extreme classification repo.
- AmazonTitles-670k
- Amazon-670K
- Wiki-500K
- Amazon-3M
- LF-AmazonTitles-131K
- LF-WikiSeeAlso-320K
- LF-AmazonTitles-1.3M
- LF-Paper2keywords-8.6M
-
Setup environment based on the installation instructions above.
-
Run the main script with Hydra-style arguments:
python src/main.py data=<dataset_name> dataset_path=<path_to_datasets> log_fname=<log_file_name>
dataset_path
: The root folder where all datasets are stored.data
: The name of the dataset to use. Options include:amazon670k
,amazontitles670k
,wiki500k
,amazon3m
,lfamazontitles131k
,lfamazontitles1.3m
,lfwikiseealso320k
,lfpaper2keywords
.
All training is managed through src/main.py
and configured using Hydra. The base configuration for each dataset is located in config/dataset/
. You can set parameters in the config file or command line. The performance is shown in wandb or log files.
Pure Bfloat16 training on AmazonTitles-670K
python src/main.py dataset=amazontitles670k log_fname=log_bf16_at670k dataset_path=Datasets dataset.model.xmc.implementation=chunked
Pure FP8 Training on AmazonTitles-670K (Require Hopper, Ada, Blackwell cards)
python src/main.py dataset=amazontitles670k log_fname=log_fp8_at670k dataset_path=Datasets dataset.model.xmc.implementation=fp8chunked dataset.training.FP8.use_fp8_encoder=True
Simulated FP8 Training on AmazonTitles-670K (Any card supports Bfloat16)
This runs a FP8 training where parameters are stored in FP8, but the matrix multiplication is performed in BFloat16. Keep memory benefits of FP8 training.
python src/main.py dataset=amazontitles670k log_fname=log_sfp8_at670k dataset_path=Datasets dataset.model.xmc.implementation=fp8chunked dataset.training.FP8.use_fp8_encoder=False dataset.training.xmc.simulated_fp8=True
Torchao and NVIDIA RTX 40 series Ada cards(e.g., 4060 Ti)
Torchao doesn't support FP8 encoder for some Ada cards so encoder could be set to bf16 as its memory is negligible compared to classifier.
python src/main.py dataset=amazontitles670k log_fname=log_fp8_fp8enc_at670k dataset_path=Datasets dataset.model.xmc.implementation=fp8chunked dataset.training.FP8.use_fp8_encoder=False
- Navigating Extremes: Dynamic Sparsity in Large Output Spaces (NeurIPS 2024)
- Towards Memory-Efficient Training for Extremely Large Output Spaces – Learning with 500k Labels on a Single Commodity GPU (ECML 2023)
If you find our work, code, or the LF-Paper2keywords-8.6M
dataset useful in your research, please cite the following:
@inproceedings{
zhang2025elmo,
title={{ELMO} : Efficiency via Low-precision and Peak Memory Optimization in Large Output Spaces},
author={Jinbin Zhang and Nasib Ullah and Erik Schultheis and Rohit Babbar},
booktitle={Forty-second International Conference on Machine Learning},
year={2025},
url={https://openreview.net/forum?id=d6CTIPrTTC}
}