99from abc import ABC , abstractmethod
1010from functools import reduce
1111from importlib .util import find_spec
12- from typing import Any
12+ from typing import Any , Optional
1313
1414import pandas as pd
1515import xarray as xr
@@ -74,19 +74,20 @@ def __init__(
7474 token: The API token for authentication. Defaults to None.
7575 application_id: The Application ID for authentication. Defaults to None.
7676 """
77-
77+
7878 self .territory = territory # "FRANCE", "ANTIL", or others (see API doc)
7979 self .precision = precision
8080 self ._validate_parameters ()
8181
8282 self ._capabilities : pd .DataFrame | None = None
83-
84- if self .MODEL_TYPE == "ENSEMBLE" :
85- self ._entry_point : str = (
83+ self ._entry_point : str
84+
85+ if self .MODEL_TYPE == "ENSEMBLE" :
86+ self ._entry_point = (
8687 f"{ self .BASE_ENTRY_POINT } xxx-{ self .PRECISION_FLOAT_TO_STR [self .precision ]} -{ self .territory } -WCS"
8788 )
8889 else :
89- self ._entry_point : str = (
90+ self ._entry_point = (
9091 f"{ self .BASE_ENTRY_POINT } -{ self .PRECISION_FLOAT_TO_STR [self .precision ]} -{ self .territory } -WCS"
9192 )
9293 self ._model_base_path = self .MODEL_NAME + "/" + self .API_VERSION
@@ -125,34 +126,37 @@ def get_capabilities(self) -> pd.DataFrame:
125126 """
126127 return self .capabilities
127128
128- def get_coverage_description (self , coverage_id : str , ensemble_numbers : list [int ] | None = None ) -> dict [str , Any ]:
129+ def get_coverage_description (
130+ self , coverage_id : str , ensemble_numbers : list [Optional [int ]] | None = None
131+ ) -> dict [str , Any ]:
129132 """Return the available axis (times, heights) of a coverage.
130133
131134 TODO: Other informations can be fetched, not yet implemented.
132135
133136 Args:
134137 coverage_id: An id of a coverage, use get_capabilities() to get them.
135- ensemble_numbers: For ensemble models only, numbers of the desired
138+ ensemble_numbers: For ensemble models only, numbers of the desired
136139 ensemble members. If None, defaults to the member 0.
137140 Returns:
138141 A dictionary containing more info on the coverage.
139142 """
140-
143+ numbers_to_fetch : list [Optional [int ]]
144+
141145 if self .MODEL_TYPE == "ENSEMBLE" :
142146 if ensemble_numbers is None :
143147 numbers_to_fetch = [0 ]
144- else :
148+ else :
145149 numbers_to_fetch = ensemble_numbers
146150 coverage_description = {}
147151 else :
148152 numbers_to_fetch = [None ]
149-
153+
150154 for ensemble_number in numbers_to_fetch :
151155 description = self ._get_coverage_description (coverage_id , ensemble_number )
152156 grid_axis = description ["wcs:CoverageDescriptions" ]["wcs:CoverageDescription" ]["gml:domainSet" ][
153157 "gmlrgrid:ReferenceableGridByVectors"
154158 ]["gmlrgrid:generalGridAxis" ]
155-
159+
156160 coverage_description_single = {
157161 "forecast_horizons" : [
158162 dt .timedelta (seconds = time ) for time in self ._get_available_feature (grid_axis , "time" )
@@ -164,7 +168,7 @@ def get_coverage_description(self, coverage_id: str, ensemble_numbers: list[int]
164168 coverage_description = coverage_description_single
165169 else :
166170 coverage_description [f"number_{ ensemble_number } " ] = coverage_description_single
167-
171+
168172 return coverage_description
169173
170174 def get_coverage (
@@ -173,7 +177,7 @@ def get_coverage(
173177 lat : tuple = FRANCE_METRO_LATITUDES ,
174178 long : tuple = FRANCE_METRO_LONGITUDES ,
175179 ensemble_numbers : list [int ] | None = None ,
176- heights : list [int ] | None = None ,
180+ heights : list [int ] | None = None ,
177181 pressures : list [int ] | None = None ,
178182 forecast_horizons : list [dt .timedelta ] | None = None ,
179183 run : str | None = None ,
@@ -187,7 +191,7 @@ def get_coverage(
187191 indicator: Indicator of a coverage to retrieve.
188192 lat: Minimum and maximum latitude.
189193 long: Minimum and maximum longitude.
190- ensemble_numbers: For ensemble models only, numbers of the desired
194+ ensemble_numbers: For ensemble models only, numbers of the desired
191195 ensemble members. If None, defaults to the member 0.
192196 heights: Heights in meters.
193197 pressures: Pressures in hPa.
@@ -208,24 +212,23 @@ def get_coverage(
208212 if ensemble_numbers is None :
209213 ensemble_numbers = [0 ]
210214 logger .info (f"Using { len (ensemble_numbers )} ensemble members" )
211-
215+
212216 # Ensure we only have one of coverage_id, indicator
213217 if not bool (indicator ) ^ bool (coverage_id ):
214218 raise ValueError ("Argument `indicator` or `coverage_id` need to be set (only one of them)" )
215219 if indicator is not None :
216220 coverage_id = self ._get_coverage_id (indicator , run , interval )
217-
221+
218222 logger .info (f"Using `coverage_id={ coverage_id } `" )
219-
220-
223+
221224 axis = self .get_coverage_description (coverage_id )
222225
223226 heights = self ._raise_if_invalid_or_fetch_default ("heights" , heights , axis ["heights" ])
224227 pressures = self ._raise_if_invalid_or_fetch_default ("pressures" , pressures , axis ["pressures" ])
225228 forecast_horizons = self ._raise_if_invalid_or_fetch_default (
226229 "forecast_horizons" , forecast_horizons , axis ["forecast_horizons" ]
227230 )
228-
231+
229232 df_list = [
230233 self ._get_data_single_forecast (
231234 coverage_id = coverage_id ,
@@ -254,8 +257,7 @@ def _build_capabilities(self) -> pd.DataFrame:
254257 """
255258
256259 logger .info ("Fetching all available coverages..." )
257-
258-
260+
259261 capabilities = self ._fetch_capabilities ()
260262 df_capabilities = pd .DataFrame (capabilities ["wcs:Capabilities" ]["wcs:Contents" ]["wcs:CoverageSummary" ])
261263 df_capabilities = df_capabilities .rename (
@@ -272,13 +274,11 @@ def _build_capabilities(self) -> pd.DataFrame:
272274 df_capabilities ["interval" ] = [
273275 coverage_id .split ("___" )[1 ].split ("Z" )[1 ].strip ("_" ) for coverage_id in df_capabilities ["id" ]
274276 ]
275-
277+
276278 nb_indicators = len (df_capabilities ["indicator" ].unique ())
277279 nb_coverage_ids = df_capabilities .shape [0 ]
278280 runs = df_capabilities ["run" ].unique ()
279-
280-
281-
281+
282282 logger .info (
283283 f"\n "
284284 f"\t Successfully fetched { nb_coverage_ids } coverages,\n "
@@ -287,8 +287,7 @@ def _build_capabilities(self) -> pd.DataFrame:
287287 f"\n "
288288 f"\t Default run for `get_coverage`: { runs .max ()} )"
289289 )
290-
291-
290+
292291 return df_capabilities
293292
294293 def _get_coverage_id (
@@ -400,12 +399,12 @@ def _fetch_capabilities(self) -> dict[Any, Any]:
400399 Returns:
401400 Raw capabilities (dictionary).
402401 """
403-
404- if self .MODEL_TYPE == "ENSEMBLE" :
402+
403+ if self .MODEL_TYPE == "ENSEMBLE" :
405404 url = f"{ self ._model_base_path } /{ self ._entry_point .replace ('xxx' , '001' )} /GetCapabilities"
406405 else :
407406 url = f"{ self ._model_base_path } /{ self ._entry_point } /GetCapabilities"
408-
407+
409408 params = {
410409 "service" : "WCS" ,
411410 "version" : "2.0.1" ,
@@ -442,14 +441,16 @@ def _get_coverage_description(self, coverage_id: str, ensemble_number: int | Non
442441 coverage_id (str): the Coverage ID. Use :meth:`get_coverage` to access the available coverage ids.
443442 By default use the latest temperature coverage ID.
444443 ensemble_number: For ensemble models only, number of the desired ensemble member.
445-
444+
446445 Returns:
447446 description (dict): the description of the coverage.
448447 """
449448 if self .MODEL_TYPE == "ENSEMBLE" :
450- url = f"{ self ._model_base_path } /{ self ._entry_point .replace ('xxx' , '{:03}' .format (ensemble_number ))} /DescribeCoverage"
449+ url = (
450+ f"{ self ._model_base_path } /{ self ._entry_point .replace ('xxx' , f'{ ensemble_number :03} ' )} /DescribeCoverage"
451+ )
451452 else :
452- url = f"{ self ._model_base_path } /{ self ._entry_point } /DescribeCoverage"
453+ url = f"{ self ._model_base_path } /{ self ._entry_point } /DescribeCoverage"
453454 params = {
454455 "service" : "WCS" ,
455456 "version" : "2.0.1" ,
@@ -567,9 +568,17 @@ def _get_data_single_forecast(
567568 )
568569 if ensemble_number is None :
569570 known_columns = {"latitude" , "longitude" , "run" , "forecast_horizon" , "heightAboveGround" , "isobaricInhPa" }
570- else :
571- known_columns = {"latitude" , "longitude" , "number" , "run" , "forecast_horizon" , "heightAboveGround" , "isobaricInhPa" }
572-
571+ else :
572+ known_columns = {
573+ "latitude" ,
574+ "longitude" ,
575+ "number" ,
576+ "run" ,
577+ "forecast_horizon" ,
578+ "heightAboveGround" ,
579+ "isobaricInhPa" ,
580+ }
581+
573582 indicator_column = (set (df .columns ) - known_columns ).pop ()
574583
575584 if indicator_column == "unknown" :
@@ -588,7 +597,7 @@ def _get_data_single_forecast(
588597 df .rename (columns = {indicator_column : new_indicator_column }, inplace = True )
589598 if self .MODEL_TYPE == "ENSEMBLE" :
590599 df .rename (columns = {"number" : "ensemble_number" }, inplace = True )
591-
600+
592601 df .drop (
593602 columns = ["isobaricInhPa" , "heightAboveGround" , "meanSea" , "potentialVorticity" ],
594603 errors = "ignore" ,
@@ -634,9 +643,9 @@ def _get_coverage_file(
634643 raster.plot_tiff_file: Method for plotting raster data stored in TIFF format.
635644 """
636645 if ensemble_number is None :
637- url = f"{ self ._model_base_path } /{ self ._entry_point } /GetCoverage"
646+ url = f"{ self ._model_base_path } /{ self ._entry_point } /GetCoverage"
638647 else :
639- url = f"{ self ._model_base_path } /{ self ._entry_point .replace ('xxx' , '{ :03}'. format ( ensemble_number )) } /GetCoverage"
648+ url = f"{ self ._model_base_path } /{ self ._entry_point .replace ('xxx' , f' { ensemble_number :03} ') } /GetCoverage"
640649
641650 params = {
642651 "service" : "WCS" ,
@@ -703,7 +712,7 @@ def get_combined_coverage(
703712 Args:
704713 indicator_names (list[str]): A list of indicator names to retrieve data for.
705714 runs (list[str]): A list of runs for each indicator. Format should be "YYYY-MM-DDTHH:MM:SSZ".
706- ensemble_numbers: For ensemble models only, numbers of the desired
715+ ensemble_numbers: For ensemble models only, numbers of the desired
707716 ensemble members. If None, defaults to the member 0.
708717 heights (list[int] | None): A list of heights in meters to filter by (default is None).
709718 pressures (list[int] | None): A list of pressures in hPa to filter by (default is None).
@@ -724,7 +733,7 @@ def get_combined_coverage(
724733 # Numbers cannot be None if the model type is ENSEMBLE
725734 if self .MODEL_TYPE == "ENSEMBLE" and ensemble_numbers is None :
726735 ensemble_numbers = [0 ]
727-
736+
728737 if runs is None :
729738 runs = [None ]
730739 coverages = [
@@ -767,7 +776,7 @@ def _get_combined_coverage_for_single_run(
767776 Args:
768777 indicator_names (list[str]): A list of indicator names to retrieve data for.
769778 run (str): A single runs for each indicator. Format should be "YYYY-MM-DDTHH:MM:SSZ".
770- ensemble_numbers: For ensemble models only, numbers of the desired
779+ ensemble_numbers: For ensemble models only, numbers of the desired
771780 ensemble members. If None, defaults to the member 0.
772781 heights (list[int] | None): A list of heights in meters to filter by (default is None).
773782 pressures (list[int] | None): A list of pressures in hPa to filter by (default is None).
@@ -828,37 +837,43 @@ def _check_params_length(params: list[Any] | None, arg_name: str) -> list[Any]:
828837 else :
829838 forecast_horizons = [self .find_common_forecast_horizons (coverage_ids )[0 ]]
830839 logger .info (f"Using common forecast_horizons `forecast_horizons={ forecast_horizons } `." )
831-
832- coverages = [[
833- self .get_coverage (
834- coverage_id = coverage_id ,
835- run = run ,
836- lat = lat ,
837- long = long ,
838- ensemble_numbers = [ensemble_number ] if ensemble_number is not None else None ,
839- heights = [height ] if height is not None else [],
840- pressures = [pressure ] if pressure is not None else [],
841- forecast_horizons = forecast_horizons ,
842- temp_dir = temp_dir ,
843- )
844- for coverage_id , height , pressure in zip (coverage_ids , heights , pressures )
845- ] for ensemble_number in ([None ] if (ensemble_numbers is None ) else ensemble_numbers )
840+
841+ coverages = [
842+ [
843+ self .get_coverage (
844+ coverage_id = coverage_id ,
845+ run = run ,
846+ lat = lat ,
847+ long = long ,
848+ ensemble_numbers = [ensemble_number ] if ensemble_number is not None else None ,
849+ heights = [height ] if height is not None else [],
850+ pressures = [pressure ] if pressure is not None else [],
851+ forecast_horizons = forecast_horizons ,
852+ temp_dir = temp_dir ,
853+ )
854+ for coverage_id , height , pressure in zip (coverage_ids , heights , pressures )
855+ ]
856+ for ensemble_number in ([None ] if (ensemble_numbers is None ) else ensemble_numbers )
846857 ]
847-
848- coverages_concat = pd .concat ([reduce (
849- lambda left , right : pd .merge (
850- left ,
851- right ,
852- on = ["latitude" , "longitude" , "ensemble_number" , "run" , "forecast_horizon" ] if self .MODEL_TYPE == "ENSEMBLE" else ["latitude" , "longitude" , "run" , "forecast_horizon" ],
853- how = "inner" ,
854- validate = "one_to_one" ,
855- ),
856- coverages [i ],
857- )
858- for i in range (len (coverages ))
858+
859+ coverages_concat = pd .concat (
860+ [
861+ reduce (
862+ lambda left , right : pd .merge (
863+ left ,
864+ right ,
865+ on = ["latitude" , "longitude" , "ensemble_number" , "run" , "forecast_horizon" ]
866+ if self .MODEL_TYPE == "ENSEMBLE"
867+ else ["latitude" , "longitude" , "run" , "forecast_horizon" ],
868+ how = "inner" ,
869+ validate = "one_to_one" ,
870+ ),
871+ coverages [i ],
872+ )
873+ for i in range (len (coverages ))
859874 ]
860- )
861-
875+ )
876+
862877 return coverages_concat
863878
864879 def _get_forecast_horizons (self , coverage_ids : list [str ]) -> list [list [dt .timedelta ]]:
@@ -923,4 +938,4 @@ def _validate_forecast_horizons(self, coverage_ids: list[str], forecast_horizons
923938 if not set (forecast_horizons ).issubset (times )
924939 ]
925940
926- return invalid_coverage_ids
941+ return invalid_coverage_ids
0 commit comments