Skip to content

Commit 20309ca

Browse files
fix
1 parent 8d444e0 commit 20309ca

File tree

3 files changed

+3
-12
lines changed

3 files changed

+3
-12
lines changed

ppsci/loss/integral.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -65,10 +65,6 @@ def forward(self, output_dict, label_dict, weight_dict=None):
6565
)
6666
if weight_dict is not None:
6767
loss *= weight_dict[key]
68-
if isinstance(self.weight, float):
69-
loss *= self.weight
70-
elif isinstance(self.weight, dict) and key in self.weight:
71-
loss *= self.weight[key]
7268

7369
if self.reduction == "sum":
7470
loss = loss.sum()

ppsci/loss/l2.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -47,10 +47,6 @@ def forward(self, output_dict, label_dict, weight_dict=None):
4747
loss = F.mse_loss(output_dict[key], label_dict[key], "none")
4848
if weight_dict is not None:
4949
loss *= weight_dict[key]
50-
if isinstance(self.weight, float):
51-
loss *= self.weight
52-
elif isinstance(self.weight, dict) and key in self.weight:
53-
loss *= self.weight[key]
5450

5551
if "area" in output_dict:
5652
loss *= output_dict["area"]
@@ -120,6 +116,7 @@ class L2RelLoss(base.Loss):
120116
121117
Args:
122118
reduction (Literal["mean", "sum"], optional): Specifies the reduction to apply to the output: 'mean' | 'sum'. Defaults to "mean".
119+
weight (Optional[Union[float, Dict[str, float]]]): Weight for loss. Defaults to None.
123120
124121
Examples:
125122
>>> import ppsci

ppsci/loss/mse.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -60,10 +60,6 @@ def forward(self, output_dict, label_dict, weight_dict=None):
6060
loss = F.mse_loss(output_dict[key], label_dict[key], "none")
6161
if weight_dict is not None:
6262
loss *= weight_dict[key]
63-
if isinstance(self.weight, (float, int)):
64-
loss *= self.weight
65-
elif isinstance(self.weight, dict) and key in self.weight:
66-
loss *= self.weight[key]
6763

6864
if "area" in output_dict:
6965
loss *= output_dict["area"]
@@ -97,6 +93,7 @@ class MSELossWithL2Decay(MSELoss):
9793
Args:
9894
reduction (Literal["mean", "sum"], optional): Specifies the reduction to apply to the output: 'mean' | 'sum'. Defaults to "mean".
9995
regularization_dict (Optional[Dict[str, float]]): Regularization dictionary. Defaults to None.
96+
weight (Optional[Union[float, Dict[str, float]]]): Weight for loss. Defaults to None.
10097
10198
Raises:
10299
ValueError: reduction should be 'mean' or 'sum'.
@@ -134,6 +131,7 @@ class PeriodicMSELoss(base.Loss):
134131
135132
Args:
136133
reduction (str, optional): Reduction method. Defaults to "mean".
134+
weight (Optional[Union[float, Dict[str, float]]]): Weight for loss. Defaults to None.
137135
"""
138136

139137
def __init__(

0 commit comments

Comments
 (0)