Skip to content

Commit de731a3

Browse files
tjhunterTimothee Hunter
andauthored
Removing some hardcoded paths and starting refactor (ecmwf#32)
* reformating code * cicd * linting fixes * format * changes' * work * work * work * changes * changed refactor settings * misconfigured * fixes * fixes * fixes * reset bad formatting * reset files * reset files * style * changes * comments * changes --------- Co-authored-by: Timothee Hunter <ecm8774@ac6-100.bullx>
1 parent 6b8dec2 commit de731a3

File tree

9 files changed

+132
-43
lines changed

9 files changed

+132
-43
lines changed
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
# (C) Copyright 2024 WeatherGenerator contributors.
2+
#
3+
# This software is licensed under the terms of the Apache Licence Version 2.0
4+
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
5+
#
6+
# In applying this licence, ECMWF does not waive the privileges and immunities
7+
# granted to it by virtue of its status as an intergovernmental organisation
8+
# nor does it submit to any jurisdiction.
9+
10+
ERA5 :
11+
type : anemoi
12+
filenames : ['aifs-ea-an-oper-0001-mars-o96-1979-2022-6h-v6.zarr']
13+
loss_weight : 1.
14+
source_variables : [null]
15+
target_variables : [null]
16+
diagnostic : False
17+
masking_rate : 0.6
18+
masking_rate_none : 0.05
19+
token_size : 32
20+
embed :
21+
net : transformer
22+
num_tokens : 1
23+
num_heads : 8
24+
dim_embed : 256
25+
num_blocks : 2
26+
embed_target_coords :
27+
net : linear
28+
dim_embed : 256
29+
target_readout :
30+
type : 'obs_value' # token or obs_value
31+
num_layers : 2
32+
num_heads : 4
33+
# sampling_rate : 0.2
34+
pred_head :
35+
ens_size : 1
36+
num_layers : 1

pyproject.toml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,12 @@ dev = [
4444
]
4545

4646

47+
[tool.black]
48+
49+
# Wide rows
50+
line-length = 100
51+
52+
4753
# The linting configuration
4854
[tool.ruff]
4955

src/weathergen/__init__.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@
1313
import traceback
1414

1515
from weathergen.train.trainer import Trainer
16-
from weathergen.utils.config import Config
16+
from weathergen.utils.config import Config, private_conf
17+
from weathergen.utils.logger import init_loggers
1718

1819

1920
####################################################################################################
@@ -27,6 +28,8 @@ def evaluate(
2728
save_samples=True,
2829
gridded_output_streams=[],
2930
):
31+
# TODO: move somewhere else
32+
init_loggers()
3033
# load config if specified
3134
cf = Config.load(run_id, epoch if epoch is not None else -1)
3235

@@ -45,7 +48,7 @@ def evaluate(
4548
# cf.start_date_val = 202201010400
4649
# cf.end_date_val = 202301010400
4750

48-
cf.step_hrs = 12
51+
# cf.step_hrs = 12
4952

5053
cf.shuffle = shuffle
5154

@@ -64,12 +67,15 @@ def evaluate(
6467

6568
####################################################################################################
6669
def train(run_id=None) -> None:
70+
# TODO: move somewhere else
71+
init_loggers()
72+
private_cf = private_conf()
6773
cf = Config()
6874

6975
# directory where input streams are specified
7076
# cf.streams_directory = './streams_large/'
71-
# cf.streams_directory = './streams_anemoi/'
72-
cf.streams_directory = "./streams_mixed/"
77+
cf.streams_directory = "./config/streams/streams_anemoi/"
78+
# cf.streams_directory = "./streams_mixed/"
7379

7480
# embed_orientation : 'channels' or 'columns'
7581
# channels: embedding is per channel for a token (#tokens=num_channels)
@@ -175,7 +181,8 @@ def train(run_id=None) -> None:
175181
cf.norm_type = "LayerNorm" #'LayerNorm' #'RMSNorm'
176182
cf.nn_module = "te"
177183

178-
cf.data_path = "/home/mlx/ai-ml/datasets/stable/"
184+
cf.data_path = private_cf["data_path"]
185+
# "/home/mlx/ai-ml/datasets/stable/"
179186
# cf.data_path = '/lus/h2resw01/fws4/lb/project/ai-ml/observations/v1'
180187
# cf.data_path = '/leonardo_scratch/large/userexternal/clessig0/obs/v1'
181188
cf.start_date = 201301010000
@@ -201,7 +208,7 @@ def train(run_id=None) -> None:
201208
trainer = Trainer(log_freq=20, checkpoint_freq=250, print_freq=10)
202209

203210
try:
204-
trainer.run(cf)
211+
trainer.run(cf, private_cf)
205212
except:
206213
extype, value, tb = sys.exc_info()
207214
traceback.print_exc()

src/weathergen/train/trainer.py

Lines changed: 35 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@
3535
from weathergen.utils.train_logger import TrainLogger
3636
from weathergen.utils.validation_io import write_validation
3737

38+
_logger = logging.getLogger(__name__)
39+
3840

3941
class Trainer(Trainer_Base):
4042
###########################################
@@ -47,7 +49,14 @@ def __init__(self, log_freq=20, checkpoint_freq=250, print_freq=10):
4749
self.print_freq = print_freq
4850

4951
###########################################
50-
def init(self, cf, run_id_contd=None, epoch_contd=None, run_id_new=False, run_mode="training"):
52+
def init(
53+
self,
54+
cf,
55+
run_id_contd=None,
56+
epoch_contd=None,
57+
run_id_new=False,
58+
run_mode="training",
59+
):
5160
self.cf = cf
5261

5362
if isinstance(run_id_new, str):
@@ -284,7 +293,7 @@ def evaluate_jac(self, cf, run_id, epoch, mode="row", date=None, obs_id=0, sampl
284293
)
285294

286295
###########################################
287-
def run(self, cf, run_id_contd=None, epoch_contd=None, run_id_new=False):
296+
def run(self, cf, private_cf, run_id_contd=None, epoch_contd=None, run_id_new=False):
288297
# general initalization
289298
self.init(cf, run_id_contd, epoch_contd, run_id_new)
290299

@@ -419,18 +428,23 @@ def run(self, cf, run_id_contd=None, epoch_contd=None, run_id_new=False):
419428
)
420429
self.grad_scaler = torch.amp.GradScaler("cuda")
421430

431+
assert len(self.dataset) > 0, f"No data found in {self.dataset}"
432+
422433
# lr is updated after each batch so account for this
434+
# TODO: conf should be read-only, do not modify the conf in flight
423435
cf.lr_steps = int((len(self.dataset) * cf.num_epochs) / cf.batch_size)
436+
424437
steps_decay = cf.lr_steps - cf.lr_steps_warmup - cf.lr_steps_cooldown
438+
_logger.debug(f"steps_decay={steps_decay} lr_steps={cf.lr_steps}")
425439
# ensure that steps_decay has a reasonable value
426440
if steps_decay < int(0.2 * cf.lr_steps):
427441
cf.lr_steps_warmup = int(0.1 * cf.lr_steps)
428442
cf.lr_steps_cooldown = int(0.05 * cf.lr_steps)
429443
steps_decay = cf.lr_steps - cf.lr_steps_warmup - cf.lr_steps_cooldown
430-
str = f"cf.lr_steps_warmup and cf.lr_steps_cooldown were larger than cf.lr_steps={cf.lr_steps}"
431-
str += ". The value have been adjusted to cf.lr_steps_warmup={cf.lr_steps_warmup} and "
432-
str += " cf.lr_steps_cooldown={cf.lr_steps_cooldown} so that steps_decay={steps_decay}."
433-
logging.getLogger("obslearn").warning("")
444+
s = f"cf.lr_steps_warmup and cf.lr_steps_cooldown were larger than cf.lr_steps={cf.lr_steps}"
445+
s += f". The value have been adjusted to cf.lr_steps_warmup={cf.lr_steps_warmup} and "
446+
s += f" cf.lr_steps_cooldown={cf.lr_steps_cooldown} so that steps_decay={steps_decay}."
447+
logging.getLogger("obslearn").warning(s)
434448
self.lr_scheduler = LearningRateScheduler(
435449
self.optimizer,
436450
cf.batch_size,
@@ -558,7 +572,11 @@ def compute_loss(
558572
)
559573
if tro_type == "token":
560574
pred = pred.reshape(
561-
[*pred.shape[:2], target.shape[-2], target.shape[-1] - gs]
575+
[
576+
*pred.shape[:2],
577+
target.shape[-2],
578+
target.shape[-1] - gs,
579+
]
562580
)
563581
pred = torch.cat([pred[:, i, :l] for i, l in enumerate(sl)], 1)
564582
else:
@@ -600,7 +618,7 @@ def compute_loss(
600618
target_data[mask, i],
601619
pred[:, mask, i],
602620
pred[:, mask, i].mean(0),
603-
pred[:, mask, i].std(0) if ens else torch.zeros(1),
621+
(pred[:, mask, i].std(0) if ens else torch.zeros(1)),
604622
)
605623
val_uw += temp.item()
606624
val = val + channel_loss_weight[i] * temp # * tw[jj]
@@ -613,9 +631,11 @@ def compute_loss(
613631
target_data[mask_nan[:, i], i],
614632
pred[:, mask_nan[:, i], i],
615633
pred[:, mask_nan[:, i], i].mean(0),
616-
pred[:, mask_nan[:, i], i].std(0)
617-
if ens
618-
else torch.zeros(1),
634+
(
635+
pred[:, mask_nan[:, i], i].std(0)
636+
if ens
637+
else torch.zeros(1)
638+
),
619639
)
620640
val_uw += temp.item()
621641
val = val + channel_loss_weight[i] * temp
@@ -1028,7 +1048,10 @@ def log_terminal(self, bidx, epoch):
10281048
)
10291049
print("\t", end="")
10301050
for i_obs, rt in enumerate(self.cf.streams):
1031-
print("{}".format(rt["name"]) + f" : {l_avg[0, i_obs]:0.4E} \t", end="")
1051+
print(
1052+
"{}".format(rt["name"]) + f" : {l_avg[0, i_obs]:0.4E} \t",
1053+
end="",
1054+
)
10321055
print("\n", flush=True)
10331056

10341057
self.t_start = time.time()

src/weathergen/train/trainer_base.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import itertools
1212
import logging
1313
import os
14-
import pathlib
14+
from pathlib import Path
1515

1616
import pynvml
1717
import torch
@@ -22,6 +22,8 @@
2222
from weathergen.train.utils import str_to_tensor, tensor_to_str
2323
from weathergen.utils.config import Config
2424

25+
_logger = logging.getLogger(__name__)
26+
2527

2628
class Trainer_Base:
2729
def __init__(self):
@@ -114,11 +116,14 @@ def init_streams(cf: Config, run_id_contd):
114116
# warn if specified dir does not exist
115117
if not os.path.isdir(cf.streams_directory):
116118
sd = cf.streams_directory
117-
logging.getLogger("obslearn").warning(f"Streams directory {sd} does not exist.")
119+
_logger.warning(f"Streams directory {sd} does not exist.")
118120

119121
# read all reportypes from directory, append to existing ones
120122
temp = {}
121-
for fh in sorted(pathlib.Path(cf.streams_directory).rglob("*.yml")):
123+
streams_dir = Path(cf.streams_directory).absolute()
124+
_logger.info(f"Reading streams from {streams_dir}")
125+
126+
for fh in sorted(streams_dir.rglob("*.yml")):
122127
stream_parsed = yaml.safe_load(fh.read_text())
123128
if stream_parsed is not None:
124129
temp.update(stream_parsed)
@@ -131,7 +136,7 @@ def init_streams(cf: Config, run_id_contd):
131136
# flatten list
132137
rts = list(itertools.chain.from_iterable(rts))
133138
if len(rts) != len(list(set(rts))):
134-
logging.getLogger("obslearn").warning("Duplicate reportypes specified.")
139+
_logger.warning("Duplicate reportypes specified.")
135140

136141
return cf
137142

src/weathergen/utils/config.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,10 @@
99

1010
import json
1111
import os
12+
from pathlib import Path
13+
from typing import Any
14+
15+
import yaml
1216

1317

1418
###########################################
@@ -63,3 +67,13 @@ def load(run_id, epoch=None):
6367
cf.__dict__ = json.loads(json_str[0])
6468

6569
return cf
70+
71+
72+
# Function that checks if WEATHERGEN_PRIVATE_HOME is set and returns it:
73+
def private_conf() -> Any:
74+
if "WEATHERGEN_PRIVATE_CONF" in os.environ:
75+
private_home = Path(os.environ["WEATHERGEN_PRIVATE_CONF"])
76+
private_conf = yaml.safe_load(private_home.read_text())
77+
return private_conf
78+
else:
79+
raise ValueError("WEATHERGEN_PRIVATE_CONF is not set.")

src/weathergen/utils/logger.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,21 @@ def format(self, record):
2323
return super().format(record)
2424

2525

26-
logger = logging.getLogger("obslearn")
27-
logger.setLevel(logging.DEBUG)
28-
ch = logging.StreamHandler()
29-
formatter = RelPathFormatter("%(pathname)s:%(lineno)d : %(levelname)-8s : %(message)s")
30-
ch.setFormatter(formatter)
31-
logger.handlers.clear()
32-
logger.addHandler(ch)
26+
def init_loggers():
27+
"""
28+
Initialize the logger for the package.
29+
30+
WARNING: this function resets all the logging handlers.
31+
"""
32+
formatter = RelPathFormatter("%(pathname)s:%(lineno)d : %(levelname)-8s : %(message)s")
33+
for package in ["obslearn", "weathergen"]:
34+
logger = logging.getLogger(package)
35+
logger.handlers.clear()
36+
logger.setLevel(logging.DEBUG)
37+
ch = logging.StreamHandler()
38+
ch.setFormatter(formatter)
39+
logger.addHandler(ch)
40+
41+
42+
# TODO: remove, it should be module-level loggers
43+
logger = logging.getLogger("weathergen")

src/weathergen/utils/run_id.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
# granted to it by virtue of its status as an intergovernmental organisation
88
# nor does it submit to any jurisdiction.
99

10-
from obslearn.train.utils import get_run_id
10+
from weathergen.train.utils import get_run_id
1111

1212
if __name__ == "__main__":
1313
print(get_run_id())

src/weathergenerator_utils/__init__.py

Lines changed: 0 additions & 13 deletions
This file was deleted.

0 commit comments

Comments
 (0)