Skip to content

Commit a85e2a5

Browse files
authored
Log grad_norm (#1339)
Since the `grad_norm` is always computed anyway and might provide useful insight into the training dynamics I don't see a reason not to log it. `grad_norm` could also be logged to the output [here](https://github.com/pytorch/torchtitan/blob/main/torchtitan/components/metrics.py#L400-L408), let me know if I should add that. Terminal logging demo: <img width="1192" alt="grad_norm_logging_demo" src="https://github.com/user-attachments/assets/f7c4721a-ec63-48e7-9d13-8ea0f4e26326" /> <img width="1192" alt="grad_norm_logging_demo" src="https://github.com/user-attachments/assets/f7c4721a-ec63-48e7-9d13-8ea0f4e26326" />
1 parent dc7fd23 commit a85e2a5

File tree

3 files changed

+12
-2
lines changed

3 files changed

+12
-2
lines changed

torchtitan/components/metrics.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -352,6 +352,7 @@ def log(
352352
step: int,
353353
global_avg_loss: float,
354354
global_max_loss: float,
355+
grad_norm: float,
355356
extra_metrics: dict[str, Any] | None = None,
356357
):
357358
assert self.num_flops_per_token > 0, "num_flops_per_token must be set"
@@ -377,6 +378,7 @@ def log(
377378
metrics = {
378379
"loss_metrics/global_avg_loss": global_avg_loss,
379380
"loss_metrics/global_max_loss": global_max_loss,
381+
"grad_norm": grad_norm,
380382
"throughput(tps)": tps,
381383
"tflops": tflops,
382384
"mfu(%)": mfu,
@@ -400,6 +402,7 @@ def log(
400402
logger.info(
401403
f"{color.red}step: {step:2} "
402404
f"{color.green}loss: {global_avg_loss:7.4f} "
405+
f"{color.orange}grad_norm: {grad_norm:7.4f} "
403406
f"{color.yellow}memory: {device_mem_stats.max_reserved_gib:5.2f}GiB"
404407
f"({device_mem_stats.max_reserved_pct:.2f}%) "
405408
f"{color.blue}tps: {round(tps):,} "

torchtitan/tools/utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,7 @@ class Color:
135135
cyan = "\033[36m"
136136
white = "\033[37m"
137137
reset = "\033[39m"
138+
orange = "\033[38;2;180;60;0m"
138139

139140

140141
@dataclass(frozen=True)
@@ -148,6 +149,7 @@ class NoColor:
148149
cyan = ""
149150
white = ""
150151
reset = ""
152+
orange = ""
151153

152154

153155
def check_if_feature_in_pytorch(

torchtitan/train.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -431,7 +431,7 @@ def train_step(
431431
loss = self.forward_backward_step(input_dict, labels)
432432
accumulated_losses.append(loss.detach())
433433

434-
dist_utils.clip_grad_norm_(
434+
grad_norm = dist_utils.clip_grad_norm_(
435435
[p for m in self.model_parts for p in m.parameters()],
436436
self.job_config.training.max_norm,
437437
foreach=True,
@@ -463,7 +463,12 @@ def train_step(
463463
else:
464464
global_avg_loss = global_max_loss = loss.detach().item()
465465

466-
self.metrics_processor.log(self.step, global_avg_loss, global_max_loss)
466+
self.metrics_processor.log(
467+
self.step,
468+
global_avg_loss,
469+
global_max_loss,
470+
grad_norm.item(),
471+
)
467472

468473
@record
469474
def train(self):

0 commit comments

Comments
 (0)