Skip to content

Commit 7daa624

Browse files
Merge pull request #322 from lijialin03/develop
update train for lbfgs
2 parents 797cb06 + 33e43f7 commit 7daa624

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

ppsci/solver/train.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def train_epoch_func(solver, epoch_id: int, log_freq: int):
6464
constraint_loss = _constraint.loss(output_dict, label_dict, weight_dict)
6565
total_loss += constraint_loss
6666

67-
loss_dict[_constraint.name] += float(constraint_loss)
67+
loss_dict[_constraint.name] = float(constraint_loss)
6868

6969
reader_tic = time.perf_counter()
7070

@@ -156,8 +156,9 @@ def closure():
156156
)
157157
total_loss += constraint_loss
158158

159-
loss_dict[_constraint.name] += float(constraint_loss)
159+
loss_dict[_constraint.name] = float(constraint_loss)
160160

161+
solver.optimizer.clear_grad()
161162
total_loss.backward()
162163
loss_dict["loss"] = float(total_loss)
163164

0 commit comments

Comments
 (0)