File tree Expand file tree Collapse file tree 3 files changed +12
-2
lines changed Expand file tree Collapse file tree 3 files changed +12
-2
lines changed Original file line number Diff line number Diff line change @@ -352,6 +352,7 @@ def log(
352
352
step : int ,
353
353
global_avg_loss : float ,
354
354
global_max_loss : float ,
355
+ grad_norm : float ,
355
356
extra_metrics : dict [str , Any ] | None = None ,
356
357
):
357
358
assert self .num_flops_per_token > 0 , "num_flops_per_token must be set"
@@ -377,6 +378,7 @@ def log(
377
378
metrics = {
378
379
"loss_metrics/global_avg_loss" : global_avg_loss ,
379
380
"loss_metrics/global_max_loss" : global_max_loss ,
381
+ "grad_norm" : grad_norm ,
380
382
"throughput(tps)" : tps ,
381
383
"tflops" : tflops ,
382
384
"mfu(%)" : mfu ,
@@ -400,6 +402,7 @@ def log(
400
402
logger .info (
401
403
f"{ color .red } step: { step :2} "
402
404
f"{ color .green } loss: { global_avg_loss :7.4f} "
405
+ f"{ color .orange } grad_norm: { grad_norm :7.4f} "
403
406
f"{ color .yellow } memory: { device_mem_stats .max_reserved_gib :5.2f} GiB"
404
407
f"({ device_mem_stats .max_reserved_pct :.2f} %) "
405
408
f"{ color .blue } tps: { round (tps ):,} "
Original file line number Diff line number Diff line change @@ -135,6 +135,7 @@ class Color:
135
135
cyan = "\033 [36m"
136
136
white = "\033 [37m"
137
137
reset = "\033 [39m"
138
+ orange = "\033 [38;2;180;60;0m"
138
139
139
140
140
141
@dataclass (frozen = True )
@@ -148,6 +149,7 @@ class NoColor:
148
149
cyan = ""
149
150
white = ""
150
151
reset = ""
152
+ orange = ""
151
153
152
154
153
155
def check_if_feature_in_pytorch (
Original file line number Diff line number Diff line change @@ -431,7 +431,7 @@ def train_step(
431
431
loss = self .forward_backward_step (input_dict , labels )
432
432
accumulated_losses .append (loss .detach ())
433
433
434
- dist_utils .clip_grad_norm_ (
434
+ grad_norm = dist_utils .clip_grad_norm_ (
435
435
[p for m in self .model_parts for p in m .parameters ()],
436
436
self .job_config .training .max_norm ,
437
437
foreach = True ,
@@ -463,7 +463,12 @@ def train_step(
463
463
else :
464
464
global_avg_loss = global_max_loss = loss .detach ().item ()
465
465
466
- self .metrics_processor .log (self .step , global_avg_loss , global_max_loss )
466
+ self .metrics_processor .log (
467
+ self .step ,
468
+ global_avg_loss ,
469
+ global_max_loss ,
470
+ grad_norm .item (),
471
+ )
467
472
468
473
@record
469
474
def train (self ):
You can’t perform that action at this time.
0 commit comments