Official PyTorch implementation for the paper: "Navigating Extremes: Dynamic Sparsity in Large Output Spaces" accepted at NeurIPS 2024.
👉 You can find the camera-ready paper here.
conda create -y -n spartex python=3.10
conda activate spartex
bash setup.sh
pip install torch_sparse-0.0.6.dev295+cu124.pt24-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Download datasets from the extreme classification repo.
or follow the links below
- Setup environment based on the installation instructions above.
- Settings and Hyperparameters are managed by hydra. See the complete configuration layout for
LF-AmazonTitles-131K
below. For more details check config folder. - Add key:value (running-env:root path of datafolders) entry in env2path dictionary in main.py and add the running-env to environment in yaml file (Optional to avoid typing path during every run).
- Run
python src/main.py dataset=dataset log_fname=log_dataset
(step 3 followed).
OR - Run
python src/main.py dataset=dataset dataset_path=./data log_fname=log_dataset
(step 3 not followed).
[where
dataset_path
is root path and dataset argument names arewiki31k
,amazon670k
,wiki500k
,amazon3m
,lfamazontitles131k
,lfwikiseealso320k
]
Pre-trained Initialization for LF-AmazonTitles-131K can be found here
Wiki31K
python src/main.py dataset=lfamazontitles131k log_fname="log_AT131K"
Amazon670K
python src/main.py dataset=amazon670k log_fname="log_A670K"
If you find our work/code useful in your research, please cite the following:
@inproceedings{ullahnavigating,
title={Navigating Extremes: Dynamic Sparsity in Large Output Spaces},
author={Ullah, Nasib and Schultheis, Erik and Lasby, Mike and Ioannou, Yani and Babbar, Rohit},
booktitle={The Thirty-eighth Annual Conference on Neural Information Processing Systems}
}
[1] The Extreme Classification Repository
[3] Dynamic Sparse Training with Structured Sparsity
[3] pyxclib
environment:
running_env: guest
cuda_device_id: 0
device: cuda
data:
dataset: lfamazontitles131k
is_lf_data: True
augment_label_data: True
use_filter_eval: True
num_labels: 131073
max_len: 32
num_workers: 8
batch_size: 512
test_batch_size: 512
model:
encoder:
encoder_model: "sentence-transformers/msmarco-distilbert-base-v4"
encoder_tokenizer: ${dataset.model.encoder.encoder_model}
encoder_ftr_dim: 768
pool_mode: 'last_hidden_avg'
feature_layers: 1
embed_dropout: 0.85
use_torch_compile: False
use_ngame_encoder_weights: False
ngame_checkpoint: ./NGAME_ENCODERS/${dataset.data.dataset}/state_dict.pt
penultimate:
use_penultimate_layer: False
penultimate_size: 4096
penultimate_activation: relu
ffi:
use_sparse_layer: False
fan_in: 128
prune_mode: fraction
rewire_threshold: 0.01
rewire_fraction: 0.15
growth_mode: random
growth_init_mode: zero
input_features: 768
output_features: 131073
rewire_interval: 1000
use_rewire_scheduling: True
rewire_end_epoch: 0.66 #depends on epoch
auxiliary:
use_meta_branch: False
group_y_group: 0
meta_cutoff_epoch: 20 # varies based on fan_in values
auxloss_scaling: 0.4
training:
seed: 42
amp:
enabled: False
dtype: float16
optimization:
loss_fn: bce # ['bce','squared_hinge']
encoder_optimizer: adamw
xmc_optimizer: sgd
epochs: 100 # depends on dataset
dense_epochs: 100
grad_accum_step: 1
encoder_lr: 1.0e-5
penultimate_lr: 2.0e-4
meta_lr: 5.0e-4
lr: 0.05 # learning rate of final layer
wd_encoder: 0.01 # weight decay on encoder
wd: 1e-4 # weight decay of final layer
lr_scheduler: CosineScheduleWithWarmup
lr_scheduler_xmc: CosineScheduleWithWarmup
warmup_steps: 5000
training_steps: 1 #selected at runtime based on batch size and dataloader
evaluation:
train_evaluate: True
train_evaluate_every: 10
test_evaluate_every: 1
A: 0.6 # for propensity calculation
B: 2.6 # for propensity calculation
eval_psp: True
verbose:
show_iter: False # print loss during training
print_iter: 2000 # how often (iteration) to print
use_wandb: False
wandb_runname: none
logging: True
log_fname: log_amazontitles131k
use_checkpoint: False #whether to use automatic checkpoint
checkpoint_file: PBCE_NoLF_NM1
best_p1: 0.2 # In case of automatic checkpoint