10
10
import mlx .core as mx
11
11
import mlx .nn as nn
12
12
import numpy as np
13
+ from mlx .nn .utils import average_gradients
13
14
from mlx .utils import tree_flatten
14
15
15
16
@@ -84,9 +85,16 @@ def iterate_batches(dataset, tokenizer, batch_size, max_seq_length, train=False)
84
85
f" examples but only has { len (dataset )} ."
85
86
)
86
87
88
+ # If running in distributed mode (N machines) then each one should skip N-1
89
+ # samples
90
+ step = mx .distributed .init ().size ()
91
+ if batch_size % step != 0 :
92
+ raise ValueError ("The batch size must be divisible by the number of workers" )
93
+
87
94
# Make the batches:
88
95
batch_idx = [
89
- idx [i : i + batch_size ] for i in range (0 , len (idx ) - batch_size + 1 , batch_size )
96
+ idx [i : i + batch_size : step ]
97
+ for i in range (0 , len (idx ) - batch_size + 1 , batch_size )
90
98
]
91
99
92
100
while True :
@@ -112,9 +120,9 @@ def iterate_batches(dataset, tokenizer, batch_size, max_seq_length, train=False)
112
120
max_length_in_batch = pad_to * ((max (lengths ) + pad_to - 1 ) // pad_to )
113
121
max_length_in_batch = min (max_length_in_batch , max_seq_length )
114
122
115
- batch_arr = np .zeros ((batch_size , max_length_in_batch ), np .int32 )
123
+ batch_arr = np .zeros ((batch_size // step , max_length_in_batch ), np .int32 )
116
124
117
- for j in range (batch_size ):
125
+ for j in range (batch_size // step ):
118
126
truncated_length = min (lengths [j ], max_seq_length )
119
127
batch_arr [j , :truncated_length ] = batch [j ][:truncated_length ]
120
128
lengths [j ] = (
@@ -138,7 +146,7 @@ def evaluate(
138
146
loss : callable = default_loss ,
139
147
iterate_batches : callable = iterate_batches ,
140
148
):
141
- all_losses = []
149
+ all_losses = 0
142
150
ntokens = 0
143
151
144
152
index_iterator = iter (range (num_batches )) if num_batches != - 1 else iter (int , 1 )
@@ -153,10 +161,14 @@ def evaluate(
153
161
),
154
162
):
155
163
losses , toks = loss (model , * batch )
156
- all_losses .append ((losses * toks ).item ())
157
- ntokens += toks .item ()
164
+ all_losses += losses * toks
165
+ ntokens += toks
166
+ mx .eval (all_losses , ntokens )
167
+
168
+ all_losses = mx .distributed .all_sum (all_losses )
169
+ ntokens = mx .distributed .all_sum (ntokens )
158
170
159
- return np . sum (all_losses ) / ntokens
171
+ return (all_losses / ntokens ). item ()
160
172
161
173
162
174
class TrainingCallback :
@@ -182,6 +194,11 @@ def train(
182
194
training_callback : TrainingCallback = None ,
183
195
):
184
196
print (f"Starting training..., iters: { args .iters } " )
197
+ world = mx .distributed .init ()
198
+ world_size = world .size ()
199
+ rank = world .rank ()
200
+ if world_size > 1 :
201
+ print (f"Node { rank } of { world_size } " )
185
202
186
203
if args .grad_checkpoint :
187
204
grad_checkpoint (model .layers [0 ])
@@ -192,15 +209,19 @@ def step(batch):
192
209
# Forward and backward pass
193
210
(lvalue , toks ), grad = loss_value_and_grad (model , * batch )
194
211
212
+ # All reduce the gradients if running in distributed mode
213
+ grad = average_gradients (grad )
214
+
195
215
# Model update
196
216
optimizer .update (model , grad )
197
217
198
218
return lvalue , toks
199
219
200
220
loss_value_and_grad = nn .value_and_grad (model , loss )
201
221
202
- losses = []
222
+ losses = 0
203
223
n_tokens = 0
224
+ steps = 0
204
225
trained_tokens = 0
205
226
# Main training loop
206
227
start = time .perf_counter ()
@@ -229,9 +250,13 @@ def step(batch):
229
250
iterate_batches = iterate_batches ,
230
251
)
231
252
val_time = time .perf_counter () - stop
232
- print (
233
- f"Iter { it } : " f"Val loss { val_loss :.3f} , " f"Val took { val_time :.3f} s"
234
- )
253
+ if rank == 0 :
254
+ print (
255
+ f"Iter { it } : "
256
+ f"Val loss { val_loss :.3f} , "
257
+ f"Val took { val_time :.3f} s" ,
258
+ flush = True ,
259
+ )
235
260
236
261
if training_callback is not None :
237
262
val_info = {
@@ -244,30 +269,33 @@ def step(batch):
244
269
start = time .perf_counter ()
245
270
246
271
lvalue , toks = step (batch )
247
- mx .eval (state , lvalue , toks )
248
-
249
- # Record loss
250
- losses .append (lvalue .item ())
251
- n_tokens += toks .item ()
272
+ losses += lvalue
273
+ n_tokens += toks
274
+ steps += 1
275
+ mx .eval (state , losses , n_tokens )
252
276
253
277
# Report training loss if needed
254
278
if it % args .steps_per_report == 0 or it == args .iters :
255
279
stop = time .perf_counter ()
256
280
257
- train_loss = np .mean (losses )
281
+ train_loss = mx .distributed .all_sum (losses ).item ()
282
+ train_loss /= steps * mx .distributed .init ().size ()
283
+ n_tokens = mx .distributed .all_sum (n_tokens ).item ()
258
284
learning_rate = optimizer .learning_rate .item ()
259
285
it_sec = args .steps_per_report / (stop - start )
260
286
tokens_sec = float (n_tokens ) / (stop - start )
261
287
trained_tokens += n_tokens
262
288
peak_mem = mx .metal .get_peak_memory () / 2 ** 30
263
- print (
264
- f"Iter { it } : Train loss { train_loss :.3f} , "
265
- f"Learning Rate { learning_rate :.3e} , "
266
- f"It/sec { it_sec :.3f} , "
267
- f"Tokens/sec { tokens_sec :.3f} , "
268
- f"Trained Tokens { trained_tokens } , "
269
- f"Peak mem { peak_mem :.3f} GB"
270
- )
289
+ if rank == 0 :
290
+ print (
291
+ f"Iter { it } : Train loss { train_loss :.3f} , "
292
+ f"Learning Rate { learning_rate :.3e} , "
293
+ f"It/sec { it_sec :.3f} , "
294
+ f"Tokens/sec { tokens_sec :.3f} , "
295
+ f"Trained Tokens { trained_tokens } , "
296
+ f"Peak mem { peak_mem :.3f} GB" ,
297
+ flush = True ,
298
+ )
271
299
272
300
if training_callback is not None :
273
301
train_info = {
@@ -281,8 +309,9 @@ def step(batch):
281
309
}
282
310
training_callback .on_train_loss_report (train_info )
283
311
284
- losses = []
312
+ losses = 0
285
313
n_tokens = 0
314
+ steps = 0
286
315
start = time .perf_counter ()
287
316
288
317
# Save adapter weights
0 commit comments