Skip to content

Commit bf17f68

Browse files
[Refine] Refine loss and metric module (#919)
* return loss dict instead of loss summation for all loss.forward * adapt all mtl module for Dict[str, Tensor] type of input losses * fix * remove 'area' in Constriant.output_keys * fix eval.py * fix code * fix examples in func.py * fix examples in func.py * Fix for MSELossWithL2Decay and train_enn.py * fix doctest in loss/mse.py * fix epnn * fix
1 parent d5f10d5 commit bf17f68

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

50 files changed

+367
-211
lines changed

examples/RegAE/RegAE.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,10 @@ def loss_expr(output_dict, label_dict, weight_dict=None):
5656
base = paddle.exp(2.0 * log_sigma) + paddle.pow(mu, 2) - 1.0 - 2.0 * log_sigma
5757
KLLoss = 0.5 * paddle.sum(base) / mu.shape[0]
5858

59-
return F.mse_loss(output_dict["decoder_z"], label_dict["p_train"]) + KLLoss
59+
return {
60+
"decode_loss": F.mse_loss(output_dict["decoder_z"], label_dict["p_train"])
61+
+ KLLoss
62+
}
6063

6164
# set constraint
6265
sup_constraint = ppsci.constraint.SupervisedConstraint(

examples/amgnet/amgnet_airfoil.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def train_mse_func(
3737
label_dict: Dict[str, "pgl.Graph"],
3838
*args,
3939
) -> paddle.Tensor:
40-
return F.mse_loss(output_dict["pred"], label_dict["label"].y)
40+
return {"pred": F.mse_loss(output_dict["pred"], label_dict["label"].y)}
4141

4242

4343
def eval_rmse_func(

examples/amgnet/amgnet_cylinder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def train_mse_func(
3737
label_dict: Dict[str, "pgl.Graph"],
3838
*args,
3939
) -> paddle.Tensor:
40-
return F.mse_loss(output_dict["pred"], label_dict["label"].y)
40+
return {"pred": F.mse_loss(output_dict["pred"], label_dict["label"].y)}
4141

4242

4343
def eval_rmse_func(

examples/cfdgcn/cfdgcn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def train_mse_func(
3333
label_dict: Dict[str, "pgl.Graph"],
3434
*args,
3535
) -> paddle.Tensor:
36-
return F.mse_loss(output_dict["pred"], label_dict["label"].y)
36+
return {"pred": F.mse_loss(output_dict["pred"], label_dict["label"].y)}
3737

3838

3939
def eval_rmse_func(

examples/deepcfd/deepcfd.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -243,7 +243,7 @@ def loss_expr(
243243
loss_v = (output[:, 1:2, :, :] - y[:, 1:2, :, :]) ** 2
244244
loss_p = (output[:, 2:3, :, :] - y[:, 2:3, :, :]).abs()
245245
loss = (loss_u + loss_v + loss_p) / CHANNELS_WEIGHTS
246-
return loss.sum()
246+
return {"output": loss.sum()}
247247

248248
sup_constraint = ppsci.constraint.SupervisedConstraint(
249249
{

examples/deephpms/burgers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131

3232
def pde_loss_func(output_dict, *args):
3333
losses = F.mse_loss(output_dict["f_pde"], output_dict["du_t"], "sum")
34-
return losses
34+
return {"pde": losses}
3535

3636

3737
def pde_l2_rel_func(output_dict, *args):
@@ -53,7 +53,7 @@ def boundary_loss_func(output_dict, *args):
5353

5454
losses = F.mse_loss(u_lb, u_ub, "sum")
5555
losses += F.mse_loss(du_x_lb, du_x_ub, "sum")
56-
return losses
56+
return {"boundary": losses}
5757

5858

5959
def train(cfg: DictConfig):

examples/deephpms/korteweg_de_vries.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131

3232
def pde_loss_func(output_dict, *args):
3333
losses = F.mse_loss(output_dict["f_pde"], output_dict["du_t"], "sum")
34-
return losses
34+
return {"pde": losses}
3535

3636

3737
def pde_l2_rel_func(output_dict, *args):
@@ -56,7 +56,7 @@ def boundary_loss_func(output_dict, *args):
5656
losses = F.mse_loss(u_lb, u_ub, "sum")
5757
losses += F.mse_loss(du_x_lb, du_x_ub, "sum")
5858
losses += F.mse_loss(du_xx_lb, du_xx_ub, "sum")
59-
return losses
59+
return {"boundary": losses}
6060

6161

6262
def train(cfg: DictConfig):

examples/deephpms/kuramoto_sivashinsky.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131

3232
def pde_loss_func(output_dict, *args):
3333
losses = F.mse_loss(output_dict["f_pde"], output_dict["du_t"], "sum")
34-
return losses
34+
return {"pde": losses}
3535

3636

3737
def pde_l2_rel_func(output_dict, *args):
@@ -59,7 +59,7 @@ def boundary_loss_func(output_dict, *args):
5959
losses += F.mse_loss(du_x_lb, du_x_ub, "sum")
6060
losses += F.mse_loss(du_xx_lb, du_xx_ub, "sum")
6161
losses += F.mse_loss(du_xxx_lb, du_xxx_ub, "sum")
62-
return losses
62+
return {"boundary": losses}
6363

6464

6565
def train(cfg: DictConfig):

examples/deephpms/navier_stokes.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131

3232
def pde_loss_func(output_dict, *args):
3333
losses = F.mse_loss(output_dict["f_pde"], output_dict["dw_t"], "sum")
34-
return losses
34+
return {"pde": losses}
3535

3636

3737
def pde_l2_rel_func(output_dict, *args):

examples/deephpms/schrodinger.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
def pde_loss_func(output_dict, *args):
3333
losses = F.mse_loss(output_dict["f_pde"], output_dict["du_t"], "sum")
3434
losses += F.mse_loss(output_dict["g_pde"], output_dict["dv_t"], "sum")
35-
return losses
35+
return {"pde": losses}
3636

3737

3838
def pde_l2_rel_func(output_dict, *args):
@@ -62,7 +62,7 @@ def boundary_loss_func(output_dict, *args):
6262
losses += F.mse_loss(v_lb, v_ub, "sum")
6363
losses += F.mse_loss(du_x_lb, du_x_ub, "sum")
6464
losses += F.mse_loss(dv_x_lb, dv_x_ub, "sum")
65-
return losses
65+
return {"boundary": losses}
6666

6767

6868
def sol_l2_rel_func(output_dict, label_dict):

0 commit comments

Comments
 (0)