Skip to content

Commit a928bba

Browse files
committed
Pass down TrainingArgs instance to iterate_batches function and TrainingCallback methods
Addresses ml-explore#1224
1 parent 7a83077 commit a928bba

File tree

1 file changed

+11
-6
lines changed

1 file changed

+11
-6
lines changed

llms/mlx_lm/tuner/trainer.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import time
66
from dataclasses import dataclass, field
77
from pathlib import Path
8-
from typing import Union
8+
from typing import Optional, Union
99

1010
import mlx.core as mx
1111
import mlx.nn as nn
@@ -76,7 +76,9 @@ def default_loss(model, inputs, targets, lengths):
7676
return ce, ntoks
7777

7878

79-
def iterate_batches(dataset, tokenizer, batch_size, max_seq_length, train=False):
79+
def iterate_batches(
80+
dataset, tokenizer, batch_size, max_seq_length, train=False, args=None
81+
):
8082
# Sort by length:
8183
idx = sorted(range(len(dataset)), key=lambda idx: len(dataset[idx]))
8284
if len(dataset) < batch_size:
@@ -167,11 +169,13 @@ def evaluate(
167169

168170
class TrainingCallback:
169171

170-
def on_train_loss_report(self, train_info: dict):
172+
def on_train_loss_report(
173+
self, train_info: dict, args: Optional[TrainingArgs] = None
174+
):
171175
"""Called to report training loss at specified intervals."""
172176
pass
173177

174-
def on_val_loss_report(self, val_info: dict):
178+
def on_val_loss_report(self, val_info: dict, args: Optional[TrainingArgs] = None):
175179
"""Called to report validation loss at specified intervals or the beginning."""
176180
pass
177181

@@ -227,6 +231,7 @@ def step(batch):
227231
batch_size=args.batch_size,
228232
max_seq_length=args.max_seq_length,
229233
train=True,
234+
args=args,
230235
),
231236
):
232237
# Report validation loss if needed, the first validation loss
@@ -258,7 +263,7 @@ def step(batch):
258263
"val_loss": val_loss,
259264
"val_time": val_time,
260265
}
261-
training_callback.on_val_loss_report(val_info)
266+
training_callback.on_val_loss_report(val_info, args=args)
262267

263268
start = time.perf_counter()
264269

@@ -301,7 +306,7 @@ def step(batch):
301306
"trained_tokens": trained_tokens,
302307
"peak_memory": peak_mem,
303308
}
304-
training_callback.on_train_loss_report(train_info)
309+
training_callback.on_train_loss_report(train_info, args=args)
305310

306311
losses = 0
307312
n_tokens = 0

0 commit comments

Comments
 (0)