Skip to content

Commit 407551f

Browse files
Merge pull request #318 from HydrogenSulfate/fix_loss_weight
refine weight arg in loss
2 parents 17d315c + f155452 commit 407551f

File tree

6 files changed

+107
-45
lines changed

6 files changed

+107
-45
lines changed

ppsci/loss/base.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,16 @@
1717
from typing import Union
1818

1919
import paddle.nn as nn
20+
from typing_extensions import Literal
2021

2122

2223
class Loss(nn.Layer):
2324
"""Base class for loss."""
2425

2526
def __init__(
26-
self, reduction: str, weight: Optional[Union[Dict[str, float], float]] = None
27+
self,
28+
reduction: Literal["mean", "sum"],
29+
weight: Optional[Union[float, Dict[str, float]]] = None,
2730
):
2831
super().__init__()
2932
self.reduction = reduction

ppsci/loss/integral.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from typing import Union
1818

1919
import paddle.nn.functional as F
20+
from typing_extensions import Literal
2021

2122
from ppsci.loss import base
2223

@@ -35,8 +36,8 @@ class IntegralLoss(base.Loss):
3536
$M$ is the number of samples in monte carlo integration.
3637
3738
Args:
38-
reduction (str, optional): Reduction method. Defaults to "mean".
39-
weight (Optional[Union[Dict[str, float], float]]): Weight for loss. Defaults to None.
39+
reduction (Literal["mean", "sum"], optional): Reduction method. Defaults to "mean".
40+
weight (Optional[Union[float, Dict[str, float]]]): Weight for loss. Defaults to None.
4041
4142
Examples:
4243
>>> import ppsci
@@ -45,8 +46,8 @@ class IntegralLoss(base.Loss):
4546

4647
def __init__(
4748
self,
48-
reduction: str = "mean",
49-
weight: Optional[Union[Dict[str, float], float]] = None,
49+
reduction: Literal["mean", "sum"] = "mean",
50+
weight: Optional[Union[float, Dict[str, float]]] = None,
5051
):
5152
if reduction not in ["mean", "sum"]:
5253
raise ValueError(
@@ -64,14 +65,16 @@ def forward(self, output_dict, label_dict, weight_dict=None):
6465
)
6566
if weight_dict:
6667
loss *= weight_dict[key]
67-
if isinstance(self.weight, float):
68-
loss *= self.weight
69-
elif isinstance(self.weight, dict) and key in self.weight:
70-
loss *= self.weight[key]
7168

7269
if self.reduction == "sum":
7370
loss = loss.sum()
7471
elif self.reduction == "mean":
7572
loss = loss.mean()
73+
74+
if isinstance(self.weight, float):
75+
loss *= self.weight
76+
elif isinstance(self.weight, dict) and key in self.weight:
77+
loss *= self.weight[key]
78+
7679
losses += loss
7780
return losses

ppsci/loss/l1.py

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from typing import Union
1818

1919
import paddle.nn.functional as F
20+
from typing_extensions import Literal
2021

2122
from ppsci.loss import base
2223

@@ -33,8 +34,8 @@ class L1Loss(base.Loss):
3334
$$
3435
3536
Args:
36-
reduction (str, optional): Reduction method. Defaults to "mean".
37-
weight (Optional[Union[Dict[str, float], float]]): Weight for loss. Defaults to None.
37+
reduction (Literal["mean", "sum"], optional): Reduction method. Defaults to "mean".
38+
weight (Optional[Union[float, Dict[str, float]]]): Weight for loss. Defaults to None.
3839
3940
Examples:
4041
>>> import ppsci
@@ -43,8 +44,8 @@ class L1Loss(base.Loss):
4344

