Skip to content

Commit 966083a

Browse files
jpolziluise
andauthored
Add forecast and observation activity (ecmwf#1126)
* Add calculation methods for forecast and observation activity metrics in Scores class * Add new calculation methods for forecast activity metrics in Scores class * ruff * fix func name * Rename observation activity calculation method to target activity in Scores class * typo * refactor to common calc_act function for activity * fix cases * have calc_tact and calc_fact that use _calc_act for maintainability * fix small thing in style --------- Co-authored-by: iluise <luise.ilaria@gmail.com>
1 parent ecaffd2 commit 966083a

File tree

1 file changed

+143
-18
lines changed
  • packages/evaluate/src/weathergen/evaluate

1 file changed

+143
-18
lines changed

packages/evaluate/src/weathergen/evaluate/score.py

Lines changed: 143 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def _get_skill_score(
5353
5454
Returns
5555
----------
56-
skiil_score : xr.DataArray
56+
skill_score : xr.DataArray
5757
Skill score data array
5858
"""
5959

@@ -184,6 +184,8 @@ def __init__(
184184
"acc": self.calc_acc,
185185
"froct": self.calc_froct,
186186
"troct": self.calc_troct,
187+
"fact": self.calc_fact,
188+
"tact": self.calc_tact,
187189
"grad_amplitude": self.calc_spatial_variability,
188190
"psnr": self.calc_psnr,
189191
"seeps": self.calc_seeps,
@@ -266,21 +268,25 @@ def get_score(
266268

267269
arg_names: list[str] = inspect.getfullargspec(f).args[1:]
268270

269-
if score_name in ["froct", "troct"]:
270-
args = {
271-
"p": data.prediction,
272-
"gt": data.ground_truth,
273-
"p_next": data.prediction_next,
274-
"gt_next": data.ground_truth_next,
275-
}
276-
elif score_name == "acc":
277-
args = {
278-
"p": data.prediction,
279-
"gt": data.ground_truth,
280-
"c": data.climatology,
281-
}
282-
else:
283-
args = {"p": data.prediction, "gt": data.ground_truth}
271+
score_args_map = {
272+
"froct": ["p", "gt", "p_next", "gt_next"],
273+
"troct": ["p", "gt", "p_next", "gt_next"],
274+
"acc": ["p", "gt", "c"],
275+
"fact": ["p", "c"],
276+
"tact": ["gt", "c"],
277+
}
278+
279+
available = {
280+
"p": data.prediction,
281+
"gt": data.ground_truth,
282+
"p_next": data.prediction_next,
283+
"gt_next": data.ground_truth_next,
284+
"c": data.climatology,
285+
}
286+
287+
#assign p and gt by default if metrics do not have specific args
288+
keys = score_args_map.get(score_name, ["p", "gt"])
289+
args = {k: available[k] for k in keys}
284290

285291
# Add group_by_coord if provided
286292
if group_by_coord is not None:
@@ -760,7 +766,7 @@ def calc_froct(
760766
"""
761767
if self._agg_dims is None:
762768
raise ValueError(
763-
"Cannot calculate forecast activity without aggregation dimensions (agg_dims=None)."
769+
"Cannot calculate rate of change without aggregation dimensions (agg_dims=None)."
764770
)
765771

766772
froct = self.calc_change_rate(p, p_next)
@@ -799,7 +805,7 @@ def calc_troct(
799805
"""
800806
if self._agg_dims is None:
801807
raise ValueError(
802-
"Cannot calculate forecast activity without aggregation dimensions (agg_dims=None)."
808+
"Cannot calculate rate of change without aggregation dimensions (agg_dims=None)."
803809
)
804810

805811
troct = self.calc_change_rate(gt, gt_next)
@@ -811,14 +817,131 @@ def calc_troct(
811817

812818
return troct
813819

820+
def _calc_act(
821+
self,
822+
x: xr.DataArray,
823+
c: xr.DataArray,
824+
group_by_coord: str | None = None,
825+
spatial_dims: list = None,
826+
):
827+
"""
828+
Calculate activity metric as standard deviation of forecast or target anomaly.
829+
830+
NOTE:
831+
The climatlogical mean data clim_mean must fit to the forecast and ground truth data.
832+
833+
Parameters
834+
----------
835+
x: xr.DataArray
836+
Forecast or target data array
837+
c: xr.DataArray
838+
Climatological mean data array, which is used to calculate anomalies
839+
group_by_coord: str
840+
Name of the coordinate to group by.
841+
If provided, the coordinate becomes a new dimension of the activity score.
842+
spatial_dims: List[str]
843+
Names of spatial dimensions over which activity is calculated.
844+
Note: No averaging is possible over these dimensions.
845+
"""
846+
847+
# Check if spatial_dims are in the data
848+
spatial_dims = ["ipoint"] if spatial_dims is None else to_list(spatial_dims)
849+
850+
for dim in spatial_dims:
851+
if dim not in x.dims:
852+
raise ValueError(
853+
f"Spatial dimension '{dim}' not found in prediction data dimensions: {x.dims}"
854+
)
855+
if c is None:
856+
return xr.full_like(x.sum(spatial_dims), np.nan)
857+
858+
# Calculate anomalies
859+
ano = x - c
860+
861+
if group_by_coord:
862+
# Apply groupby and calculate activity within each group using apply
863+
ano_grouped = ano.groupby(group_by_coord)
864+
865+
# Use apply to calculate activity for each group - this preserves the coordinate structure
866+
act = xr.concat(
867+
[ano_group.std(dim=spatial_dims) for group_label, ano_group in ano_grouped],
868+
dim=group_by_coord,
869+
).assign_coords({group_by_coord: list(ano_grouped.groups.keys())})
870+
871+
else:
872+
# Calculate forecast activity over spatial dimensions (no grouping)
873+
act = ano.std(dim=spatial_dims)
874+
875+
return act
876+
877+
def calc_fact(
878+
self,
879+
p: xr.DataArray,
880+
c: xr.DataArray,
881+
group_by_coord: str | None = None,
882+
spatial_dims: list = None,
883+
):
884+
"""
885+
Calculate forecast activity metric as standard deviation of forecast anomaly.
886+
887+
NOTE:
888+
The climatlogical mean data clim_mean must fit to the forecast data.
889+
890+
Parameters
891+
----------
892+
p: xr.DataArray
893+
Forecast data array
894+
c: xr.DataArray
895+
Climatological mean data array, which is used to calculate anomalies
896+
group_by_coord: str
897+
Name of the coordinate to group by.
898+
If provided, the coordinate becomes a new dimension of the activity score.
899+
spatial_dims: List[str]
900+
Names of spatial dimensions over which activity is calculated.
901+
Note: No averaging is possible over these dimensions.
902+
"""
903+
904+
return self._calc_act(p, c, group_by_coord, spatial_dims)
905+
906+
def calc_tact(
907+
self,
908+
gt: xr.DataArray,
909+
c: xr.DataArray,
910+
group_by_coord: str | None = None,
911+
spatial_dims: list = None,
912+
):
913+
"""
914+
Calculate target activity metric as standard deviation of target anomaly.
915+
916+
NOTE:
917+
The climatlogical mean data clim_mean must fit to the target data.
918+
919+
Parameters
920+
----------
921+
gt: xr.DataArray
922+
Target data array
923+
c: xr.DataArray
924+
Climatological mean data array, which is used to calculate anomalies
925+
group_by_coord: str
926+
Name of the coordinate to group by.
927+
If provided, the coordinate becomes a new dimension of the activity score.
928+
spatial_dims: List[str]
929+
Names of spatial dimensions over which activity is calculated.
930+
Note: No averaging is possible over these dimensions.
931+
"""
932+
933+
return self._calc_act(gt, c, group_by_coord, spatial_dims)
934+
814935
def _calc_acc_group(
815936
self, fcst: xr.DataArray, obs: xr.DataArray, spatial_dims: list[str]
816937
) -> xr.DataArray:
817938
"""Calculate ACC for a single group
818939
Parameters
819940
----------
941+
----------
820942
fcst: xr.DataArray
821943
Forecast data for the group
944+
Forecast data for the group
822945
obs: xr.DataArray
823946
Observation data for the group
824947
spatial_dims: List[str]
@@ -896,6 +1019,8 @@ def calc_acc(
8961019
# Calculate ACC over spatial dimensions (no grouping)
8971020
acc = self._calc_acc_group(fcst_ano, obs_ano, spatial_dims)
8981021

1022+
acc = self._calc_acc_group(fcst_ano, obs_ano, spatial_dims)
1023+
8991024
return acc
9001025

9011026
def calc_bias(self, p: xr.DataArray, gt: xr.DataArray, group_by_coord: str | None = None):

0 commit comments

Comments
 (0)