|
5 | 5 | import time
|
6 | 6 | from dataclasses import dataclass, field
|
7 | 7 | from pathlib import Path
|
8 |
| -from typing import Union |
| 8 | +from typing import Optional, Union |
9 | 9 |
|
10 | 10 | import mlx.core as mx
|
11 | 11 | import mlx.nn as nn
|
@@ -76,7 +76,9 @@ def default_loss(model, inputs, targets, lengths):
|
76 | 76 | return ce, ntoks
|
77 | 77 |
|
78 | 78 |
|
79 |
| -def iterate_batches(dataset, tokenizer, batch_size, max_seq_length, train=False): |
| 79 | +def iterate_batches( |
| 80 | + dataset, tokenizer, batch_size, max_seq_length, train=False, args=None |
| 81 | +): |
80 | 82 | # Sort by length:
|
81 | 83 | idx = sorted(range(len(dataset)), key=lambda idx: len(dataset[idx]))
|
82 | 84 | if len(dataset) < batch_size:
|
@@ -167,11 +169,13 @@ def evaluate(
|
167 | 169 |
|
168 | 170 | class TrainingCallback:
|
169 | 171 |
|
170 |
| - def on_train_loss_report(self, train_info: dict): |
| 172 | + def on_train_loss_report( |
| 173 | + self, train_info: dict, args: Optional[TrainingArgs] = None |
| 174 | + ): |
171 | 175 | """Called to report training loss at specified intervals."""
|
172 | 176 | pass
|
173 | 177 |
|
174 |
| - def on_val_loss_report(self, val_info: dict): |
| 178 | + def on_val_loss_report(self, val_info: dict, args: Optional[TrainingArgs] = None): |
175 | 179 | """Called to report validation loss at specified intervals or the beginning."""
|
176 | 180 | pass
|
177 | 181 |
|
@@ -227,6 +231,7 @@ def step(batch):
|
227 | 231 | batch_size=args.batch_size,
|
228 | 232 | max_seq_length=args.max_seq_length,
|
229 | 233 | train=True,
|
| 234 | + args=args, |
230 | 235 | ),
|
231 | 236 | ):
|
232 | 237 | # Report validation loss if needed, the first validation loss
|
@@ -258,7 +263,7 @@ def step(batch):
|
258 | 263 | "val_loss": val_loss,
|
259 | 264 | "val_time": val_time,
|
260 | 265 | }
|
261 |
| - training_callback.on_val_loss_report(val_info) |
| 266 | + training_callback.on_val_loss_report(val_info, args=args) |
262 | 267 |
|
263 | 268 | start = time.perf_counter()
|
264 | 269 |
|
@@ -301,7 +306,7 @@ def step(batch):
|
301 | 306 | "trained_tokens": trained_tokens,
|
302 | 307 | "peak_memory": peak_mem,
|
303 | 308 | }
|
304 |
| - training_callback.on_train_loss_report(train_info) |
| 309 | + training_callback.on_train_loss_report(train_info, args=args) |
305 | 310 |
|
306 | 311 | losses = 0
|
307 | 312 | n_tokens = 0
|
|
0 commit comments