Skip to content

Conversation

@djsaunde
Copy link
Collaborator

@djsaunde djsaunde commented Oct 29, 2025

This PR adds sample packing support. It uses TRL's SFTConfig packing=True and padding_free=True args to pack the sequences, and we compute packed_seq_lengths metadata and thread it through the model forward pass. This metadata is used to create block causal masks for SDPA and xformers attention, and is passed to the flash attention varlen API which handles the block causal masking itself under the hood (we need to do this ourselves because of our custom forward pass, whereas TRL handles the sequence length metadata internally in their trainer).

I added a few unit tests. I also wrote a quick bash script for smoke testing some common model architectures: gist, which runs.

Below is a comparison of short unsloth/qwen2.5-0.5b training runs. The losses don't match because we're seeing more / different samples on each step. But the scale and trend match, which is the important bit.

image

Commands:

No sample packing:

python unsloth-cli.py --model_name unsloth/qwen2.5-0.5b --dataset yahma/alpaca-cleaned --per_device_train_batch_size 8 --max_steps 50 --max_seq_length 2048

Sample packing:

python unsloth-cli.py --model_name unsloth/qwen2.5-0.5b --dataset yahma/alpaca-cleaned --per_device_train_batch_size 1 --max_steps 50 --max_seq_length 2048 --sample_packing

Note that we use --per_device_train_batch_size 1 in the latter case since we are packing multiple examples into a single [1, max_seq_length] tensor.

The benefit of this approach is that we're able to discard a lot of zero padding, and therefore get higher token/s training throughput. The below plot shows that we're able to get through our dataset ~20% faster. These gains depend on the dataset and configured --max_seq_length; if we increase this we generally get better packing efficiency => higher throughput.

image

I manually tested on SDPA and flash attention, but I still need to test xformers attention since I couldn't get it to build for blackwell.

TODO

  • test xformers attention

@djsaunde djsaunde self-assigned this Oct 29, 2025
@djsaunde djsaunde changed the title Packing sample packing Oct 29, 2025
@djsaunde
Copy link
Collaborator Author

Follow up: DRY up attention code. We re-implement a big if / else block for selecting / running the attention per modeling file. We can factor this out into a separate module and call a helper function. CC @Datta0

@djsaunde
Copy link
Collaborator Author

I added support for passing position IDs to RoPE (needed for correctness, just like attention), and a (fused QK) triton kernel for the RoPE embedding (similar to what exists currently for the non-packing case).

Benchmarks show we're competitive to the triton kernel for the non-packing case while numerical ~match and significantly beat the torch slow path:

RoPE kernel benchmark sweep (microseconds per call)

seqlen varlen dense old new speedup max abs Δ mean abs Δ
256 False 198.501
256 True 429.066 223.670 1.918 4.768e-07 1.136e-08
512 False 413.377
512 True 1149.956 566.851 2.029 4.768e-07 1.170e-08
1024 False 1113.990
1024 True 2784.808 1140.053 2.443 4.768e-07 1.187e-08
2048 False 2341.204
2048 True 5525.063 2372.505 2.329 4.768e-07 1.214e-08
4096 False 4675.885
4096 True 11354.554 4681.061 2.426 4.768e-07 1.239e-08
8192 False 9285.158
8192 True 21901.080 9323.563 2.349 4.768e-07 1.256e-08

@shimmyshimmer shimmyshimmer changed the title sample packing Uncontaminated packing Oct 30, 2025
@shimmyshimmer shimmyshimmer changed the title Uncontaminated packing Uncontaminated Sample Packing Oct 30, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant