Skip to content

Commit 8671505

Browse files
authored
Sgrasse/develop/issue 106 (ecmwf#159)
* update dependencies to include 'omegaconf' * rewrite deserialization of streams config for omegaconf * move deserialization of stream configs * fix streams deserialization * rewrite deserialization of private config for omegaconf * rewrite deserialization of overwrite config for omegaconf * add method to deserialize default config * consistent names for deserialization methods * transform methods of Config object into functions * Change implementation of config object to OmegaConf * add method to load/create new omega config * formatting * fix: create "result" directory if not already present * fix: set 'Path' as default factory in CLI Using 'Path' as default factory in CLI instead of 'str' avoids the 'str'-'Path' cast potentially casting 'None' to 'Path' (which fails). The cast would be needed to comply with the 'Path' | 'None' interface of subsequent methods. * fix: correctly call 'OmegaConf.to_container' * fix: change return type of stream deserialization to comply with usage * remove unused private_conf from call to trainer * load default config from file instead of setting in code * make evaluate / train continue compatible * fix: use -2 to select None as epoch value * remove redundant options that dont differ from default values * change load_config to account for evaluation/continue * use 'load_config' for evaluation/continue * autoformatting * reenable --config flag for training * Small improvements suggested by review. - docstrings - pure stylistic code changes - slightly change logging/exceptions - improved comments - pin Omegaconf version - remove faulty code * remove default None args * mark stuff private in `config` * remove __main__ from `trainer_base` * rework `load_streams` after comments * update uv.lock, fix missing adjustment in evaluate * cosmetic fixes
1 parent b034986 commit 8671505

File tree

9 files changed

+286
-304
lines changed

9 files changed

+286
-304
lines changed

config/default_config.yml

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
streams_directory: "./config/streams/streams_anemoi/"
2+
3+
embed_orientation: "channels"
4+
embed_local_coords: True
5+
embed_centroids_local_coords: False
6+
embed_size_centroids: 64
7+
embed_unembed_mode: "block"
8+
9+
target_cell_local_prediction: True
10+
target_coords_local: True
11+
12+
ae_local_dim_embed: 1024
13+
ae_local_num_blocks: 2
14+
ae_local_num_heads: 16
15+
ae_local_dropout_rate: 0.1
16+
ae_local_with_qk_lnorm: True
17+
18+
ae_local_num_queries: 2
19+
ae_local_queries_per_cell: False
20+
ae_adapter_num_heads: 16
21+
ae_adapter_embed: 128
22+
ae_adapter_with_qk_lnorm: True
23+
ae_adapter_with_residual: True
24+
ae_adapter_dropout_rate: 0.1
25+
26+
ae_global_dim_embed: 2048
27+
ae_global_num_blocks: 8
28+
ae_global_num_heads: 32
29+
ae_global_dropout_rate: 0.1
30+
ae_global_with_qk_lnorm: True
31+
ae_global_att_dense_rate: 0.2
32+
ae_global_block_factor: 64
33+
ae_global_mlp_hidden_factor: 2
34+
35+
pred_adapter_kv: False
36+
pred_self_attention: True
37+
pred_dyadic_dims: False
38+
pred_mlp_adaln: True
39+
40+
forecast_delta_hrs: 0
41+
forecast_steps: 0
42+
forecast_policy: null
43+
forecast_freeze_model: False
44+
forecast_att_dense_rate: 0.25
45+
fe_num_blocks: 0
46+
fe_num_heads: 16
47+
fe_dropout_rate: 0.1
48+
fe_with_qk_lnorm: True
49+
50+
healpix_level: 5
51+
52+
with_mixed_precision: True
53+
with_flash_attention: True
54+
compile_model: False
55+
with_fsdp: True
56+
57+
loss_fcts:
58+
-
59+
- "mse"
60+
- 1.0
61+
loss_fcts_val:
62+
-
63+
- "mse"
64+
- 1.0
65+
66+
batch_size: 1
67+
batch_size_validation: 1
68+
69+
masking_mode: "forecast"
70+
masking_rate: 0.0
71+
masking_rate_sampling: True
72+
sampling_rate_target: 1.0
73+
74+
num_epochs: 24
75+
samples_per_epoch: 4096
76+
samples_per_validation: 512
77+
shuffle: True
78+
79+
lr_scaling_policy: "sqrt"
80+
lr_start: 0.000001
81+
lr_max: 0.00003
82+
lr_final_decay: 0.000001
83+
lr_final: 0.0
84+
lr_steps_warmup: 256
85+
lr_steps_cooldown: 4096
86+
lr_policy_warmup: "cosine"
87+
lr_policy_decay: "linear"
88+
lr_policy_cooldown: "linear"
89+
90+
grad_clip: 5.0
91+
weight_decay: 0.1
92+
norm_type: "LayerNorm"
93+
nn_module: "te"
94+
95+
start_date: 201301010000
96+
end_date: 202012310000
97+
start_date_val: 202101010000
98+
end_date_val: 202201010000
99+
len_hrs: 6
100+
step_hrs: 6
101+
input_window_steps: 1
102+
103+
val_initial: False
104+
105+
loader_num_workers: 8
106+
log_validation: 0
107+
108+
istep: 0
109+
run_history: []
110+
111+
desc: ""
112+
data_loader_rng_seed: ???
113+
run_id: ???

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ dependencies = [
2424
'psutil',
2525
"flash-attn; sys_platform == 'linux'",
2626
"polars~=1.25.2",
27+
"omegaconf~=2.3.0",
2728
]
2829

2930
[project.urls]

src/weathergen/__init__.py

Lines changed: 13 additions & 167 deletions
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,15 @@
1212
import sys
1313
import time
1414
import traceback
15+
from pathlib import Path
1516

1617
import pandas as pd
1718

19+
import weathergen.utils.config as config
1820
from weathergen.train.trainer import Trainer
19-
from weathergen.utils.config import Config, load_overwrite_conf, load_private_conf
2021
from weathergen.utils.logger import init_loggers
2122

2223

23-
####################################################################################################
2424
def evaluate():
2525
"""
2626
Evaluation function for WeatherGenerator model.
@@ -79,26 +79,17 @@ def evaluate():
7979
)
8080
parser.add_argument(
8181
"--private_config",
82-
type=str,
82+
type=Path,
8383
default=None,
8484
help="Path to private configuration file for paths.",
8585
)
8686

8787
args = parser.parse_args()
8888

89-
# get the paths from the private config
90-
private_cf = load_private_conf(args.private_config)
91-
9289
# TODO: move somewhere else
9390
init_loggers()
9491

95-
# load config: if run_id is full path, it loads from there
96-
model_path = private_cf["model_path"] if "model_path" in private_cf.keys() else "./models"
97-
cf = Config.load(args.run_id, args.epoch, model_path)
98-
99-
# add parameters from private (paths) config
100-
for k, v in private_cf.items():
101-
setattr(cf, k, v)
92+
cf = config.load_config(args.private_config, args.run_id, args.epoch, None)
10293

10394
cf.run_history += [(cf.run_id, cf.istep)]
10495

@@ -154,7 +145,7 @@ def train_continue() -> None:
154145
)
155146
parser.add_argument(
156147
"--private_config",
157-
type=str,
148+
type=Path,
158149
default=None,
159150
help="Path to private configuration file for paths.",
160151
)
@@ -165,16 +156,10 @@ def train_continue() -> None:
165156
)
166157

167158
args = parser.parse_args()
168-
# get the paths from the private config
169-
private_cf = load_private_conf(args.private_config)
170159

171-
# load config if specified
172-
model_path = private_cf["model_path"] if "model_path" in private_cf.keys() else "./models"
173-
cf = Config.load(args.run_id, args.epoch, model_path)
160+
cf = config.load_config(args.private_config, args.run_id, args.epoch, None)
174161

175162
# track history of run to ensure traceability of results
176-
if "run_history" not in cf.__dict__:
177-
cf.run_history = []
178163
cf.run_history += [(cf.run_id, cf.istep)]
179164

180165
#########################
@@ -210,7 +195,7 @@ def train_continue() -> None:
210195
cf.istep = 0
211196

212197
trainer = Trainer()
213-
trainer.run(cf, private_cf, args.run_id, args.epoch, args.run_id_new)
198+
trainer.run(cf, args.run_id, args.epoch, args.run_id_new)
214199

215200

216201
####################################################################################################
@@ -235,171 +220,32 @@ def train() -> None:
235220
)
236221
parser.add_argument(
237222
"--private_config",
238-
type=str,
223+
type=Path,
239224
default=None,
240225
help="Path to private configuration file for paths",
241226
)
242227
parser.add_argument(
243228
"--config",
244-
type=str,
229+
type=Path,
245230
default=None,
246-
help="Path to private configuration file for overwriting the defaults in the function body. Defaults to None.",
231+
help="Optional experiment specfic configuration file",
247232
)
248233

249234
args = parser.parse_args()
250235

251236
# TODO: move somewhere else
252237
init_loggers()
253238

254-
# get the non-default configs: private and overwrite
255-
private_cf = load_private_conf(args.private_config)
256-
overwrite_cf = load_overwrite_conf(args.config)
257-
258-
cf = Config()
259-
260-
# directory where input streams are specified
261-
# cf.streams_directory = './streams_large/'
262-
cf.streams_directory = "./config/streams/streams_anemoi/"
263-
# cf.streams_directory = "./config/streams/streams_mixed/"
264-
# cf.streams_directory = "./streams_mixed/"
265-
266-
# embed_orientation : 'channels' or 'columns'
267-
# channels: embedding is per channel for a token (#tokens=num_channels)
268-
# columns: embedding is per "column", all channels are embedded together (#tokens=token_size)
269-
# the per-stream embedding paramters, in particular dim_embed, have to be chosen accordingly
270-
cf.embed_orientation = "channels"
271-
cf.embed_local_coords = True
272-
# False since per cell coords are meaningless for cells
273-
cf.embed_centroids_local_coords = False
274-
cf.embed_size_centroids = 64
275-
cf.embed_unembed_mode = "block"
276-
277-
cf.target_cell_local_prediction = True
278-
cf.target_coords_local = True
279-
280-
# parameters for local assimilation engine
281-
cf.ae_local_dim_embed = 1024 # 2048 #1024
282-
cf.ae_local_num_blocks = 2
283-
cf.ae_local_num_heads = 16
284-
cf.ae_local_dropout_rate = 0.1
285-
cf.ae_local_with_qk_lnorm = True
286-
287-
# assimilation engine local -> global adapter
288-
cf.ae_local_num_queries = 2
289-
cf.ae_local_queries_per_cell = False
290-
cf.ae_adapter_num_heads = 16
291-
cf.ae_adapter_embed = 128
292-
cf.ae_adapter_with_qk_lnorm = True
293-
cf.ae_adapter_with_residual = True
294-
cf.ae_adapter_dropout_rate = 0.1
295-
296-
# parameters for global assimilation engine
297-
cf.ae_global_dim_embed = 2048
298-
cf.ae_global_num_blocks = 8
299-
cf.ae_global_num_heads = 32
300-
cf.ae_global_dropout_rate = 0.1
301-
cf.ae_global_with_qk_lnorm = True
302-
cf.ae_global_att_dense_rate = 0.2 # 0.25 : every 4-th block is dense attention
303-
cf.ae_global_block_factor = 64
304-
cf.ae_global_mlp_hidden_factor = 2
305-
306-
cf.pred_adapter_kv = False
307-
cf.pred_self_attention = True
308-
cf.pred_dyadic_dims = False
309-
cf.pred_mlp_adaln = True
310-
311-
# forecasting engine
312-
cf.forecast_delta_hrs = 0
313-
cf.forecast_steps = 0 # [j for j in range(1,11) for i in range(1)]
314-
cf.forecast_policy = None #'fixed', 'sequential'
315-
cf.forecast_freeze_model = False # False
316-
cf.forecast_att_dense_rate = 0.25
317-
318-
cf.fe_num_blocks = 0
319-
cf.fe_num_heads = 16
320-
cf.fe_dropout_rate = 0.1
321-
cf.fe_with_qk_lnorm = True
322-
323-
cf.healpix_level = 5
324-
325-
# working precision
326-
cf.with_mixed_precision = True
327-
cf.with_flash_attention = True
239+
cf = config.load_config(args.private_config, None, None, args.config)
240+
328241
if cf.with_flash_attention:
329242
assert cf.with_mixed_precision
330-
# compile entire model
331-
cf.compile_model = False
332-
333-
cf.with_fsdp = True
334-
335-
cf.loss_fcts = [["mse", 1.0]]
336-
cf.loss_fcts_val = [["mse", 1.0]]
337-
# cf.loss_fcts = [['mse', 0.5], ['stats', 0.5]]
338-
# cf.loss_fcts_val = [['mse', 0.5], ['stats', 0.5]]
339-
340-
cf.batch_size = 1
341-
cf.batch_size_validation = 1
342-
343-
# forecast
344-
cf.masking_mode = "forecast"
345-
cf.masking_rate = 0.0
346-
cf.masking_rate_sampling = True # False
347-
cf.sampling_rate_target = 1.0
348-
349-
cf.num_epochs = 24
350-
cf.samples_per_epoch = 4096
351-
cf.samples_per_validation = 512
352-
cf.shuffle = True
353-
354-
cf.lr_scaling_policy = "sqrt"
355-
cf.lr_start = 0.000001
356-
cf.lr_max = 0.00003
357-
cf.lr_final_decay = 0.000001
358-
cf.lr_final = 0.0
359-
cf.lr_steps_warmup = 256
360-
cf.lr_steps_cooldown = 4096
361-
cf.lr_policy_warmup = "cosine"
362-
cf.lr_policy_decay = "linear"
363-
cf.lr_policy_cooldown = "linear"
364-
365-
cf.grad_clip = 5.0
366-
cf.weight_decay = 0.1
367-
cf.norm_type = "LayerNorm" #'LayerNorm' #'RMSNorm'
368-
cf.nn_module = "te"
369-
370-
cf.start_date = 201301010000
371-
cf.end_date = 202012310000
372-
cf.start_date_val = 202101010000
373-
cf.end_date_val = 202201010000
374-
cf.len_hrs = 6
375-
cf.step_hrs = 6
376-
cf.input_window_steps = 1
377-
378-
cf.val_initial = False
379-
380-
cf.loader_num_workers = 8
381243
cf.data_loader_rng_seed = int(time.time())
382-
cf.log_validation = 0
383-
384-
cf.istep = 0
385-
cf.run_history = []
386-
387-
cf.run_id = args.run_id
388-
cf.desc = ""
389-
390-
# overwrite parameters from private config
391-
for k, v in private_cf.items():
392-
setattr(cf, k, v)
393-
cf.data_path = private_cf["data_path_anemoi"] # for backward compatibility
394-
395-
# overwrite parameters from overwrite config
396-
for k, v in overwrite_cf.items():
397-
setattr(cf, k, v)
398244

399245
trainer = Trainer(log_freq=20, checkpoint_freq=250, print_freq=10)
400246

401247
try:
402-
trainer.run(cf, private_cf)
248+
trainer.run(cf)
403249
except Exception:
404250
extype, value, tb = sys.exc_info()
405251
traceback.print_exc()

0 commit comments

Comments
 (0)