Skip to content

Commit 8d444e0

Browse files
refine weight arg in loss
1 parent a00eeee commit 8d444e0

File tree

5 files changed

+97
-25
lines changed

5 files changed

+97
-25
lines changed

ppsci/loss/base.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,16 @@
1717
from typing import Union
1818

1919
import paddle.nn as nn
20+
from typing_extensions import Literal
2021

2122

2223
class Loss(nn.Layer):
2324
"""Base class for loss."""
2425

2526
def __init__(
26-
self, reduction: str, weight: Optional[Union[Dict[str, float], float]] = None
27+
self,
28+
reduction: Literal["mean", "sum"],
29+
weight: Optional[Union[float, Dict[str, float]]] = None,
2730
):
2831
super().__init__()
2932
self.reduction = reduction

ppsci/loss/integral.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from typing import Union
1818

1919
import paddle.nn.functional as F
20+
from typing_extensions import Literal
2021

2122
from ppsci.loss import base
2223

@@ -36,7 +37,7 @@ class IntegralLoss(base.Loss):
3637
3738
Args:
3839
reduction (str, optional): Reduction method. Defaults to "mean".
39-
weight (Optional[Union[Dict[str, float], float]]): Weight for loss. Defaults to None.
40+
weight (Optional[Union[float, Dict[str, float]]]): Weight for loss. Defaults to None.
4041
4142
Examples:
4243
>>> import ppsci
@@ -45,8 +46,8 @@ class IntegralLoss(base.Loss):
4546

