Releases: fla-org/flash-linear-attention
v0.4.0
🧠 New Models
What's Changed
- [GDN] Fix tiling bugs once gv applied by @yzhangcs in #589
- [Conv] Add comprehensive docstring and change default backend to triton by @zhiyuan1i in #592
- Update cumprod_householder_bwd.py by @SeepingFragranceLock in #593
- [FIX] Correct cumsum dimension in normalize_output by @sirluk in #594
- [Triton] Add autotune caching support for Triton kernels by @zhiyuan1i in #598
- [DeltaFormer] Add Model by @Nathancgy in #585
- [DeltaFormer] Replace GenerationMixin with FLAGenerationMixin and upd… by @zhiyuan1i in #600
- [Cache] Fix from_legacy_cache by @zhiyuan1i in #605
- [Deps] Make pytest an optional dependency by @wedaly in #610
- [DeltaFormer] Fixed testing ops error by @Nathancgy in #602
- [Conv] Fix potential OOB problems by @yzhangcs in #615
- [Deps] Minimize deps by @zhiyuan1i in #617
- Determine the chunk size at the kernel entry by @yzhangcs in #619
- Add KDA by @yzhangcs in #621
- [Lint] Migrate from flake8/isort to ruff for faster linting by @zhiyuan1i in #613
New Contributors
- @SeepingFragranceLock made their first contribution in #593
- @sirluk made their first contribution in #594
- @Nathancgy made their first contribution in #585
- @wedaly made their first contribution in #610
Full Changelog: v0.3.2...v0.4.0
v0.3.2
📣 Highlights
Starting with this release, every time we ship a new version of flash-linear-attention, we will simultaneously publish fla-core: a minimal-dependency subset of the main repo that contains only the essentials.
🧠 New Models
What's Changed
- [Conv] Provide fn interface for
causal_conv1dby @yzhangcs in #578 - [Log Linear Attention] add backward pass by @2022tgoel in #577
- [PaTH] Fix q init & dq masking by @yzhangcs in #581
- [TokenShift] Fix a bug in decoding by @zhiyuan1i in #583
- [Deps] Lock
transformers<4.56.0by @zhiyuan1i in #582 - [Log-Linear Attention] add models by @2022tgoel in #579
- [Deps] Upgrade to transformers 4.56.x by @zhiyuan1i in #587
- [Build] Split package distribution[skip test] by @zhiyuan1i in #588
Full Changelog: v0.3.1...v0.3.2
v0.3.1
What's Changed
- [Misc] Change grid to support long ctx by @zhiyuan1i in #528
- [RWKV7] Reduce CPU overhead by @zhiyuan1i in #529
- [Tokenshift] Support SP and cache by @zhiyuan1i in #531
- [RWKV7] Use tokenshift to save cache by @zhiyuan1i in #532
- [RWKV7] Fix the issue of RWKV7 initialization with BFloat16 data type on CPU. by @zhiyuan1i in #538
- [CI] Add compatibility check by @zhiyuan1i in #536
- [ShortConv] Support cache in prefill by @zhiyuan1i in #535
- [WIP] Add Log-Linear Attention by @2022tgoel in #524
- [Cache] Upgrade to transformer>= v4.48[skip test] by @zhiyuan1i in #541
- [Misc.] Set env var TRITON_F32_DEFAULT to
ieeewhen tf32 is not supported on NVIDIA by @KevlarKanou in #544 - [CI] Fix mirror for building triton by @zhiyuan1i in #543
- Log-Linear Attention Tests by @2022tgoel in #542
- [CI] Add proxy config for git by @zhiyuan1i in #548
- [Conv] Fix warning issue by @zhiyuan1i in #549
- [Misc.] Eliminate recompilation in layer-norm kernels caused by dynam… by @zhiyuan1i in #545
- [Misc.] Add activations for non-cuda Backends by @zhiyuan1i in #174
- [TMA] Accelerate solve_tril with TMA descriptors[skip test] by @zhiyuan1i in #550
- [CI] Upgrade to latest casual-conv1d and fix triton build for 3.4.x by @zhiyuan1i in #551
- [CI] Fix support for Intel GPU by @zhiyuan1i in #554
- [Fix] Fix Triton Error for HeadDim < 16[skip test] by @zhiyuan1i in #556
- [GLA] Fix simple_gla Test by @zhiyuan1i in #558
- [CI] Fix CI script errors[skip test] by @zhiyuan1i in #566
- require transformers <= 4.53.3 by @richardodliu in #570
- [Deps] Adopt transformers>4.53.3 by @zhiyuan1i in #571
- [Misc.] Clean codes and make mypy happy by @zhiyuan1i in #572
- [Models]: Add MoM by @WKX933 in #442
- [MoM]Fix lint by @JusenD in #573
- [Refactor] Apply GradientCheckpointingLayer to all model layers by @yzhangcs in #575
- [Mamba] Fix errors in Triton backend by @zhiyuan1i in #576
New Contributors
- @2022tgoel made their first contribution in #524
- @KevlarKanou made their first contribution in #544
- @richardodliu made their first contribution in #570
- @WKX933 made their first contribution in #442
Full Changelog: v0.3.0...v0.3.1
v0.3.0
Highlights
🧠 New Models
We are excited to expand our model library with the addition of four powerful new architectures.
- 🎉 MesaNet by @sustcsonglin
- 🛣️ PaTH by @sustcsonglin
- 🐍 Comba by @AwesomeSeq @yzhangcs
- 🐳 MLA by @toothacher17 @yzhangcs
What's Changed
- [MesaNet] add kernel impl. by @sustcsonglin in #419
- [GDN] Add support for inference with GVA by @yzhangcs in #429
- [HGRN] remove unused q_conv1d by @yibozhong in #430
- Update mesa_net.py by @jovoswald in #434
- [Gated DeltaNet] Refactor the kernel to remove one matrix inversion by @sustcsonglin in #433
- [Modules] Add
L2Warpto maintain bf16 precision by @zhiyuan1i in #438 - [RWKV]: Set default scale to None by @zhiyuan1i in #445
- [Typos] Change scale docs to (Optional[float]) [skip test] by @zhiyuan1i in #446
- [Modules] Enhance Testing of
l2warpby @zhiyuan1i in #448 - [CI] Upgrade CI envs to torch~=2.7.0 by @zhiyuan1i in #450
- [Mesa] misc. fix by @sustcsonglin in #449
- [Models]: Add Comba Implementation by @AwesomeSeq in #444
- [Test] Walk around the bug of
causal_conv1dby @zhiyuan1i in #453 - [Utils] Add deprecation handling for kwargs with
deprecate_kwargdecorator by @yzhangcs in #455 - [ShortConv] Replace
use_fast_conv1dwithbackendparameter by @yzhangcs in #456 - [Docs] Update tensor shape descriptions and deprecate
head_firstargument by @yzhangcs in #457 - [Simple GLA] Support dg when dht passed by @yzhangcs in #459
- [Mesa] Improve precision by @sustcsonglin in #460
- [Comba] Remove problematic
safe_expby @yzhangcs in #466 - [TokenShift] Fix invalid argument on AMD GPUs by @zhiyuan1i in #464
- [Test] Refractor model testing[skip test] by @zhiyuan1i in #467
- [Testing] Enhance generation testing by @sustcsonglin in #468
- [Simple GLA] Remove unnecessary dg for data-independent decay by @yzhangcs in #469
- [CI] Update workflow by @zhiyuan1i in #473
- [Misc.] Enhance support for some platforms by @zhiyuan1i in #470
- [Gated Delta Product] Optimize kernels by @sustcsonglin in #472
- [README] Add support for aarch64 by @zhiyuan1i in #475
- [Cache] Fix bad
seen_tokensupdate by @yzhangcs in #478 - [CI] Revert causal-conv1d to
2a288a1by @zhiyuan1i in #480 - [Parallel] Fix all tokens offsets by @yzhangcs in #479
- Use
tl.exp2for all gating operations by @yzhangcs in #361 - Refactor modeling tests by @yzhangcs in #482
- Add L2_norm for p in Recurrent ops to fix generation error by @AwesomeSeq in #483
- Refactor benchmark: adapt to latest FLA benchmark interface by @yuweih205 in #488
- [GLA] Remove all
safe_expops by @yzhangcs in #489 - [MesaNet] Remove all
safe_expops & Refactor tests by @yzhangcs in #490 - [Misc.] Support PT2.5 by @zhiyuan1i in #491
- [Misc.] Fast testing & Autotune by @sustcsonglin in #476
- fix: update import path for causal_conv1d by @yuweih205 in #492
- Make RWKV-7 init match official RWKV-LM by @johanwind in #493
- Modernize the
fused_chunkimpls by @yzhangcs in #437 - [ShortConv] Fix bad conv weight input shape during inference by @yzhangcs in #495
- [DeltaProduct] chore: remove unused functions by @timurcarstensen in #496
- [CI] Fix pipeline in GPU CIs by @zhiyuan1i in #497
- [RWKV] Make
torch.compiledecorator compatible with python3.10 by @zhiyuan1i in #498 - [GDN] Fuse 64x64 matrix inverse kernel by @yzhangcs in #501
- [L2Norm] Speedup by saving rstd by @yzhangcs in #506
- [Norm] Move eps out of sqrt by @yzhangcs in #508
- Correct types of constructor arguments with issues for configuration classes by @V0XNIHILI in #509
- Fix typo: suppoerted -> supported by @zxytim in #510
- [RWKV7] Increase Lora shape for headdim>64 by @zhiyuan1i in #512
- [Delta Rule] Support gk for WY reprs by @yzhangcs in #514
- [PaTH attention] Support headdim 128 & refactor kernel for better stability by @sustcsonglin in #503
- [Rotary] Fix
max_seqlenunder varlen mode by @yzhangcs in #516 - [Misc] Skip testing models on Nvidia 4090 CI by @zhiyuan1i in #517
- [GDP] Delete duplicated code by @yzhangcs in #518
- [WIP] Add MLA layers into fla by @toothacher17 in #395
- [Mamba] Add triton conv1d backend and fix mamba2 test by @zhiyuan1i in #520
- [Typo] Fix types in all configuration files[skip test] by @V0XNIHILI in #513
- [GSA] Fix memory boundary conditions by @JusenD in #527
New Contributors
- @jovoswald made their first contribution in #434
- @AwesomeSeq made their first contribution in #444
- @yuweih205 made their first contribution in #488
- @V0XNIHILI made their first contribution in #509
- @zxytim made their first contribution in #510
- @toothacher17 made their first contribution in #395
- @JusenD made their first contribution in #527
Full Changelog: v0.2.2...v0.3.0
v0.2.2
What's Changed
- [TokenShift] support
fused_token_shiftwithvarlenby @zhiyuan1i in #373 - [Mamba] Use official init strategies by @yzhangcs in #374
- [Mamba2] Create attn layer by @yzhangcs in #375
- [Mamba] Add attn layer & fix configs by @yzhangcs in #376
- [RWKV7] Update
fused_addcmulimpls by @zhiyuan1i in #378 - [RWKV7]: Rewrite docs to match Triton codes. by @zhiyuan1i in #381
- [RWKV7] Fix convert script by @zhiyuan1i in #383
- [Misc.] Update triton-nightly.yml by @zhiyuan1i in #382
- [PaTH] Add PaTH attention model and kernel by @sustcsonglin in #384
- [Tests] Enable tests with
causal_conv1don H100 CIs by @zhiyuan1i in #385 - [GDN]: initializing
A_loganddt_biasin_init_weightsby @HanGuo97 in #380 - [Utils] Add fused pack/unpack fns by @yzhangcs in #386
- [RWKV7] Strictly initialize rwkv7 according to RWKV-LM by @zhiyuan1i in #387
- [chore] switched to
processing_classkwarg inside Trainer invocation by @timurcarstensen in #391 - [RWKV7] Update initialization to sync with latest RWKV-LM by @zhiyuan1i in #393
- [Token Shift]: Fix potential cuda kernel parameter error for varlen by @zhiyuan1i in #397
- [DeltaProduct] fix query conv cache, remove extraneous query convs by @timurcarstensen in #396
- [Misc.] Log warnings when Triton is older than 3.2.0 by @zhiyuan1i in #394
- [RWKV7]: clean
fused_addcmul_rwkv7impls by @zhiyuan1i in #404 - [README] Update FoX venue info by @zhixuan-lin in #406
- Added details to some formulas, fixed the display error of the
L2 Lossformula by @Beortext in #407 - [RWKV7] Change fp32 errors to warnings by @zhiyuan1i in #412
- [Misc.] Add
exist_ok=Trueto all models by @zhiyuan1i in #413 - Add Rodimus impl into fla by @ziHoHe in #416
- Align RWKV7 LoRA Rank Initialization with official Implementation by @WuTianyi321 in #418
- [Canon] Add triton impls by @yzhangcs in #388
- [GDN] Support Gated Value Attention (GVA) by @Rafa-zy in #421
- [RWKV7]: clean some imps by @zhiyuan1i in #420
- [RoPE] Fix out-of-boundary bugs by @yzhangcs in #423
- [RWKV] Fix
cu_seqlenswith gradient checkpoint by @zhiyuan1i in #422
New Contributors
- @timurcarstensen made their first contribution in #391
- @ziHoHe made their first contribution in #416
- @WuTianyi321 made their first contribution in #418
- @Rafa-zy made their first contribution in #421
Full Changelog: v0.2.1...v0.2.2
v0.2.1
Highlights
🚀 Performance Boost for DeltaNet
We've achieved a notable performance enhancement for (Gated) DeltaNet models. The optimization efforts focused on the fused LayerNormGated layer, particularly for small headdims, which has resulted in a 1.1x speedup.
Below are the benchmarks for 1B parameter models, tested on 4k sequences in varlen mode, using a single H100 GPU
| TPS (K tokens/s) | |
|---|---|
| Transformer++ | 53.8 |
| DeltaNet (before) | 48.6 |
| DeltaNet (after) | 54.0 |
by running
python -m benchmarks.benchmark_training_throughput \
--name delta_net \
--batch_size 1 \
--seq_len 32768 \
--context_len 4096 \
--varlen \
--steps 512What's Changed
- [Gated DeltaNet] optimize UT transform by @sustcsonglin in #349
- [RWKV] remove duplicate params from autotune key list by @jihaoh98 in #359
- Fix some arg passing by @yibozhong in #358
- [RWKV7] Update RWKV7 to follow official initialization by @zhiyuan1i in #365
- Remove all
NT: constexprby @sustcsonglin in #364 - [Misc.] Use
logger.infoinstead ofprintinfla.utils.pyby @zhiyuan1i in #366 - [RWKV]: Prevent initialization when loading pretrained weights by @zhiyuan1i in #369
- [Norm] Optimize speed for small headdim by @yzhangcs in #368
- [GroupNorm] Optimized speed for small headdims by @yzhangcs in #371
- [LayerNormGated] Fix arg bugs during autotuning by @yzhangcs in #372
New Contributors
- @jihaoh98 made their first contribution in #359
- @yibozhong made their first contribution in #358
Full Changelog: v0.2.0...v0.2.1
v0.2.0
What's Changed
- [Attn] Delete V reduction & Enable 256 headdim tests by @yzhangcs in #273
- [RWKV7] Add more elementwise kernels by @zhiyuan1i in #271
- [CI] Remove cache and disable full test on Arc GPU by @zhiyuan1i in #274
- [Fox] Add model/layer/kernel impls w/ varlen support by @yzhangcs in #275
- [FoX] Simplify some tests and enhance tiling by @zhiyuan1i in #277
- [Test] Remove some warnings and correct condition checks by @zhiyuan1i in #278
- [CI] auto-cancel workflows on PR merge via concurrency group by @zhiyuan1i in #280
- [Test] use
tl.float16instead oftl.bfloat16by @zhiyuan1i in #281 - [OP] replace
tl.exp,tl.log,tl.log2with fast ops whenFLA_USE_FAST_OPS=1by @zhiyuan1i in #276 - [FoX] Rename
foxtoforgetting_attnby @yzhangcs in #282 - [DeltaNet] WY repr speedup by @yzhangcs in #279
- [README] Add
--no-use-pep517flag for faster installation by @zhiyuan1i in #286 - [FoX] Skip test
D>128on RTX4090 by @zhiyuan1i in #287 - [FoX] Test different forget gate initialization ranges by @zhixuan-lin in #291
- [FoX] Fix class inheritance for ForgettingTransformerForCausalLM by @zhixuan-lin in #293
- [CI] use latest stable
tritonby @zhiyuan1i in #294 - [Triton] use
tl.gatherto enhance performance by @zhiyuan1i in #270 - [WY representation] Faster lower triangle inverse by @sustcsonglin in #289
- [GroupNorm] Add argument
is_rms_normto GroupNorm by @zhixuan-lin in #295 - [GroupNorm] Return correct residual in reference implementation by @zhixuan-lin in #297
- [CI] Don't show
Tritonautotune logs in CI by @zhiyuan1i in #298 - [FoX] Use GroupNorm for QK-norm implementation in FoX by @zhixuan-lin in #299
- [Utils] Update H100 and A100 configs by @zhiyuan1i in #306
- Pass shifted labels and add a warning to RWKV-7 initialization. by @Triang-jyed-driung in #304
- [Misc.] Update imports for
GatedDeltaProductby @yzhangcs in #309 - [FAQ] Rewrite the nightly installation instructions by @zhiyuan1i in #305
- Add unit tests for model forward and variable-length checks by @yzhangcs in #310
- [Test] Improve path handling and test file detection by @zhiyuan1i in #311
- [ShortConv] Adjust input shape according to
cu_seqlensby @yzhangcs in #316 - [Tests] Add unit tests for generation with padding by @yzhangcs in #312
- [Testing] Update testing.py by @zhiyuan1i in #320
- [DeltaNet] optimize
chunk_delta_hby @sustcsonglin in #315 - [CI] Only cancel in-progress CI for pull requests by @zhiyuan1i in #321
- [Test] Skip some tests on arcA770 by @zhiyuan1i in #322
- [API] Update
head_firstparameter default toFalseby @yzhangcs in #324 - [Rotary] Remove max_seqlen parameter and adjust related logic by @yzhangcs in #326
- [DeltaProduct] Remove unnecessary config parameter. by @JulienSiems in #325
- fix the training problem of GatedDeltaProduct by @ridgerchu in #327
- [Linear Attn] Fix head_first tests by @yzhangcs in #330
- [Deprecated] Remove
head_firstoption in gla variants by @yzhangcs in #337 - [Test] Ensure most tests on Triton 3.2.0 and add
4096seq_length in tests [skip test] by @zhiyuan1i in #300 - [FoX] Merge code to FlashAttention | support batch inference by @sustcsonglin in #333
- [DeltaNet] Delete
head_firstoption for all by @yzhangcs in #338 - [WIP] Remove head_first option by @yzhangcs in #339
- [RWKV7] add
input_precisionparam [skip test] by @zhiyuan1i in #335 - [Testing] Add recursive dependency finding for test discovery by @zhiyuan1i in #341
- [WIP] Delete
head_firstoption for cumsum by @yzhangcs in #342 - [WIP] Delete head_first tests for DeltaNet/GLA by @yzhangcs in #344
- [Attn] Remove
head_first& renameoffsetstocu_seqlensby @yzhangcs in #345 - [RWKV7] Drop some kernels to enhance speed by @zhiyuan1i in #346
- Remove the
head_firstarg from several token mixing layer fns. by @yzhangcs in #347
New Contributors
- @sustcsonglin made their first contribution in #289
Full Changelog: v0.1.2...v0.2.0
v0.1.2
What's Changed
- [RWKV7] fix
RWKV7Attention.__init__by @exhyy in #238 - fix(triton): remove num_warps=8 in bwd_prepare_wy_repr_kernel to avoid MMA layout assertion on non-Ampere GPUs. by @kugwzk in #240
- [Fix]: reshape o before o_proj in linear_attn layer. by @Luther-Sparks in #243
- [CI] Seperate tests to compile , normal and varlen by @zhiyuan1i in #247
- [ABC] Add
use_ropeparameter to ABCAttention and ABCConfig & Fix compiler bugs in kernels by @yzhangcs in #248 - [CI] trigger GPU workflow only on pull_request events by @zhiyuan1i in #249
- Create test_linearatten.py by @kangyiyang in #250
- [CI] Fix all erros and enable testing for PR by @zhiyuan1i in #251
- [CI] add H100 GPU by @zhiyuan1i in #254
- [Gated DeltaNet] fix gdn kernel bugs on h100 when vdim=64 by @kugwzk in #256
- [Test] Enhance support for NVIDIA Hopper GPU by @zhiyuan1i in #257
- [FAQ] Update triton-nightly links by @yzhangcs in #259
- [Attn] Add triton impls for MHA/GQA by @yzhangcs in #260
- [Attn] Use larger block size for hopper devices by @yzhangcs in #261
- [Attn] Enable test for attn by @zhiyuan1i in #262
- [CI] fix a syntax error in triton-nightly by @zhiyuan1i in #263
- Bump
flato v0.1.2 by @yzhangcs in #264
New Contributors
- @exhyy made their first contribution in #238
- @kugwzk made their first contribution in #240
- @Luther-Sparks made their first contribution in #243
- @yzhangcs made their first contribution in #248
- @kangyiyang made their first contribution in #250
Full Changelog: v0.1.1...v0.1.2
v0.1.1
What's Changed
- [README] Fix HGRN2 bibs in a43b525
- [LightNet] Use fused norm, output gate and proj in bc86b59
- [LayerNormGated] Support combined sigmoid/swish output gate in 27de88d
- [NSA] Fix missing Cache class in a519920
- [BUG][RWKV7] Fix value head dim mismatch in 8fd1d3d
- [DeltaNet] Improved kernel speed by finegrained autotuning in 3b9bba8
Full Changelog: v0.1.0...v0.1.1
💥 v0.1.0
What's Changed
- Update README.md by @eltociear in #2
- fix simple gla backward by @sunyt32 in #6
- Adding RWKV-v4. by @ridgerchu in #8
- fixed hgrn.py paper link and title by @ridgerchu in #10
- Update recurrent_naive.py by @hypnopump in #12
- fix: calculate du on different batch by @uniartisan in #35
- fix: enhance state gradient when bf16 by @uniartisan in #37
- Add implementations of Mamba 2 into FLA by @DanFosing in #39
- Minor mamba-2 fixes by @DanFosing in #40
- [DeltaNet] Adds beta as a vector option by @hypnopump in #42
- [DRAFT] Beta gradient does not match by @hypnopump in #43
- [Attn] fix negative value of seqlen offset during sft by @ChaosCodes in #45
- [RWKV6] fix backward if h0 not passed by @hypnopump in #48
- Replace mamba2
mamba_chunk_scan_combinedtriton kernel bysimple_glatriton kernel by @learning-chip in #49 - benchmark script for simple_gla vs mamba2 kernel by @learning-chip in #50
- Update amp custom_fwd, custom_bwd usage for torch 2.4.0 compatibility by @mirceamironenco in #54
- Fix syntax error by @JulienSiems in #55
- Add
__init__.pyinfla/ops/commonfor automatic package discovery by @zhixuan-lin in #56 - [
Mamba2] Post Merge Fixes -norm_before_gateand generation withinputs_embedsby @vasqu in #57 - Correctly compute
max_seqlenwhenmax_position_embeddingsisNoneby @zhixuan-lin in #59 - add chunked kl div by @ChaosCodes in #62
- Add fine-grained warning category for easier supression by @mirceamironenco in #65
- Update fused_chunk.py by @hypnopump in #72
- [
Mamba2] Fix slow path by @vasqu in #84 - Add BitNet by @DustinWang1 in #85
- Fix RWKV6 Cache Problems by @WorldEditors in #78
- Bugs in RWKV6 OP by @WorldEditors in #87
- fix mamba2 cache bug by @WorldEditors in #89
- fix dh0 is None breaking backward pass by @Sxela in #102
- support varlen training for conv1d by @LKJacky in #116
- blood for the torch.compile gods by @harrisonvanderbyl in #119
- Added forward pass for chunckwise ttt-linear, varlen is supported. by @Pan-Yuqi in #124
- Add scripts for converting pretrained RWKV7 models to fla format by @Triang-jyed-driung in #128
- [
Mamba2] Fixes for caching and multiple other small issues by @vasqu in #129 - [LinAttn] Fix handling of None scale in chunk_linear_attn for output normalization by @HallerPatrick in #130
- Fix incorrect kwarg name in
fused_recurrentby @fffffgggg54 in #134 - RWKV-7 conversion and evals by @Triang-jyed-driung in #135
- Fixed dtype mismatch of mamba & mamba2 under residual_in_fp32 setting by @chengshuang18 in #137
- [RWKV7] Fix masking before time shifting modules by @Triang-jyed-driung in #141
- [RWKV7, but applicable to all models] Update modeling_rwkv7.py: Fixing
base_model_prefixby @Triang-jyed-driung in #143 - fix bitattn with latest attn implementation. by @ridgerchu in #146
- [RWKV7] Remove in-place operations and add gradient checkpointing for
v_firstby @Triang-jyed-driung in #145 - [BitNet] Fix bugs of model definitions by @ridgerchu in #147
- Fix #157 by @jannalulu in #167
- [Mamba, Samba] Add weight initializations and reset_parameters() in _init_weights() for compatibility in Flame by @zaydzuhri in #169
- fix lint errors by @jannalulu in #170
- [RWKV] Follow-up to fix cache management by @jannalulu in #168
- one liner by @seanxwzhang in #178
- [Attn] Fix cache update of swa by @Pan-Yuqi in #183
- [RWKV7] Fix conversion precision by @Triang-jyed-driung in #188
- [GRPO]: add grpo functions by @uniartisan in #189
- [RWKV] fix logits handling by @jannalulu in #192
- [Modules]: Enhance the precision of the fused LayerNorm OP. by @uniartisan in #200
- [MISC] fix delta_net logit handling by @jannalulu in #205
- [RWKV7] Keep compatibility with Torch Compiler. by @uniartisan in #208
- [Misc.] Update wrapper to support contiguous and guard custom device … by @uniartisan in #212
- [Models] Fix the error in the judgment of past_key_values when inputs… by @uniartisan in #213
- [Titans] Update Titans implementation by @rucnyz in #214
- [Mamba2] Fix initialization by @HanGuo97 in #225
- [TTT] Update fused chunk ops and state bias term by @Pan-Yuqi in #230
- Enable utils.py to be imported on CPU-only machines (#231) by @zhuzeyuan in #232
- [Utils] use fla.utils.device instead of cuda by @uniartisan in #163
- fix(GatedDeltaNet): Ensure integer dimensions when using
expand_vby @vladislavalerievich in #234
New Contributors
- @eltociear made their first contribution in #2
- @sunyt32 made their first contribution in #6
- @ridgerchu made their first contribution in #8
- @hypnopump made their first contribution in #12
- @DanFosing made their first contribution in #39
- @ChaosCodes made their first contribution in #45
- @learning-chip made their first contribution in #49
- @mirceamironenco made their first contribution in #54
- @JulienSiems made their first contribution in #55
- @zhixuan-lin made their first contribution in #56
- @vasqu made their first contribution in #57
- @DustinWang1 made their first contribution in #85
- @WorldEditors made their first contribution in #78
- @Sxela made their first contribution in #102
- @LKJacky made their first contribution in #116
- @harrisonvanderbyl made their first contribution in #119
- @Triang-jyed-driung made their first contribution in #128
- @HallerPatrick made their first contribution in #130
- @fffffgggg54 made their first contribution in #134
- @chengshuang18 made their first contribution in https://github.com/fla-org/flash-linear-attenti...