💥 v0.1.0
What's Changed
- Update README.md by @eltociear in #2
- fix simple gla backward by @sunyt32 in #6
- Adding RWKV-v4. by @ridgerchu in #8
- fixed hgrn.py paper link and title by @ridgerchu in #10
- Update recurrent_naive.py by @hypnopump in #12
- fix: calculate du on different batch by @uniartisan in #35
- fix: enhance state gradient when bf16 by @uniartisan in #37
- Add implementations of Mamba 2 into FLA by @DanFosing in #39
- Minor mamba-2 fixes by @DanFosing in #40
- [DeltaNet] Adds beta as a vector option by @hypnopump in #42
- [DRAFT] Beta gradient does not match by @hypnopump in #43
- [Attn] fix negative value of seqlen offset during sft by @ChaosCodes in #45
- [RWKV6] fix backward if h0 not passed by @hypnopump in #48
- Replace mamba2
mamba_chunk_scan_combinedtriton kernel bysimple_glatriton kernel by @learning-chip in #49 - benchmark script for simple_gla vs mamba2 kernel by @learning-chip in #50
- Update amp custom_fwd, custom_bwd usage for torch 2.4.0 compatibility by @mirceamironenco in #54
- Fix syntax error by @JulienSiems in #55
- Add
__init__.pyinfla/ops/commonfor automatic package discovery by @zhixuan-lin in #56 - [
Mamba2] Post Merge Fixes -norm_before_gateand generation withinputs_embedsby @vasqu in #57 - Correctly compute
max_seqlenwhenmax_position_embeddingsisNoneby @zhixuan-lin in #59 - add chunked kl div by @ChaosCodes in #62
- Add fine-grained warning category for easier supression by @mirceamironenco in #65
- Update fused_chunk.py by @hypnopump in #72
- [
Mamba2] Fix slow path by @vasqu in #84 - Add BitNet by @DustinWang1 in #85
- Fix RWKV6 Cache Problems by @WorldEditors in #78
- Bugs in RWKV6 OP by @WorldEditors in #87
- fix mamba2 cache bug by @WorldEditors in #89
- fix dh0 is None breaking backward pass by @Sxela in #102
- support varlen training for conv1d by @LKJacky in #116
- blood for the torch.compile gods by @harrisonvanderbyl in #119
- Added forward pass for chunckwise ttt-linear, varlen is supported. by @Pan-Yuqi in #124
- Add scripts for converting pretrained RWKV7 models to fla format by @Triang-jyed-driung in #128
- [
Mamba2] Fixes for caching and multiple other small issues by @vasqu in #129 - [LinAttn] Fix handling of None scale in chunk_linear_attn for output normalization by @HallerPatrick in #130
- Fix incorrect kwarg name in
fused_recurrentby @fffffgggg54 in #134 - RWKV-7 conversion and evals by @Triang-jyed-driung in #135
- Fixed dtype mismatch of mamba & mamba2 under residual_in_fp32 setting by @chengshuang18 in #137
- [RWKV7] Fix masking before time shifting modules by @Triang-jyed-driung in #141
- [RWKV7, but applicable to all models] Update modeling_rwkv7.py: Fixing
base_model_prefixby @Triang-jyed-driung in #143 - fix bitattn with latest attn implementation. by @ridgerchu in #146
- [RWKV7] Remove in-place operations and add gradient checkpointing for
v_firstby @Triang-jyed-driung in #145 - [BitNet] Fix bugs of model definitions by @ridgerchu in #147
- Fix #157 by @jannalulu in #167
- [Mamba, Samba] Add weight initializations and reset_parameters() in _init_weights() for compatibility in Flame by @zaydzuhri in #169
- fix lint errors by @jannalulu in #170
- [RWKV] Follow-up to fix cache management by @jannalulu in #168
- one liner by @seanxwzhang in #178
- [Attn] Fix cache update of swa by @Pan-Yuqi in #183
- [RWKV7] Fix conversion precision by @Triang-jyed-driung in #188
- [GRPO]: add grpo functions by @uniartisan in #189
- [RWKV] fix logits handling by @jannalulu in #192
- [Modules]: Enhance the precision of the fused LayerNorm OP. by @uniartisan in #200
- [MISC] fix delta_net logit handling by @jannalulu in #205
- [RWKV7] Keep compatibility with Torch Compiler. by @uniartisan in #208
- [Misc.] Update wrapper to support contiguous and guard custom device … by @uniartisan in #212
- [Models] Fix the error in the judgment of past_key_values when inputs… by @uniartisan in #213
- [Titans] Update Titans implementation by @rucnyz in #214
- [Mamba2] Fix initialization by @HanGuo97 in #225
- [TTT] Update fused chunk ops and state bias term by @Pan-Yuqi in #230
- Enable utils.py to be imported on CPU-only machines (#231) by @zhuzeyuan in #232
- [Utils] use fla.utils.device instead of cuda by @uniartisan in #163
- fix(GatedDeltaNet): Ensure integer dimensions when using
expand_vby @vladislavalerievich in #234
New Contributors
- @eltociear made their first contribution in #2
- @sunyt32 made their first contribution in #6
- @ridgerchu made their first contribution in #8
- @hypnopump made their first contribution in #12
- @DanFosing made their first contribution in #39
- @ChaosCodes made their first contribution in #45
- @learning-chip made their first contribution in #49
- @mirceamironenco made their first contribution in #54
- @JulienSiems made their first contribution in #55
- @zhixuan-lin made their first contribution in #56
- @vasqu made their first contribution in #57
- @DustinWang1 made their first contribution in #85
- @WorldEditors made their first contribution in #78
- @Sxela made their first contribution in #102
- @LKJacky made their first contribution in #116
- @harrisonvanderbyl made their first contribution in #119
- @Triang-jyed-driung made their first contribution in #128
- @HallerPatrick made their first contribution in #130
- @fffffgggg54 made their first contribution in #134
- @chengshuang18 made their first contribution in #137
- @jannalulu made their first contribution in #167
- @zaydzuhri made their first contribution in #169
- @seanxwzhang made their first contribution in #178
- @zhuzeyuan made their first contribution in #232
- @vladislavalerievich made their first contribution in #234
Full Changelog: https://github.com/fla-org/flash-linear-attention/commits/v0.1.0