@@ -34,7 +34,7 @@ def train_epoch_func(solver, epoch_id: int, log_freq: int):
34
34
total_loss = 0
35
35
loss_dict = misc .Prettydefaultdict (float )
36
36
loss_dict ["loss" ] = 0.0
37
- total_batch_size = []
37
+ total_batch_size = 0
38
38
reader_cost = 0
39
39
batch_cost = 0
40
40
reader_tic = time .perf_counter ()
@@ -48,7 +48,7 @@ def train_epoch_func(solver, epoch_id: int, log_freq: int):
48
48
for key in solver .train_time_info :
49
49
solver .train_time_info [key ].reset ()
50
50
reader_cost += time .perf_counter () - reader_tic
51
- total_batch_size . append ( next (iter (input_dict .values ())).shape [0 ])
51
+ total_batch_size += next (iter (input_dict .values ())).shape [0 ]
52
52
53
53
for v in input_dict .values ():
54
54
v .stop_gradient = False
@@ -80,21 +80,20 @@ def train_epoch_func(solver, epoch_id: int, log_freq: int):
80
80
if iter_id % solver .update_freq == 0 :
81
81
solver .scaler .minimize (solver .optimizer , total_loss_scaled )
82
82
solver .optimizer .clear_grad ()
83
- if solver .lr_scheduler is not None and not solver .lr_scheduler .by_epoch :
84
- solver .lr_scheduler .step ()
85
83
else :
86
84
total_loss .backward ()
87
85
if iter_id % solver .update_freq == 0 :
88
86
solver .optimizer .step ()
89
87
solver .optimizer .clear_grad ()
90
- if solver .lr_scheduler is not None and not solver .lr_scheduler .by_epoch :
91
- solver .lr_scheduler .step ()
88
+
89
+ # update learning rate by step
90
+ if solver .lr_scheduler is not None and not solver .lr_scheduler .by_epoch :
91
+ solver .lr_scheduler .step ()
92
92
93
93
batch_cost += time .perf_counter () - batch_tic
94
94
95
95
# update and log training information
96
96
solver .global_step += 1
97
- total_batch_size = sum (total_batch_size )
98
97
solver .train_time_info ["reader_cost" ].update (reader_cost )
99
98
solver .train_time_info ["batch_cost" ].update (batch_cost )
100
99
printer .update_train_loss (solver , loss_dict , total_batch_size )
@@ -117,23 +116,26 @@ def train_LBFGS_epoch_func(solver, epoch_id: int, log_freq: int):
117
116
for iter_id in range (1 , solver .iters_per_epoch + 1 ):
118
117
loss_dict = misc .Prettydefaultdict (float )
119
118
loss_dict ["loss" ] = 0.0
120
- total_batch_size = []
119
+ total_batch_size = 0
121
120
reader_cost = 0
122
121
batch_cost = 0
123
122
reader_tic = time .perf_counter ()
124
123
125
- input_dict_list = []
126
- label_dict_list = []
127
- weight_dict_list = []
124
+ input_dicts = []
125
+ label_dicts = []
126
+ weight_dicts = []
128
127
for _ , _constraint in solver .constraint .items ():
129
128
input_dict , label_dict , weight_dict = next (_constraint .data_iter )
130
129
reader_cost += time .perf_counter () - reader_tic
131
130
for v in input_dict .values ():
132
131
v .stop_gradient = False
133
- input_dict_list .append (input_dict )
134
- label_dict_list .append (label_dict )
135
- weight_dict_list .append (weight_dict )
136
- total_batch_size .append (next (iter (input_dict .values ())).shape [0 ])
132
+
133
+ # gather all constraint data into list
134
+ input_dicts .append (input_dict )
135
+ label_dicts .append (label_dict )
136
+ weight_dicts .append (weight_dict )
137
+ total_batch_size += next (iter (input_dict .values ())).shape [0 ]
138
+ reader_tic = time .perf_counter ()
137
139
138
140
def closure ():
139
141
"""Forward-backward closure function for LBFGS optimizer.
@@ -147,13 +149,14 @@ def closure():
147
149
_constraint .input_keys , _constraint .output_keys , solver .model
148
150
)
149
151
for output_name , output_formula in _constraint .output_expr .items ():
150
- evaluator .add_target_expr (output_formula , output_name )
152
+ if output_name in label_dict :
153
+ evaluator .add_target_expr (output_formula , output_name )
151
154
152
155
# forward for every batched data dict
153
156
with solver .autocast_context_manager ():
154
- output_dict_i = evaluator (input_dict_list [i ])
157
+ output_dict_i = evaluator (input_dicts [i ])
155
158
constraint_loss = _constraint .loss (
156
- output_dict_i , label_dict_list [i ], weight_dict_list [i ]
159
+ output_dict_i , label_dicts [i ], weight_dicts [i ]
157
160
)
158
161
total_loss += constraint_loss
159
162
@@ -165,8 +168,6 @@ def closure():
165
168
166
169
return total_loss
167
170
168
- reader_tic = time .perf_counter ()
169
-
170
171
solver .optimizer .step (closure )
171
172
if solver .lr_scheduler is not None and not solver .lr_scheduler .by_epoch :
172
173
solver .lr_scheduler .step ()
@@ -175,7 +176,6 @@ def closure():
175
176
176
177
# update and log training information
177
178
solver .global_step += 1
178
- total_batch_size = sum (total_batch_size )
179
179
solver .train_time_info ["reader_cost" ].update (reader_cost )
180
180
solver .train_time_info ["batch_cost" ].update (batch_cost )
181
181
printer .update_train_loss (solver , loss_dict , total_batch_size )
0 commit comments