Skip to content

Commit 3e341e9

Browse files
committed
k_proj bias is needed for qk norm
1 parent fbf34dc commit 3e341e9

File tree

2 files changed

+1
-4
lines changed

2 files changed

+1
-4
lines changed

python/src/diffusionkit/mlx/mmdit.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -693,7 +693,7 @@ def __init__(
693693
self.kv_proj_embed_dim = self.per_head_dim * n_heads
694694

695695
# Note: key bias is redundant due to softmax invariance
696-
self.k_proj = nn.Linear(embed_dim, self.kv_proj_embed_dim, bias=False)
696+
self.k_proj = nn.Linear(embed_dim, self.kv_proj_embed_dim)
697697
self.q_proj = nn.Linear(embed_dim, embed_dim)
698698
self.v_proj = nn.Linear(embed_dim, self.kv_proj_embed_dim)
699699
self.o_proj = nn.Linear(embed_dim, embed_dim)

python/src/diffusionkit/mlx/model_io.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -248,9 +248,6 @@ def flux_state_dict_adjustments(state_dict, prefix="", hidden_size=3072, mlp_rat
248248
for k, v in state_dict.items()
249249
}
250250

251-
# Remove k_proj bias
252-
state_dict = {k: v for k, v in state_dict.items() if "k_proj.bias" not in k}
253-
254251
state_dict["x_embedder.proj.weight"] = mx.expand_dims(
255252
mx.expand_dims(state_dict["x_embedder.proj.weight"], axis=1), axis=1
256253
)

0 commit comments

Comments
 (0)