Skip to content

Commit b74a0f2

Browse files
authored
Fixed broken evaluate (ecmwf#131)
* Fixed broken evaluate: - evaluate() used all channels infos for model - file name for latest checkpoint was incorrect (introduced in pathlib changes) * Fix path handling to make it backward compatible. (ecmwf#133) * Fixed broken evaluate: - evaluate() used all channels infos for model - file name for latest checkpoint was incorrect (introduced in pathlib changes) * Fixed issues with evaluate. * Ruffed.
1 parent 6267cd1 commit b74a0f2

File tree

2 files changed

+10
-26
lines changed

2 files changed

+10
-26
lines changed

src/weathergen/__init__.py

Lines changed: 5 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -25,19 +25,6 @@ def evaluate():
2525
"""
2626
Evaluation function for WeatherGenerator model.
2727
Entry point for calling the evaluation code from the command line.
28-
29-
Args:
30-
run_id (str): Run/model id of pretrained WeatherGenerator model.
31-
start_date (str): Start date for evaluation. Format must be parsable with pd.to_datetime.
32-
end_date (str): End date for evaluation. Format must be parsable with pd.to_datetime.
33-
epoch (int, optional): Epoch of pretrained WeatherGenerator model used for evaluation (-1 corresponds to last epoch). Defaults to -1.
34-
masking_mode (str, optional): Masking mode for evaluation. Defaults to None.
35-
forecast_steps (int, optional): Number of forecast steps for evaluation. Defaults to None.
36-
samples (int, optional): Number of samples for evaluation. Defaults to 10000000.
37-
shuffle (bool, optional): Shuffle samples for evaluation. Defaults to False.
38-
save_samples (bool, optional): Save samples for evaluation. Defaults to True.
39-
analysis_streams_output (list, optional): Analysis output streams during evaluation. Defaults to ['ERA5'].
40-
gridded_output_streams(list, optional): Currently unused and threrefore omitted here
4128
"""
4229
parser = argparse.ArgumentParser()
4330

@@ -52,13 +39,15 @@ def evaluate():
5239
"-start",
5340
type=str,
5441
required=False,
42+
default="2022-10-01",
5543
help="Start date for evaluation. Format must be parsable with pd.to_datetime.",
5644
)
5745
parser.add_argument(
5846
"--end_date",
5947
"-end",
6048
type=str,
6149
required=False,
50+
default="2022-12-01",
6251
help="End date for evaluation. Format must be parsable with pd.to_datetime.",
6352
)
6453
parser.add_argument(
@@ -104,7 +93,8 @@ def evaluate():
10493
init_loggers()
10594

10695
# load config: if run_id is full path, it loads from there
107-
cf = Config.load(args.run_id, args.epoch, private_cf["model_path"])
96+
model_path = private_cf["model_path"] if hasattr(private_cf, "model_path") else "./models"
97+
cf = Config.load(args.run_id, args.epoch, model_path)
10898

10999
# add parameters from private (paths) config
110100
for k, v in private_cf.items():
@@ -117,16 +107,8 @@ def evaluate():
117107

118108
start_date, end_date = pd.to_datetime(args.start_date), pd.to_datetime(args.end_date)
119109

120-
cf.start_date_val = start_date.strftime(
121-
"%Y%m%d%H%M"
122-
) # ML: would be better to use datetime-objects
110+
cf.start_date_val = start_date.strftime("%Y%m%d%H%M")
123111
cf.end_date_val = end_date.strftime("%Y%m%d%H%M")
124-
# # Oct-Nov 2022
125-
# cf.start_date_val = 202210011600
126-
# cf.end_date_val = 202212010400
127-
# # 2022
128-
# cf.start_date_val = 202201010400
129-
# cf.end_date_val = 202301010400
130112

131113
cf.shuffle = args.shuffle
132114

src/weathergen/train/trainer.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -120,10 +120,12 @@ def evaluate(self, cf, run_id_trained, epoch, run_id_new=False):
120120
self.dataset_val, **loader_params, sampler=None
121121
)
122122

123-
num_channels = self.dataset_val.get_num_chs()
124-
self.geoinfo_sizes = self.dataset_val.get_geoinfo_sizes()
123+
sources_size = self.dataset_val.get_sources_size()
124+
targets_num_channels = self.dataset_val.get_targets_num_channels()
125+
targets_coords_size = self.dataset_val.get_targets_coords_size()
125126

126-
self.model = Model(cf, num_channels, self.geoinfo_sizes).create().to(self.devices[0])
127+
self.model = Model(cf, sources_size, targets_num_channels, targets_coords_size).create()
128+
self.model = self.model.to(self.devices[0])
127129
self.model.load(run_id_trained, epoch)
128130
print(f"Loaded model {run_id_trained} at epoch {epoch}.")
129131
self.ddp_model = self.model

0 commit comments

Comments
 (0)