Skip to content

Commit f26dcf5

Browse files
committed
allow int or float loss weight type
1 parent bac040b commit f26dcf5

File tree

4 files changed

+8
-8
lines changed

4 files changed

+8
-8
lines changed

ppsci/loss/integral.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def forward(self, output_dict, label_dict, weight_dict=None):
7171
elif self.reduction == "mean":
7272
loss = loss.mean()
7373

74-
if isinstance(self.weight, float):
74+
if isinstance(self.weight, (float, int)):
7575
loss *= self.weight
7676
elif isinstance(self.weight, dict) and key in self.weight:
7777
loss *= self.weight[key]

ppsci/loss/l1.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def forward(self, output_dict, label_dict, weight_dict=None):
5959
loss = F.l1_loss(output_dict[key], label_dict[key], "none")
6060
if weight_dict:
6161
loss *= weight_dict[key]
62-
if isinstance(self.weight, float):
62+
if isinstance(self.weight, (float, int)):
6363
loss *= self.weight
6464
elif isinstance(self.weight, dict) and key in self.weight:
6565
loss *= self.weight[key]
@@ -72,7 +72,7 @@ def forward(self, output_dict, label_dict, weight_dict=None):
7272
elif self.reduction == "mean":
7373
loss = loss.mean()
7474

75-
if isinstance(self.weight, float):
75+
if isinstance(self.weight, (float, int)):
7676
loss *= self.weight
7777
elif isinstance(self.weight, dict) and key in self.weight:
7878
loss *= self.weight[key]
@@ -122,7 +122,7 @@ def forward(self, output_dict, label_dict, weight_dict=None):
122122
elif self.reduction == "mean":
123123
loss = loss.mean()
124124

125-
if isinstance(self.weight, float):
125+
if isinstance(self.weight, (float, int)):
126126
loss *= self.weight
127127
elif isinstance(self.weight, dict) and key in self.weight:
128128
loss *= self.weight[key]

ppsci/loss/l2.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def forward(self, output_dict, label_dict, weight_dict=None):
5252
loss *= output_dict["area"]
5353

5454
loss = loss.sum()
55-
if isinstance(self.weight, float):
55+
if isinstance(self.weight, (float, int)):
5656
loss *= self.weight
5757
elif isinstance(self.weight, dict) and key in self.weight:
5858
loss *= self.weight[key]
@@ -94,7 +94,7 @@ def forward(self, output_dict, label_dict, weight_dict=None):
9494
loss *= output_dict["area"]
9595

9696
loss = loss.sum()
97-
if isinstance(self.weight, float):
97+
if isinstance(self.weight, (float, int)):
9898
loss *= self.weight
9999
elif isinstance(self.weight, dict) and key in self.weight:
100100
loss *= self.weight[key]

ppsci/loss/mse.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def forward(self, output_dict, label_dict, weight_dict=None):
6868
loss = loss.sum()
6969
elif self.reduction == "mean":
7070
loss = loss.mean()
71-
if isinstance(self.weight, float):
71+
if isinstance(self.weight, (float, int)):
7272
loss *= self.weight
7373
elif isinstance(self.weight, dict) and key in self.weight:
7474
loss *= self.weight[key]
@@ -168,7 +168,7 @@ def forward(self, output_dict, label_dict, weight_dict=None):
168168
elif self.reduction == "mean":
169169
loss = loss.mean()
170170

171-
if isinstance(self.weight, float):
171+
if isinstance(self.weight, (float, int)):
172172
loss *= self.weight
173173
elif isinstance(self.weight, dict) and key in self.weight:
174174
loss *= self.weight[key]

0 commit comments

Comments
 (0)