Skip to content

Commit e673a97

Browse files
authored
some nits in lora (#208)
1 parent 3be5153 commit e673a97

File tree

4 files changed

+8
-7
lines changed

4 files changed

+8
-7
lines changed

mlx_lm/examples/lora_config.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# The path to the local model directory or Hugging Face repo.
2-
model: "mlx_model"
2+
model: "mlx-community/Llama-3.2-1B-Instruct"
33

44
# Whether or not to train (boolean)
55
train: true
@@ -17,7 +17,7 @@ optimizer: adamw
1717
# bias_correction: true
1818

1919
# Directory with {train, valid, test}.jsonl files
20-
data: "/path/to/training/data"
20+
data: "mlx-community/WikiSQL"
2121

2222
# The PRNG seed
2323
seed: 0

mlx_lm/fuse.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,9 @@ def main() -> None:
7575
model = load_adapters(model, args.adapter_path)
7676

7777
fused_linears = [
78-
(n, m.fuse()) for n, m in model.named_modules() if hasattr(m, "fuse")
78+
(n, m.fuse(de_quantize=args.de_quantize))
79+
for n, m in model.named_modules()
80+
if hasattr(m, "fuse")
7981
]
8082

8183
if fused_linears:

mlx_lm/lora.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@
6565
"config": None,
6666
"grad_checkpoint": False,
6767
"lr_schedule": None,
68-
"lora_parameters": {"rank": 8, "dropout": 0.0, "scale": 10.0},
68+
"lora_parameters": {"rank": 8, "dropout": 0.0, "scale": 20.0},
6969
"mask_prompt": False,
7070
"wandb": None,
7171
}

mlx_lm/tuner/lora.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,9 +52,8 @@ def fuse(self, de_quantize: bool = False):
5252
output_dims, input_dims = weight.shape
5353
fused_linear = nn.Linear(input_dims, output_dims, bias=bias)
5454

55-
lora_b = (self.scale * self.lora_b.T).astype(dtype)
56-
lora_a = self.lora_a.T.astype(dtype)
57-
fused_linear.weight = weight + lora_b @ lora_a
55+
delta = ((self.scale * self.lora_b.T) @ self.lora_a.T).astype(dtype)
56+
fused_linear.weight = weight + delta
5857
if bias:
5958
fused_linear.bias = linear.bias
6059

0 commit comments

Comments
 (0)