File tree Expand file tree Collapse file tree 2 files changed +6
-6
lines changed Expand file tree Collapse file tree 2 files changed +6
-6
lines changed Original file line number Diff line number Diff line change @@ -82,17 +82,17 @@ def get_latitude_weight(self, num_lat: int = 720):
82
82
weight = weight .reshape ((1 , 1 , - 1 , 1 ))
83
83
return weight
84
84
85
- def expm1_data (self , x : paddle .Tensor ):
85
+ def scale_expm1 (self , x : paddle .Tensor ):
86
86
return self .scale * paddle .expm1 (x )
87
87
88
88
@paddle .no_grad ()
89
89
def forward (self , output_dict , label_dict ):
90
90
metric_dict = {}
91
91
for key in label_dict :
92
92
output = (
93
- self .expm1_data (output_dict [key ]) if self .unlog else output_dict [key ]
93
+ self .scale_expm1 (output_dict [key ]) if self .unlog else output_dict [key ]
94
94
)
95
- label = self .expm1_data (label_dict [key ]) if self .unlog else label_dict [key ]
95
+ label = self .scale_expm1 (label_dict [key ]) if self .unlog else label_dict [key ]
96
96
97
97
if self .mean is not None :
98
98
output = output - self .mean
Original file line number Diff line number Diff line change @@ -109,17 +109,17 @@ def get_latitude_weight(self, num_lat: int = 720):
109
109
weight = weight .reshape ((1 , 1 , - 1 , 1 ))
110
110
return weight
111
111
112
- def expm1_data (self , x : paddle .Tensor ):
112
+ def scale_expm1 (self , x : paddle .Tensor ):
113
113
return self .scale * paddle .expm1 (x )
114
114
115
115
@paddle .no_grad ()
116
116
def forward (self , output_dict , label_dict ):
117
117
metric_dict = {}
118
118
for key in label_dict :
119
119
output = (
120
- self .expm1_data (output_dict [key ]) if self .unlog else output_dict [key ]
120
+ self .scale_expm1 (output_dict [key ]) if self .unlog else output_dict [key ]
121
121
)
122
- label = self .expm1_data (label_dict [key ]) if self .unlog else label_dict [key ]
122
+ label = self .scale_expm1 (label_dict [key ]) if self .unlog else label_dict [key ]
123
123
124
124
mse = F .mse_loss (output , label , "none" )
125
125
rmse = (mse * self .weight ).mean (axis = (- 1 , - 2 )) ** 0.5
You can’t perform that action at this time.
0 commit comments