14
14
15
15
import time
16
16
17
+ from paddle .distributed .fleet .utils import hybrid_parallel_util as hpu
18
+
19
+ from ppsci import solver
17
20
from ppsci .solver import printer
18
- from ppsci .utils import expression
19
21
from ppsci .utils import misc
20
22
from ppsci .utils import profiler
21
23
22
24
23
- def train_epoch_func (solver , epoch_id : int , log_freq : int ):
25
+ def train_epoch_func (solver : "solver.Solver" , epoch_id : int , log_freq : int ):
24
26
"""Train program for one epoch
25
27
26
28
Args:
@@ -61,38 +63,48 @@ def train_epoch_func(solver, epoch_id: int, log_freq: int):
61
63
total_batch_size += next (iter (input_dict .values ())).shape [0 ]
62
64
reader_tic = time .perf_counter ()
63
65
64
- # forward for every constraint, including model and equation expression
65
- with solver .autocast_context_manager ():
66
- constraint_losses = solver .forward_helper .train_forward (
67
- [_constraint .output_expr for _constraint in solver .constraint .values ()],
68
- input_dicts ,
69
- solver .model ,
70
- solver .constraint ,
71
- label_dicts ,
72
- weight_dicts ,
73
- )
74
-
75
- # compute loss for each constraint according to its' own output, label and weight
76
- for i , _constraint in enumerate (solver .constraint .values ()):
77
- total_loss += constraint_losses [i ]
78
- loss_dict [_constraint .name ] += float (constraint_losses [i ])
79
-
80
- if solver .update_freq > 1 :
81
- total_loss = total_loss / solver .update_freq
82
- loss_dict ["loss" ] = float (total_loss )
83
-
84
- # backward
85
- if solver .use_amp :
86
- total_loss_scaled = solver .scaler .scale (total_loss )
87
- total_loss_scaled .backward ()
88
- if iter_id % solver .update_freq == 0 :
66
+ with solver .no_sync_context_manager (solver .world_size > 1 , solver .model ):
67
+ # forward for every constraint, including model and equation expression
68
+ with solver .autocast_context_manager (solver .use_amp , solver .amp_level ):
69
+ constraint_losses = solver .forward_helper .train_forward (
70
+ [
71
+ _constraint .output_expr
72
+ for _constraint in solver .constraint .values ()
73
+ ],
74
+ input_dicts ,
75
+ solver .model ,
76
+ solver .constraint ,
77
+ label_dicts ,
78
+ weight_dicts ,
79
+ )
80
+ # accumulate all losses
81
+ for i , _constraint in enumerate (solver .constraint .values ()):
82
+ total_loss += constraint_losses [i ]
83
+ loss_dict [_constraint .name ] += (
84
+ float (constraint_losses [i ]) / solver .update_freq
85
+ )
86
+ if solver .update_freq > 1 :
87
+ total_loss = total_loss / solver .update_freq
88
+ loss_dict ["loss" ] = float (total_loss )
89
+
90
+ # backward
91
+ if solver .use_amp :
92
+ total_loss_scaled = solver .scaler .scale (total_loss )
93
+ total_loss_scaled .backward ()
94
+ else :
95
+ total_loss .backward ()
96
+
97
+ # update parameters
98
+ if iter_id % solver .update_freq == 0 or iter_id == solver .iters_per_epoch :
99
+ if solver .world_size > 1 :
100
+ # fuse + allreduce manually before optimization if use DDP + no_sync
101
+ # details in https://github.com/PaddlePaddle/Paddle/issues/48898#issuecomment-1343838622
102
+ hpu .fused_allreduce_gradients (list (solver .model .parameters ()), None )
103
+ if solver .use_amp :
89
104
solver .scaler .minimize (solver .optimizer , total_loss_scaled )
90
- solver .optimizer .clear_grad ()
91
- else :
92
- total_loss .backward ()
93
- if iter_id % solver .update_freq == 0 :
105
+ else :
94
106
solver .optimizer .step ()
95
- solver .optimizer .clear_grad ()
107
+ solver .optimizer .clear_grad ()
96
108
97
109
# update learning rate by step
98
110
if solver .lr_scheduler is not None and not solver .lr_scheduler .by_epoch :
@@ -111,7 +123,7 @@ def train_epoch_func(solver, epoch_id: int, log_freq: int):
111
123
batch_tic = time .perf_counter ()
112
124
113
125
114
- def train_LBFGS_epoch_func (solver , epoch_id : int , log_freq : int ):
126
+ def train_LBFGS_epoch_func (solver : "solver.Solver" , epoch_id : int , log_freq : int ):
115
127
"""Train function for one epoch with L-BFGS optimizer.
116
128
117
129
Args:
@@ -152,30 +164,38 @@ def closure():
152
164
Tensor: Computed loss.
153
165
"""
154
166
total_loss = 0
155
- for i , _constraint in enumerate ( solver .constraint . values () ):
156
- evaluator = expression . ExpressionSolver (
157
- _constraint . input_keys , _constraint . output_keys , solver . model
158
- )
159
- for output_name , output_formula in _constraint . output_expr . items ():
160
- if output_name in label_dict :
161
- evaluator . add_target_expr ( output_formula , output_name )
162
-
163
- # forward for every batched data dict
164
- with solver .autocast_context_manager ():
165
- output_dict_i = evaluator ( input_dicts [ i ])
166
- constraint_loss = _constraint . loss (
167
- output_dict_i , label_dicts [ i ], weight_dicts [ i ]
167
+ with solver . no_sync_context_manager ( solver . world_size > 1 , solver .model ):
168
+ with solver . autocast_context_manager ( solver . use_amp , solver . amp_level ):
169
+ # forward for every constraint, including model and equation expression
170
+ constraint_losses = solver . forward_helper . train_forward (
171
+ [
172
+ _constraint . output_expr
173
+ for _constraint in solver . constraint . values ( )
174
+ ],
175
+ input_dicts ,
176
+ solver .model ,
177
+ solver . constraint ,
178
+ label_dicts ,
179
+ weight_dicts ,
168
180
)
169
- total_loss += constraint_loss
181
+ # accumulate all losses
182
+ for i , _constraint in enumerate (solver .constraint .values ()):
183
+ total_loss += constraint_losses [i ]
184
+ loss_dict [_constraint .name ] = float (constraint_losses [i ])
185
+ loss_dict ["loss" ] = float (total_loss )
170
186
171
- loss_dict [_constraint .name ] = float (constraint_loss )
187
+ # backward
188
+ solver .optimizer .clear_grad ()
189
+ total_loss .backward ()
172
190
173
- solver .optimizer .clear_grad ()
174
- total_loss .backward ()
175
- loss_dict ["loss" ] = float (total_loss )
191
+ if solver .world_size > 1 :
192
+ # fuse + allreduce manually before optimization if use DDP model
193
+ # details in https://github.com/PaddlePaddle/Paddle/issues/48898#issuecomment-1343838622
194
+ hpu .fused_allreduce_gradients (list (solver .model .parameters ()), None )
176
195
177
196
return total_loss
178
197
198
+ # update parameters
179
199
solver .optimizer .step (closure )
180
200
181
201
# update learning rate by step
0 commit comments