Skip to content

Commit d19f08b

Browse files
clessigkacpnowaktjhunterclessig
authored
Clessig/develop/channel logging 282 (ecmwf#615)
* Fix bug with seed being divided by 0 for worker ID=0 * Fix bug causing crash when secrets aren't in private config * Implement logging losses per channel * Fix issue with empty targets * Rework loss logging * ruff * Remove computing max_channels * Change variables names * ruffed * Remove redundant enumerations * Use stages for logging * Add type hints * Apply the review * ruff * fix * Fix type hints * ruff * Implement sending tensors of different shapes * ruff * Fix merge * Fix docstring * rerun workflow * Review * Change default colums name * Fix merge * - Added ddp_average_nan that is robust to NaN/0 entries when computing mean - Switched from all_gather to this function in trainer to robustly average - Some code cleanup * use all_to_all communication * Fixing problem with single-worker (non-DDP) training * Ruffed * Re-enabled validation loss output in terminal * Simplified handling of dist initalized --------- Co-authored-by: Kacper Nowak <kacper.nowak@awi.de> Co-authored-by: Tim Hunter <tim.hunter@ecmwf.int> Co-authored-by: clessig <christian.lessig@ecwmf.int>
1 parent 511f036 commit d19f08b

File tree

8 files changed

+312
-195
lines changed

8 files changed

+312
-195
lines changed

.gitignore

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -204,10 +204,8 @@ output/
204204
logs/
205205
models/
206206
results/
207+
plots/
207208
models
208209
results
209210
playground/
210-
plots/
211211
.config/
212-
213-

config/streams/streams_ocean/fesom.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
FESOM :
1111
type : fesom
12-
filenames : ['fesom_ifs_awi']
12+
filenames : ['test4.zarr']
1313
loss_weight : 1.
1414
source : null
1515
target : ['sst']

src/weathergen/datasets/data_reader_fesom.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ def __init__(
143143

144144
self.mean = np.concatenate((np.array([0, 0]), np.array(self.ds.data.attrs["means"])))
145145
self.stdev = np.sqrt(
146-
np.concatenate((np.array([1, 1]), np.array(self.ds.data.attrs["vars"])))
146+
np.concatenate((np.array([1, 1]), np.array(self.ds.data.attrs["std"])))
147147
)
148148

149149
source_channels = stream_info.get("source")

src/weathergen/train/loss.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
import numpy as np
1212
import torch
1313

14+
stat_loss_fcts = ["stats", "kernel_crps"] # Names of loss functions that need std computed
15+
1416

1517
####################################################################################################
1618
def gaussian(x, mu=0.0, std_dev=1.0):

0 commit comments

Comments
 (0)