4647
def __init__(
4748
self,
48-
reduction: str = "mean",
49-
weight: Optional[Union[Dict[str, float], float]] = None,
49+
reduction: Literal["mean", "sum"] = "mean",
50+
weight: Optional[Union[float, Dict[str, float]]] = None,
5051
):
5152
if reduction not in ["mean", "sum"]:
5253
raise ValueError(
@@ -73,5 +74,11 @@ def forward(self, output_dict, label_dict, weight_dict=None):
7374
loss = loss.sum()
7475
elif self.reduction == "mean":
7576
loss = loss.mean()
77+
78+
if isinstance(self.weight, float):
79+
loss *= self.weight
80+
elif isinstance(self.weight, dict) and key in self.weight:
81+
loss *= self.weight[key]
82+
7683
losses += loss
7784
return losses

ppsci/loss/l1.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from typing import Union
1818

1919
import paddle.nn.functional as F
20+
from typing_extensions import Literal
2021

2122
from ppsci.loss import base
2223

@@ -34,7 +35,7 @@ class L1Loss(base.Loss):
3435
3536
Args:
3637
reduction (str, optional): Reduction method. Defaults to "mean".
37-
weight (Optional[Union[Dict[str, float], float]]): Weight for loss. Defaults to None.
38+
weight (Optional[Union[float, Dict[str, float]]]): Weight for loss. Defaults to None.
3839
3940
Examples:
4041
>>> import ppsci
@@ -43,8 +44,8 @@ class L1Loss(base.Loss):
4344

4445
def __init__(
4546
self,
46-
reduction: str = "mean",
47-
weight: Optional[Union[Dict[str, float], float]] = None,
47+
reduction: Literal["mean", "sum"] = "mean",
48+
weight: Optional[Union[float, Dict[str, float]]] = None,
4849
):
4950
if reduction not in ["mean", "sum"]:
5051
raise ValueError(
@@ -70,6 +71,12 @@ def forward(self, output_dict, label_dict, weight_dict=None):
7071
loss = loss.sum()
7172
elif self.reduction == "mean":
7273
loss = loss.mean()
74+
75+
if isinstance(self.weight, float):
76+
loss *= self.weight
77+
elif isinstance(self.weight, dict) and key in self.weight:
78+
loss *= self.weight[key]
79+
7380
losses += loss
7481
return losses
7582

@@ -81,13 +88,16 @@ class PeriodicL1Loss(base.Loss):
8188
reduction (str, optional): Reduction method. Defaults to "mean".
8289
"""
8390

84-
def __init__(self, reduction="mean"):
85-
super().__init__()
91+
def __init__(
92+
self,
93+
reduction: Literal["mean", "sum"] = "mean",
94+
weight: Optional[Union[float, Dict[str, float]]] = None,
95+
):
8696
if reduction not in ["mean", "sum"]:
8797
raise ValueError(
8898
f"reduction should be 'mean' or 'sum', but got {reduction}"
8999
)
90-
self.reduction = reduction
100+
super().__init__(reduction, weight)
91101

92102
def forward(self, output_dict, label_dict, weight_dict=None):
93103
losses = 0.0
@@ -111,5 +121,11 @@ def forward(self, output_dict, label_dict, weight_dict=None):
111121
loss = loss.sum()
112122
elif self.reduction == "mean":
113123
loss = loss.mean()
124+
125+
if isinstance(self.weight, float):
126+
loss *= self.weight
127+
elif isinstance(self.weight, dict) and key in self.weight:
128+
loss *= self.weight[key]
129+
114130
losses += loss
115131
return losses

ppsci/loss/l2.py

Lines changed: 35 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -31,14 +31,14 @@ class L2Loss(base.Loss):
3131
$$
3232
3333
Args:
34-
weight (Optional[Union[Dict[str, float], float]]): Weight for loss. Defaults to None.
34+
weight (Optional[Union[float, Dict[str, float]]]): Weight for loss. Defaults to None.
3535
3636
Examples:
3737
>>> import ppsci
3838
>>> loss = ppsci.loss.L2Loss()
3939
"""
4040

41-
def __init__(self, weight: Optional[Union[Dict[str, float], float]] = None):
41+
def __init__(self, weight: Optional[Union[float, Dict[str, float]]] = None):
4242
super().__init__("sum", weight)
4343

4444
def forward(self, output_dict, label_dict, weight_dict=None):
@@ -56,20 +56,28 @@ def forward(self, output_dict, label_dict, weight_dict=None):
5656
loss *= output_dict["area"]
5757

5858
loss = loss.sum()
59+
if isinstance(self.weight, float):
60+
loss *= self.weight
61+
elif isinstance(self.weight, dict) and key in self.weight:
62+
loss *= self.weight[key]
63+
5964
losses += loss
6065
return losses
6166

6267

6368
class PeriodicL2Loss(base.Loss):
6469
"""Class for Periodic l2 loss."""
6570

66-
def __init__(self, reduction="mean"):
67-
super().__init__()
71+
def __init__(
72+
self,
73+
reduction: Literal["mean", "sum"] = "mean",
74+
weight: Optional[Union[float, Dict[str, float]]] = None,
75+
):
6876
if reduction not in ["mean", "sum"]:
6977
raise ValueError(
7078
f"reduction should be 'mean' or 'sum', but got {reduction}"
7179
)
72-
self.reduction = reduction
80+
super().__init__(reduction, weight)
7381

7482
def forward(self, output_dict, label_dict, weight_dict=None):
7583
losses = 0.0
@@ -90,6 +98,11 @@ def forward(self, output_dict, label_dict, weight_dict=None):
9098
loss *= output_dict["area"]
9199

92100
loss = loss.sum()
101+
if isinstance(self.weight, float):
102+
loss *= self.weight
103+
elif isinstance(self.weight, dict) and key in self.weight:
104+
loss *= self.weight[key]
105+
93106
losses += loss
94107
return losses
95108

@@ -107,14 +120,22 @@ class L2RelLoss(base.Loss):
107120
108121
Args:
109122
reduction (Literal["mean", "sum"], optional): Specifies the reduction to apply to the output: 'mean' | 'sum'. Defaults to "mean".
110-
123+
111124
Examples:
112125
>>> import ppsci
113126
>>> loss = ppsci.loss.L2RelLoss()
114127
"""
115128

116-
def __init__(self, reduction: Literal["mean", "sum"] = "mean"):
117-
super().__init__(reduction)
129+
def __init__(
130+
self,
131+
reduction: Literal["mean", "sum"] = "mean",
132+
weight: Optional[Union[float, Dict[str, float]]] = None,
133+
):
134+
if reduction not in ["mean", "sum"]:
135+
raise ValueError(
136+
f"reduction should be 'mean' or 'sum', but got {reduction}"
137+
)
138+
super().__init__(reduction, weight)
118139

119140
def rel_loss(self, x, y):
120141
batch_size = x.shape[0]
@@ -134,6 +155,12 @@ def forward(self, output_dict, label_dict, weight_dict=None):
134155
loss = loss.sum()
135156
elif self.reduction == "mean":
136157
loss = loss.mean()
158+
159+
if isinstance(self.weight, float):
160+
loss *= self.weight
161+
elif isinstance(self.weight, dict) and key in self.weight:
162+
loss *= self.weight[key]
163+
137164
losses += loss
138165

139166
return losses

ppsci/loss/mse.py

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ class MSELoss(base.Loss):
3636
3737
Args:
3838
reduction (str, optional): Reduction method. Defaults to "mean".
39-
weight (Optional[Union[Dict[str, float], float]]): Weight for loss. Defaults to None.
39+
weight (Optional[Union[float, Dict[str, float]]]): Weight for loss. Defaults to None.
4040
4141
Examples:
4242
>>> import ppsci
@@ -45,8 +45,8 @@ class MSELoss(base.Loss):
4545

4646
def __init__(
4747
self,
48-
reduction: str = "mean",
49-
weight: Optional[Union[Dict[str, float], float]] = None,
48+
reduction: Literal["mean", "sum"] = "mean",
49+
weight: Optional[Union[float, Dict[str, float]]] = None,
5050
):
5151
if reduction not in ["mean", "sum"]:
5252
raise ValueError(
@@ -72,6 +72,11 @@ def forward(self, output_dict, label_dict, weight_dict=None):
7272
loss = loss.sum()
7373
elif self.reduction == "mean":
7474
loss = loss.mean()
75+
if isinstance(self.weight, float):
76+
loss *= self.weight
77+
elif isinstance(self.weight, dict) and key in self.weight:
78+
loss *= self.weight[key]
79+
7580
losses += loss
7681
return losses
7782

@@ -105,8 +110,13 @@ def __init__(
105110
self,
106111
reduction: Literal["mean", "sum"] = "mean",
107112
regularization_dict: Optional[Dict[str, float]] = None,
113+
weight: Optional[Union[float, Dict[str, float]]] = None,
108114
):
109-
super().__init__(reduction)
115+
if reduction not in ["mean", "sum"]:
116+
raise ValueError(
117+
f"reduction should be 'mean' or 'sum', but got {reduction}"
118+
)
119+
super().__init__(reduction, weight)
110120
self.regularization_dict = regularization_dict
111121

112122
def forward(self, output_dict, label_dict, weight_dict=None):
@@ -126,13 +136,16 @@ class PeriodicMSELoss(base.Loss):
126136
reduction (str, optional): Reduction method. Defaults to "mean".
127137
"""
128138

129-
def __init__(self, reduction="mean"):
130-
super().__init__()
139+
def __init__(
140+
self,
141+
reduction: Literal["mean", "sum"] = "mean",
142+
weight: Optional[Union[float, Dict[str, float]]] = None,
143+
):
131144
if reduction not in ["mean", "sum"]:
132145
raise ValueError(
133146
f"reduction should be 'mean' or 'sum', but got {reduction}"
134147
)
135-
self.reduction = reduction
148+
super().__init__(reduction, weight)
136149

137150
def forward(self, output_dict, label_dict, weight_dict=None):
138151
losses = 0.0
@@ -156,5 +169,11 @@ def forward(self, output_dict, label_dict, weight_dict=None):
156169
loss = loss.sum()
157170
elif self.reduction == "mean":
158171
loss = loss.mean()
172+
173+
if isinstance(self.weight, float):
174+
loss *= self.weight
175+
elif isinstance(self.weight, dict) and key in self.weight:
176+
loss *= self.weight[key]
177+
159178
losses += loss
160179
return losses

0 commit comments

Comments
 (0)