@@ -59,7 +59,7 @@ def train_epoch_func(solver, epoch_id: int, log_freq: int):
59
59
evaluator .add_target_expr (output_formula , output_name )
60
60
61
61
# forward for every constraint
62
- with solver ._autocast_context_manager ():
62
+ with solver .autocast_context_manager ():
63
63
output_dict = evaluator (input_dict )
64
64
constraint_loss = _constraint .loss (output_dict , label_dict , weight_dict )
65
65
total_loss += constraint_loss
@@ -114,16 +114,16 @@ def train_LBFGS_epoch_func(solver, epoch_id: int, log_freq: int):
114
114
batch_tic = time .perf_counter ()
115
115
116
116
for iter_id in range (1 , solver .iters_per_epoch + 1 ):
117
- reader_cost = 0
118
- batch_cost = 0
119
117
loss_dict = misc .Prettydefaultdict (float )
120
118
loss_dict ["loss" ] = 0.0
119
+ total_batch_size = []
120
+ reader_cost = 0
121
+ batch_cost = 0
122
+ reader_tic = time .perf_counter ()
123
+
121
124
input_dict_list = []
122
125
label_dict_list = []
123
126
weight_dict_list = []
124
- batch_cost = 0
125
- total_batch_size = []
126
- reader_tic = time .perf_counter ()
127
127
for _ , _constraint in solver .constraint .items ():
128
128
input_dict , label_dict , weight_dict = next (_constraint .data_iter )
129
129
reader_cost += time .perf_counter () - reader_tic
@@ -133,10 +133,9 @@ def train_LBFGS_epoch_func(solver, epoch_id: int, log_freq: int):
133
133
label_dict_list .append (label_dict )
134
134
weight_dict_list .append (weight_dict )
135
135
total_batch_size .append (next (iter (input_dict .values ())).shape [0 ])
136
- total_batch_size = sum (total_batch_size )
137
136
138
137
def closure ():
139
- """Closure function for LBFGS optimizer.
138
+ """Forward-backward closure function for LBFGS optimizer.
140
139
141
140
Returns:
142
141
Tensor: Computed loss.
@@ -149,21 +148,25 @@ def closure():
149
148
for output_name , output_formula in _constraint .output_expr .items ():
150
149
evaluator .add_target_expr (output_formula , output_name )
151
150
152
- # forward for every constraint
153
- output_dict_i = evaluator (input_dict_list [i ])
154
- constraint_loss = _constraint .loss (
155
- output_dict_i , label_dict_list [i ], weight_dict_list [i ]
156
- )
157
- total_loss += constraint_loss
151
+ # forward for every batched data dict
152
+ with solver .autocast_context_manager ():
153
+ output_dict_i = evaluator (input_dict_list [i ])
154
+ constraint_loss = _constraint .loss (
155
+ output_dict_i , label_dict_list [i ], weight_dict_list [i ]
156
+ )
157
+ total_loss += constraint_loss
158
158
159
159
loss_dict [_constraint .name ] += float (constraint_loss )
160
160
161
161
total_loss .backward ()
162
+ loss_dict ["loss" ] = float (total_loss )
162
163
163
164
return total_loss
164
165
166
+ reader_tic = time .perf_counter ()
167
+
165
168
solver .optimizer .step (closure )
166
- if not getattr ( solver .lr_scheduler , " by_epoch" , False ) :
169
+ if solver . lr_scheduler is not None and not solver .lr_scheduler . by_epoch :
167
170
solver .lr_scheduler .step ()
168
171
169
172
batch_cost += time .perf_counter () - batch_tic
0 commit comments