4445
def __init__(
4546
self,
46-
reduction: str = "mean",
47-
weight: Optional[Union[Dict[str, float], float]] = None,
47+
reduction: Literal["mean", "sum"] = "mean",
48+
weight: Optional[Union[float, Dict[str, float]]] = None,
4849
):
4950
if reduction not in ["mean", "sum"]:
5051
raise ValueError(
@@ -70,6 +71,12 @@ def forward(self, output_dict, label_dict, weight_dict=None):
7071
loss = loss.sum()
7172
elif self.reduction == "mean":
7273
loss = loss.mean()
74+
75+
if isinstance(self.weight, float):
76+
loss *= self.weight
77+
elif isinstance(self.weight, dict) and key in self.weight:
78+
loss *= self.weight[key]
79+
7380
losses += loss
7481
return losses
7582

@@ -78,16 +85,19 @@ class PeriodicL1Loss(base.Loss):
7885
"""Class for periodic l1 loss.
7986
8087
Args:
81-
reduction (str, optional): Reduction method. Defaults to "mean".
88+
reduction (Literal["mean", "sum"], optional): Reduction method. Defaults to "mean".
8289
"""
8390

84-
def __init__(self, reduction="mean"):
85-
super().__init__()
91+
def __init__(
92+
self,
93+
reduction: Literal["mean", "sum"] = "mean",
94+
weight: Optional[Union[float, Dict[str, float]]] = None,
95+
):
8696
if reduction not in ["mean", "sum"]:
8797
raise ValueError(
8898
f"reduction should be 'mean' or 'sum', but got {reduction}"
8999
)
90-
self.reduction = reduction
100+
super().__init__(reduction, weight)
91101

92102
def forward(self, output_dict, label_dict, weight_dict=None):
93103
losses = 0.0
@@ -111,5 +121,11 @@ def forward(self, output_dict, label_dict, weight_dict=None):
111121
loss = loss.sum()
112122
elif self.reduction == "mean":
113123
loss = loss.mean()
124+
125+
if isinstance(self.weight, float):
126+
loss *= self.weight
127+
elif isinstance(self.weight, dict) and key in self.weight:
128+
loss *= self.weight[key]
129+
114130
losses += loss
115131
return losses

ppsci/loss/l2.py

Lines changed: 38 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -31,14 +31,14 @@ class L2Loss(base.Loss):
3131
$$
3232
3333
Args:
34-
weight (Optional[Union[Dict[str, float], float]]): Weight for loss. Defaults to None.
34+
weight (Optional[Union[float, Dict[str, float]]]): Weight for loss. Defaults to None.
3535
3636
Examples:
3737
>>> import ppsci
3838
>>> loss = ppsci.loss.L2Loss()
3939
"""
4040

41-
def __init__(self, weight: Optional[Union[Dict[str, float], float]] = None):
41+
def __init__(self, weight: Optional[Union[float, Dict[str, float]]] = None):
4242
super().__init__("sum", weight)
4343

4444
def forward(self, output_dict, label_dict, weight_dict=None):
@@ -47,29 +47,33 @@ def forward(self, output_dict, label_dict, weight_dict=None):
4747
loss = F.mse_loss(output_dict[key], label_dict[key], "none")
4848
if weight_dict is not None:
4949
loss *= weight_dict[key]
50-
if isinstance(self.weight, float):
51-
loss *= self.weight
52-
elif isinstance(self.weight, dict) and key in self.weight:
53-
loss *= self.weight[key]
5450

5551
if "area" in output_dict:
5652
loss *= output_dict["area"]
5753

5854
loss = loss.sum()
55+
if isinstance(self.weight, float):
56+
loss *= self.weight
57+
elif isinstance(self.weight, dict) and key in self.weight:
58+
loss *= self.weight[key]
59+
5960
losses += loss
6061
return losses
6162

6263

6364
class PeriodicL2Loss(base.Loss):
6465
"""Class for Periodic l2 loss."""
6566

66-
def __init__(self, reduction="mean"):
67-
super().__init__()
67+
def __init__(
68+
self,
69+
reduction: Literal["mean", "sum"] = "mean",
70+
weight: Optional[Union[float, Dict[str, float]]] = None,
71+
):
6872
if reduction not in ["mean", "sum"]:
6973
raise ValueError(
7074
f"reduction should be 'mean' or 'sum', but got {reduction}"
7175
)
72-
self.reduction = reduction
76+
super().__init__(reduction, weight)
7377

7478
def forward(self, output_dict, label_dict, weight_dict=None):
7579
losses = 0.0
@@ -90,6 +94,11 @@ def forward(self, output_dict, label_dict, weight_dict=None):
9094
loss *= output_dict["area"]
9195

9296
loss = loss.sum()
97+
if isinstance(self.weight, float):
98+
loss *= self.weight
99+
elif isinstance(self.weight, dict) and key in self.weight:
100+
loss *= self.weight[key]
101+
93102
losses += loss
94103
return losses
95104

@@ -100,21 +109,30 @@ class L2RelLoss(base.Loss):
100109
$$
101110
L =
102111
\begin{cases}
103-
\dfrac{1}{N}\sum\limits_{i=1}^{N}{\dfrac{\Vert \bm{X_i}-\bm{Y_i}\Vert_2}{\Vert \bm{Y_i}\Vert_2}}, & \text{if reduction='mean'} \\
104-
\sum\limits_{i=1}^{N}{\dfrac{\Vert \bm{X_i}-\bm{Y_i}\Vert_2}{\Vert \bm{Y_i}\Vert_2}}, & \text{if reduction='sum'}
112+
\dfrac{1}{N}\sum\limits_{i=1}^{N}{\dfrac{\Vert \mathbf{X_i}-\mathbf{Y_i}\Vert_2}{\Vert \mathbf{Y_i}\Vert_2}}, & \text{if reduction='mean'} \\
113+
\sum\limits_{i=1}^{N}{\dfrac{\Vert \mathbf{X_i}-\mathbf{Y_i}\Vert_2}{\Vert \mathbf{Y_i}\Vert_2}}, & \text{if reduction='sum'}
105114
\end{cases}
106115
$$
107116
108117
Args:
109118
reduction (Literal["mean", "sum"], optional): Specifies the reduction to apply to the output: 'mean' | 'sum'. Defaults to "mean".
110-
119+
weight (Optional[Union[float, Dict[str, float]]]): Weight for loss. Defaults to None.
120+
111121
Examples:
112122
>>> import ppsci
113123
>>> loss = ppsci.loss.L2RelLoss()
114124
"""
115125

116-
def __init__(self, reduction: Literal["mean", "sum"] = "mean"):
117-
super().__init__(reduction)
126+
def __init__(
127+
self,
128+
reduction: Literal["mean", "sum"] = "mean",
129+
weight: Optional[Union[float, Dict[str, float]]] = None,
130+
):
131+
if reduction not in ["mean", "sum"]:
132+
raise ValueError(
133+
f"reduction should be 'mean' or 'sum', but got {reduction}"
134+
)
135+
super().__init__(reduction, weight)
118136

119137
def rel_loss(self, x, y):
120138
batch_size = x.shape[0]
@@ -134,6 +152,12 @@ def forward(self, output_dict, label_dict, weight_dict=None):
134152
loss = loss.sum()
135153
elif self.reduction == "mean":
136154
loss = loss.mean()
155+
156+
if isinstance(self.weight, float):
157+
loss *= self.weight
158+
elif isinstance(self.weight, dict) and key in self.weight:
159+
loss *= self.weight[key]
160+
137161
losses += loss
138162

139163
return losses

ppsci/loss/mse.py

Lines changed: 30 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,8 @@ class MSELoss(base.Loss):
3535
$$
3636
3737
Args:
38-
reduction (str, optional): Reduction method. Defaults to "mean".
39-
weight (Optional[Union[Dict[str, float], float]]): Weight for loss. Defaults to None.
38+
reduction (Literal["mean", "sum"], optional): Reduction method. Defaults to "mean".
39+
weight (Optional[Union[float, Dict[str, float]]]): Weight for loss. Defaults to None.
4040
4141
Examples:
4242
>>> import ppsci
@@ -45,8 +45,8 @@ class MSELoss(base.Loss):
4545

4646
def __init__(
4747
self,
48-
reduction: str = "mean",
49-
weight: Optional[Union[Dict[str, float], float]] = None,
48+
reduction: Literal["mean", "sum"] = "mean",
49+
weight: Optional[Union[float, Dict[str, float]]] = None,
5050
):
5151
if reduction not in ["mean", "sum"]:
5252
raise ValueError(
@@ -60,10 +60,6 @@ def forward(self, output_dict, label_dict, weight_dict=None):
6060
loss = F.mse_loss(output_dict[key], label_dict[key], "none")
6161
if weight_dict:
6262
loss *= weight_dict[key]
63-
if isinstance(self.weight, (float, int)):
64-
loss *= self.weight
65-
elif isinstance(self.weight, dict) and key in self.weight:
66-
loss *= self.weight[key]
6763

6864
if "area" in output_dict:
6965
loss *= output_dict["area"]
@@ -72,6 +68,11 @@ def forward(self, output_dict, label_dict, weight_dict=None):
7268
loss = loss.sum()
7369
elif self.reduction == "mean":
7470
loss = loss.mean()
71+
if isinstance(self.weight, float):
72+
loss *= self.weight
73+
elif isinstance(self.weight, dict) and key in self.weight:
74+
loss *= self.weight[key]
75+
7576
losses += loss
7677
return losses
7778

@@ -92,6 +93,7 @@ class MSELossWithL2Decay(MSELoss):
9293
Args:
9394
reduction (Literal["mean", "sum"], optional): Specifies the reduction to apply to the output: 'mean' | 'sum'. Defaults to "mean".
9495
regularization_dict (Optional[Dict[str, float]]): Regularization dictionary. Defaults to None.
96+
weight (Optional[Union[float, Dict[str, float]]]): Weight for loss. Defaults to None.
9597
9698
Raises:
9799
ValueError: reduction should be 'mean' or 'sum'.
@@ -105,8 +107,13 @@ def __init__(
105107
self,
106108
reduction: Literal["mean", "sum"] = "mean",
107109
regularization_dict: Optional[Dict[str, float]] = None,
110+
weight: Optional[Union[float, Dict[str, float]]] = None,
108111
):
109-
super().__init__(reduction)
112+
if reduction not in ["mean", "sum"]:
113+
raise ValueError(
114+
f"reduction should be 'mean' or 'sum', but got {reduction}"
115+
)
116+
super().__init__(reduction, weight)
110117
self.regularization_dict = regularization_dict
111118

112119
def forward(self, output_dict, label_dict, weight_dict=None):
@@ -123,16 +130,20 @@ class PeriodicMSELoss(base.Loss):
123130
"""Class for periodic mean squared error loss.
124131
125132
Args:
126-
reduction (str, optional): Reduction method. Defaults to "mean".
133+
reduction (Literal["mean", "sum"], optional): Reduction method. Defaults to "mean".
134+
weight (Optional[Union[float, Dict[str, float]]]): Weight for loss. Defaults to None.
127135
"""
128136

129-
def __init__(self, reduction="mean"):
130-
super().__init__()
137+
def __init__(
138+
self,
139+
reduction: Literal["mean", "sum"] = "mean",
140+
weight: Optional[Union[float, Dict[str, float]]] = None,
141+
):
131142
if reduction not in ["mean", "sum"]:
132143
raise ValueError(
133144
f"reduction should be 'mean' or 'sum', but got {reduction}"
134145
)
135-
self.reduction = reduction
146+
super().__init__(reduction, weight)
136147

137148
def forward(self, output_dict, label_dict, weight_dict=None):
138149
losses = 0.0
@@ -156,5 +167,11 @@ def forward(self, output_dict, label_dict, weight_dict=None):
156167
loss = loss.sum()
157168
elif self.reduction == "mean":
158169
loss = loss.mean()
170+
171+
if isinstance(self.weight, float):
172+
loss *= self.weight
173+
elif isinstance(self.weight, dict) and key in self.weight:
174+
loss *= self.weight[key]
175+
159176
losses += loss
160177
return losses

ppsci/visualize/visualizer.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -292,7 +292,6 @@ class Visualizer3D(base.Visualizer):
292292
output_expr (Dict[str, Callable]): Output expression.
293293
batch_size (int, optional): Batch size of data when computing result in visu.py. Defaults to 64.
294294
label_dict (Dict[str, np.ndarray]): Label dict.
295-
transforms (Optional[Dict[str, ]]): Transformer dict.
296295
time_list (Optional[Tuple[float, ...]]): Time list.
297296
num_timestamps (int, optional): Number of timestamps.
298297
prefix (str, optional): Prefix for output file.

0 commit comments

Comments
 (0)