Skip to content

Commit e117596

Browse files
support DDP training with no_sync and fused_allreduce_gradients (#332)
* support DDP training with no_sync and fused_allreduce_gradients * fix * enhance run_check and sanity check * remove autocast context manager in visu.py * fix logger
1 parent 5ad1f19 commit e117596

File tree

5 files changed

+149
-67
lines changed

5 files changed

+149
-67
lines changed

ppsci/solver/eval.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -55,12 +55,13 @@ def _eval_by_dataset(solver, epoch_id: int, log_freq: int) -> float:
5555
for key in solver.eval_time_info:
5656
solver.eval_time_info[key].reset()
5757
reader_cost = time.perf_counter() - reader_tic
58-
5958
for v in input_dict.values():
6059
v.stop_gradient = False
6160

6261
# forward
63-
with solver.autocast_context_manager(), solver.no_grad_context_manager():
62+
with solver.autocast_context_manager(
63+
solver.use_amp, solver.amp_level
64+
), solver.no_grad_context_manager(solver.eval_with_no_grad):
6465
output_dict, validator_loss = solver.forward_helper.eval_forward(
6566
_validator.output_expr,
6667
input_dict,
@@ -179,11 +180,13 @@ def _eval_by_batch(solver, epoch_id: int, log_freq: int) -> float:
179180
solver.eval_time_info[key].reset()
180181
reader_cost = time.perf_counter() - reader_tic
181182
batch_size = next(iter(input_dict.values())).shape[0]
182-
183183
for v in input_dict.values():
184184
v.stop_gradient = False
185+
185186
# forward
186-
with solver.autocast_context_manager(), solver.no_grad_context_manager():
187+
with solver.autocast_context_manager(
188+
solver.use_amp, solver.amp_level
189+
), solver.no_grad_context_manager(solver.eval_with_no_grad):
187190
output_dict, validator_loss = solver.forward_helper.eval_forward(
188191
_validator.output_expr,
189192
input_dict,

ppsci/solver/solver.py

Lines changed: 65 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
import contextlib
1818
import copy
19+
import itertools
1920
import os
2021
import sys
2122
from typing import Any
@@ -203,6 +204,16 @@ def __init__(
203204

204205
# whether calculate metrics after each batch during evaluate
205206
self.compute_metric_by_batch = compute_metric_by_batch
207+
if validator is not None:
208+
for metric in itertools.chain(
209+
*[_v.metric.values() for _v in self.validator.values()]
210+
):
211+
if metric.keep_batch ^ compute_metric_by_batch:
212+
raise ValueError(
213+
f"{misc.typename(metric)}.keep_batch should be "
214+
f"{compute_metric_by_batch} when compute_metric_by_batch="
215+
f"{compute_metric_by_batch}."
216+
)
206217
# whether set `stop_gradient=True` for every Tensor if no differentiation involved during computation
207218
self.eval_with_no_grad = eval_with_no_grad
208219

@@ -247,6 +258,11 @@ def __init__(
247258
self.model = fleet.distributed_model(self.model)
248259
if self.optimizer is not None:
249260
self.optimizer = fleet.distributed_optimizer(self.optimizer)
261+
logger.warning(
262+
f"Detected world_size({self.world_size}) > 1, it is recommended to "
263+
"scale up the learning rate and reduce the epochs or "
264+
"iters_per_epoch according to the world_size number both linearly."
265+
)
250266

251267
self.global_step = 0
252268

@@ -493,7 +509,7 @@ def predict(
493509
batch_input_dict[key].stop_gradient = False
494510

495511
# forward
496-
with self.autocast_context_manager():
512+
with self.autocast_context_manager(self.use_amp, self.amp_level):
497513
batch_output_dict = self.model(batch_input_dict)
498514

499515
# collect batch data
@@ -522,36 +538,74 @@ def export(self):
522538
jit.save(static_model, save_path)
523539
logger.info(f"The inference model has been exported to {export_dir}")
524540

525-
def autocast_context_manager(self) -> contextlib.AbstractContextManager:
526-
"""Autocast context manager for Auto Mix Precision.
541+
def autocast_context_manager(
542+
self, enable: bool, level: Literal["O0", "O1", "O2"] = "O1"
543+
) -> contextlib.AbstractContextManager:
544+
"""Smart autocast context manager for Auto Mix Precision.
545+
546+
Args:
547+
enable (bool): Enable autocast.
548+
level (Literal["O0", "O1", "O2"]): Autocast level.
527549
528550
Returns:
529-
Union[contextlib.AbstractContextManager]: Context manager.
551+
contextlib.AbstractContextManager: Smart autocast context manager.
530552
"""
531-
if self.use_amp:
532-
ctx_manager = amp.auto_cast(level=self.amp_level)
553+
if enable:
554+
ctx_manager = amp.auto_cast(level=level)
533555
else:
534556
ctx_manager = (
535557
contextlib.nullcontext()
536558
if sys.version_info >= (3, 7)
537559
else contextlib.suppress()
538560
)
539-
540561
return ctx_manager
541562

542-
def no_grad_context_manager(self) -> contextlib.AbstractContextManager:
543-
"""No grad manager.
563+
def no_grad_context_manager(
564+
self, enable: bool
565+
) -> contextlib.AbstractContextManager:
566+
"""Smart no_grad context manager.
567+
568+
Args:
569+
enable (bool): Enable no_grad.
544570
545571
Returns:
546-
Union[contextlib.AbstractContextManager]: Context manager.
572+
contextlib.AbstractContextManager: Smart no_grad context manager.
547573
"""
548-
if self.eval_with_no_grad:
574+
if enable:
549575
ctx_manager = paddle.no_grad()
550576
else:
551577
ctx_manager = (
552578
contextlib.nullcontext()
553579
if sys.version_info >= (3, 7)
554580
else contextlib.suppress()
555581
)
582+
return ctx_manager
583+
584+
def no_sync_context_manager(
585+
self,
586+
enable: bool,
587+
ddp_model: paddle.DataParallel,
588+
) -> contextlib.AbstractContextManager:
589+
"""Smart no_sync context manager for given model.
590+
NOTE: Only `paddle.DataParallel` object has `no_sync` interface.
556591
592+
Args:
593+
enable (bool): Enable no_sync.
594+
595+
Returns:
596+
contextlib.AbstractContextManager: Smart no_sync context manager.
597+
"""
598+
if enable:
599+
if not isinstance(ddp_model, paddle.DataParallel):
600+
raise TypeError(
601+
"no_sync interface is only for model with type paddle.DataParallel, "
602+
f"but got type {type(ddp_model)}"
603+
)
604+
ctx_manager = ddp_model.no_sync()
605+
else:
606+
ctx_manager = (
607+
contextlib.nullcontext()
608+
if sys.version_info >= (3, 7)
609+
else contextlib.suppress()
610+
)
557611
return ctx_manager

ppsci/solver/train.py

Lines changed: 71 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,15 @@
1414

1515
import time
1616

17+
from paddle.distributed.fleet.utils import hybrid_parallel_util as hpu
18+
19+
from ppsci import solver
1720
from ppsci.solver import printer
18-
from ppsci.utils import expression
1921
from ppsci.utils import misc
2022
from ppsci.utils import profiler
2123

2224

23-
def train_epoch_func(solver, epoch_id: int, log_freq: int):
25+
def train_epoch_func(solver: "solver.Solver", epoch_id: int, log_freq: int):
2426
"""Train program for one epoch
2527
2628
Args:
@@ -61,38 +63,48 @@ def train_epoch_func(solver, epoch_id: int, log_freq: int):
6163
total_batch_size += next(iter(input_dict.values())).shape[0]
6264
reader_tic = time.perf_counter()
6365

64-
# forward for every constraint, including model and equation expression
65-
with solver.autocast_context_manager():
66-
constraint_losses = solver.forward_helper.train_forward(
67-
[_constraint.output_expr for _constraint in solver.constraint.values()],
68-
input_dicts,
69-
solver.model,
70-
solver.constraint,
71-
label_dicts,
72-
weight_dicts,
73-
)
74-
75-
# compute loss for each constraint according to its' own output, label and weight
76-
for i, _constraint in enumerate(solver.constraint.values()):
77-
total_loss += constraint_losses[i]
78-
loss_dict[_constraint.name] += float(constraint_losses[i])
79-
80-
if solver.update_freq > 1:
81-
total_loss = total_loss / solver.update_freq
82-
loss_dict["loss"] = float(total_loss)
83-
84-
# backward
85-
if solver.use_amp:
86-
total_loss_scaled = solver.scaler.scale(total_loss)
87-
total_loss_scaled.backward()
88-
if iter_id % solver.update_freq == 0:
66+
with solver.no_sync_context_manager(solver.world_size > 1, solver.model):
67+
# forward for every constraint, including model and equation expression
68+
with solver.autocast_context_manager(solver.use_amp, solver.amp_level):
69+
constraint_losses = solver.forward_helper.train_forward(
70+
[
71+
_constraint.output_expr
72+
for _constraint in solver.constraint.values()
73+
],
74+
input_dicts,
75+
solver.model,
76+
solver.constraint,
77+
label_dicts,
78+
weight_dicts,
79+
)
80+
# accumulate all losses
81+
for i, _constraint in enumerate(solver.constraint.values()):
82+
total_loss += constraint_losses[i]
83+
loss_dict[_constraint.name] += (
84+
float(constraint_losses[i]) / solver.update_freq
85+
)
86+
if solver.update_freq > 1:
87+
total_loss = total_loss / solver.update_freq
88+
loss_dict["loss"] = float(total_loss)
89+
90+
# backward
91+
if solver.use_amp:
92+
total_loss_scaled = solver.scaler.scale(total_loss)
93+
total_loss_scaled.backward()
94+
else:
95+
total_loss.backward()
96+
97+
# update parameters
98+
if iter_id % solver.update_freq == 0 or iter_id == solver.iters_per_epoch:
99+
if solver.world_size > 1:
100+
# fuse + allreduce manually before optimization if use DDP + no_sync
101+
# details in https://github.com/PaddlePaddle/Paddle/issues/48898#issuecomment-1343838622
102+
hpu.fused_allreduce_gradients(list(solver.model.parameters()), None)
103+
if solver.use_amp:
89104
solver.scaler.minimize(solver.optimizer, total_loss_scaled)
90-
solver.optimizer.clear_grad()
91-
else:
92-
total_loss.backward()
93-
if iter_id % solver.update_freq == 0:
105+
else:
94106
solver.optimizer.step()
95-
solver.optimizer.clear_grad()
107+
solver.optimizer.clear_grad()
96108

97109
# update learning rate by step
98110
if solver.lr_scheduler is not None and not solver.lr_scheduler.by_epoch:
@@ -111,7 +123,7 @@ def train_epoch_func(solver, epoch_id: int, log_freq: int):
111123
batch_tic = time.perf_counter()
112124

113125

114-
def train_LBFGS_epoch_func(solver, epoch_id: int, log_freq: int):
126+
def train_LBFGS_epoch_func(solver: "solver.Solver", epoch_id: int, log_freq: int):
115127
"""Train function for one epoch with L-BFGS optimizer.
116128
117129
Args:
@@ -152,30 +164,38 @@ def closure():
152164
Tensor: Computed loss.
153165
"""
154166
total_loss = 0
155-
for i, _constraint in enumerate(solver.constraint.values()):
156-
evaluator = expression.ExpressionSolver(
157-
_constraint.input_keys, _constraint.output_keys, solver.model
158-
)
159-
for output_name, output_formula in _constraint.output_expr.items():
160-
if output_name in label_dict:
161-
evaluator.add_target_expr(output_formula, output_name)
162-
163-
# forward for every batched data dict
164-
with solver.autocast_context_manager():
165-
output_dict_i = evaluator(input_dicts[i])
166-
constraint_loss = _constraint.loss(
167-
output_dict_i, label_dicts[i], weight_dicts[i]
167+
with solver.no_sync_context_manager(solver.world_size > 1, solver.model):
168+
with solver.autocast_context_manager(solver.use_amp, solver.amp_level):
169+
# forward for every constraint, including model and equation expression
170+
constraint_losses = solver.forward_helper.train_forward(
171+
[
172+
_constraint.output_expr
173+
for _constraint in solver.constraint.values()
174+
],
175+
input_dicts,
176+
solver.model,
177+
solver.constraint,
178+
label_dicts,
179+
weight_dicts,
168180
)
169-
total_loss += constraint_loss
181+
# accumulate all losses
182+
for i, _constraint in enumerate(solver.constraint.values()):
183+
total_loss += constraint_losses[i]
184+
loss_dict[_constraint.name] = float(constraint_losses[i])
185+
loss_dict["loss"] = float(total_loss)
170186

171-
loss_dict[_constraint.name] = float(constraint_loss)
187+
# backward
188+
solver.optimizer.clear_grad()
189+
total_loss.backward()
172190

173-
solver.optimizer.clear_grad()
174-
total_loss.backward()
175-
loss_dict["loss"] = float(total_loss)
191+
if solver.world_size > 1:
192+
# fuse + allreduce manually before optimization if use DDP model
193+
# details in https://github.com/PaddlePaddle/Paddle/issues/48898#issuecomment-1343838622
194+
hpu.fused_allreduce_gradients(list(solver.model.parameters()), None)
176195

177196
return total_loss
178197

198+
# update parameters
179199
solver.optimizer.step(closure)
180200

181201
# update learning rate by step

ppsci/solver/visu.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def visualize_func(solver, epoch_id: int):
5252
batch_input_dict[key].stop_gradient = False
5353

5454
# forward
55-
with solver.no_grad_context_manager():
55+
with solver.no_grad_context_manager(solver.eval_with_no_grad):
5656
batch_output_dict = solver.forward_helper.visu_forward(
5757
_visualizer.output_expr, batch_input_dict, solver.model
5858
)

ppsci/utils/save_load.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,11 @@ def save_checkpoint(
127127
"""
128128
if paddle.distributed.get_rank() != 0:
129129
return
130+
if model_dir is None:
131+
logger.warning(
132+
f"model_dir({model_dir}) is set to None, skip save_checkpoint..."
133+
)
134+
return
130135
model_dir = os.path.join(model_dir, "checkpoints")
131136
os.makedirs(model_dir, exist_ok=True)
132137
model_path = os.path.join(model_dir, prefix)

0 commit comments

Comments
 (0)