Skip to content

Commit 3f97922

Browse files
authored
Kacpnowak/develop/fix finetune forecast (ecmwf#257)
* Implement enforcement of training mode with finetune_forecast * Fix minor bug in fesom dataset * Ruffed
1 parent d2947c6 commit 3f97922

File tree

2 files changed

+6
-6
lines changed

2 files changed

+6
-6
lines changed

src/weathergen/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,7 @@ def train_continue() -> None:
184184

185185
#########################
186186
if args.finetune_forecast:
187+
cf.training_mode = "forecast"
187188
cf.forecast_delta_hrs = 0 # 12
188189
cf.forecast_steps = 1 # [j for j in range(1,9) for i in range(4)]
189190
cf.forecast_policy = "fixed" # 'sequential_random' # 'fixed' #'sequential' #_random'

src/weathergen/datasets/fesom_dataset.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ def __init__(
139139

140140
source_channels = stream_info["source"] if "source" in stream_info else None
141141
if source_channels:
142-
self.source_channels, self.source_idx = self.selec(source_channels)
142+
self.source_channels, self.source_idx = self.select(source_channels)
143143
else:
144144
self.source_channels = self.colnames
145145
self.source_idx = self.cols_idx
@@ -159,11 +159,10 @@ def select(self, ch_filters: list[str]) -> None:
159159
Allow user to specify which columns they want to access.
160160
Get functions only returned for these specified columns.
161161
"""
162-
163162
mask = [np.array([f in c for f in ch_filters]).any() for c in self.colnames]
164163

165-
selected_cols_idx = np.where(mask)[0]
166-
selected_colnames = [self.colnames[i] for i in selected_cols_idx]
164+
selected_cols_idx = self.cols_idx[np.where(mask)[0]]
165+
selected_colnames = [self.colnames[i] for i in np.where(mask)[0]]
167166

168167
return selected_colnames, selected_cols_idx
169168

@@ -343,7 +342,7 @@ def normalize_target_channels(self, target: torch.tensor) -> torch.tensor:
343342
"""
344343
assert target.shape[1] == len(self.target_idx)
345344
for i, ch in enumerate(self.target_idx):
346-
target[..., i] = (target[..., i] - self.mean[ch + 2]) / self.stdev[ch + 2]
345+
target[..., i] = (target[..., i] - self.mean[ch]) / self.stdev[ch]
347346

348347
return target
349348

@@ -380,7 +379,7 @@ def denormalize_target_channels(self, data: torch.tensor) -> torch.tensor:
380379
"""
381380
assert data.shape[-1] == len(self.target_idx), "incorrect number of channels"
382381
for i, ch in enumerate(self.target_idx):
383-
data[..., i] = (data[..., i] * self.stdev[ch + 2]) + self.mean[ch + 2]
382+
data[..., i] = (data[..., i] * self.stdev[ch]) + self.mean[ch]
384383

385384
return data
386385

0 commit comments

Comments
 (0)