Skip to content

Commit 145c18d

Browse files
authored
Reformatting with black (ecmwf#22)
* reformating code * cicd
1 parent d47ceca commit 145c18d

29 files changed

+5757
-4585
lines changed

.github/workflows/ci.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ jobs:
2626
- name: Run ruff (black)
2727
# Do not attempt to install the default dependencies, this is much faster.
2828
# Run temporarily on a sub directory before the main restyling.
29-
run: uv run --no-project --with "ruff==0.9.7" ruff format --check -n src/weathergenerator_utils
29+
run: uv run --no-project --with "ruff==0.9.7" ruff format --check -n src/
3030

3131
- name: Run ruff (flake8)
3232
# Do not attempt to install the default dependencies, this is much faster.

src/weathergen/__init__.py

Lines changed: 185 additions & 177 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
# In applying this licence, ECMWF does not waive the privileges and immunities
77
# granted to it by virtue of its status as an intergovernmental organisation
88
# nor does it submit to any jurisdiction.
9-
9+
1010
import time
1111
import sys
1212
import pdb
@@ -16,190 +16,198 @@
1616
from weathergen.train.trainer import Trainer
1717
from weathergen.train.utils import get_run_id
1818

19+
1920
####################################################################################################
20-
def evaluate( run_id, epoch, masking_mode = None, forecacast_steps = None,
21-
samples = 10000000, shuffle=False,
22-
save_samples=True, gridded_output_streams=[]) :
21+
def evaluate(
22+
run_id,
23+
epoch,
24+
masking_mode=None,
25+
forecacast_steps=None,
26+
samples=10000000,
27+
shuffle=False,
28+
save_samples=True,
29+
gridded_output_streams=[],
30+
):
31+
# load config if specified
32+
cf = Config.load(run_id, epoch if epoch is not None else -1)
33+
34+
cf.run_history += [(cf.run_id, cf.istep)]
35+
36+
cf.samples_per_validation = samples
37+
cf.log_validation = samples if save_samples else 0
38+
39+
if masking_mode is not None:
40+
cf.masking_mode = masking_mode
2341

24-
# load config if specified
25-
cf = Config.load( run_id, epoch if epoch is not None else -1)
26-
27-
cf.run_history += [ (cf.run_id, cf.istep) ]
42+
# Oct-Nov 2022
43+
cf.start_date_val = 202210011600
44+
cf.end_date_val = 202212010400
45+
# # 2022
46+
# cf.start_date_val = 202201010400
47+
# cf.end_date_val = 202301010400
2848

29-
cf.samples_per_validation = samples
30-
cf.log_validation = samples if save_samples else 0
49+
cf.step_hrs = 12
3150

32-
if masking_mode is not None :
33-
cf.masking_mode = masking_mode
51+
cf.shuffle = shuffle
3452

35-
# Oct-Nov 2022
36-
cf.start_date_val = 202210011600
37-
cf.end_date_val = 202212010400
38-
# # 2022
39-
# cf.start_date_val = 202201010400
40-
# cf.end_date_val = 202301010400
41-
42-
cf.step_hrs = 12
53+
cf.forecast_steps = forecacast_steps if forecacast_steps else cf.forecast_steps
54+
# cf.forecast_policy = 'fixed'
4355

44-
cf.shuffle = shuffle
56+
# cf.analysis_streams_output = ['Surface', 'Air', 'METEOSAT', 'ATMS', 'IASI', 'AMSR2']
57+
cf.analysis_streams_output = ["ERA5"]
4558

46-
cf.forecast_steps = forecacast_steps if forecacast_steps else cf.forecast_steps
47-
# cf.forecast_policy = 'fixed'
59+
# make sure number of loaders does not exceed requested samples
60+
cf.loader_num_workers = min(cf.loader_num_workers, samples)
4861

49-
# cf.analysis_streams_output = ['Surface', 'Air', 'METEOSAT', 'ATMS', 'IASI', 'AMSR2']
50-
cf.analysis_streams_output = ['ERA5']
51-
52-
# make sure number of loaders does not exceed requested samples
53-
cf.loader_num_workers = min( cf.loader_num_workers, samples)
62+
trainer = Trainer()
63+
trainer.evaluate(cf, run_id, epoch, True)
5464

55-
trainer = Trainer()
56-
trainer.evaluate( cf, run_id, epoch, True)
5765

5866
####################################################################################################
59-
def train( run_id = None) -> None :
60-
61-
cf = Config()
62-
63-
# directory where input streams are specified
64-
# cf.streams_directory = './streams_large/'
65-
# cf.streams_directory = './streams_anemoi/'
66-
cf.streams_directory = './streams_mixed/'
67-
68-
# embed_orientation : 'channels' or 'columns'
69-
# channels: embedding is per channel for a token (#tokens=num_channels)
70-
# columns: embedding is per "column", all channels are embedded together (#tokens=token_size)
71-
# the per-stream embedding paramters, in particular dim_embed, have to be chosen accordingly
72-
cf.embed_orientation = 'channels'
73-
cf.embed_local_coords = True
74-
# False since per cell coords are meaningless for cells
75-
cf.embed_centroids_local_coords = False
76-
cf.embed_size_centroids = 64
77-
cf.embed_unembed_mode = 'block'
78-
79-
cf.target_cell_local_prediction = True
80-
cf.target_coords_local = True
81-
82-
# parameters for local assimilation engine
83-
cf.ae_local_dim_embed = 1024 #2048 #1024
84-
cf.ae_local_num_blocks = 2
85-
cf.ae_local_num_heads = 16
86-
cf.ae_local_dropout_rate = 0.1
87-
cf.ae_local_with_qk_lnorm = True
88-
89-
# assimilation engine local -> global adapter
90-
cf.ae_local_num_queries = 2
91-
cf.ae_local_queries_per_cell = False
92-
cf.ae_adapter_num_heads = 16
93-
cf.ae_adapter_embed = 128
94-
cf.ae_adapter_with_qk_lnorm = True
95-
cf.ae_adapter_with_residual = True
96-
cf.ae_adapter_dropout_rate = 0.1
97-
98-
# parameters for global assimilation engine
99-
cf.ae_global_dim_embed = 2048
100-
cf.ae_global_num_blocks = 8
101-
cf.ae_global_num_heads = 32
102-
cf.ae_global_dropout_rate = 0.1
103-
cf.ae_global_with_qk_lnorm = True
104-
cf.ae_global_att_dense_rate = 0.2 # 0.25 : every 4-th block is dense attention
105-
cf.ae_global_block_factor = 64
106-
cf.ae_global_mlp_hidden_factor = 2
107-
108-
cf.pred_adapter_kv = False
109-
cf.pred_self_attention = True
110-
cf.pred_dyadic_dims = False
111-
cf.pred_mlp_adaln = True
112-
113-
# forecasting engine
114-
cf.forecast_delta_hrs = 0
115-
cf.forecast_steps = 0 # [j for j in range(1,11) for i in range(1)]
116-
cf.forecast_policy = None #'fixed', 'sequential'
117-
cf.forecast_freeze_model = False # False
118-
cf.forecast_att_dense_rate = 0.25
119-
120-
cf.fe_num_blocks = 0
121-
cf.fe_num_heads = 16
122-
cf.fe_dropout_rate = 0.1
123-
cf.fe_with_qk_lnorm = True
124-
125-
cf.healpix_level = 5
126-
127-
# working precision
128-
cf.with_mixed_precision = True
129-
cf.with_flash_attention = True
130-
if cf.with_flash_attention :
131-
assert cf.with_mixed_precision
132-
# compile entire model
133-
cf.compile_model = False
134-
135-
cf.with_fsdp = True
136-
137-
cf.loss_fcts = [['mse', 1.0]]
138-
cf.loss_fcts_val = [['mse', 1.0]]
139-
# cf.loss_fcts = [['mse', 0.5], ['stats', 0.5]]
140-
# cf.loss_fcts_val = [['mse', 0.5], ['stats', 0.5]]
141-
142-
cf.batch_size = 1
143-
cf.batch_size_validation = 1
144-
145-
# forecast
146-
cf.masking_mode = 'forecast'
147-
cf.masking_rate = 0.0
148-
cf.masking_rate_sampling = True #False
149-
cf.sampling_rate_target = 1.0
150-
151-
cf.num_epochs = 24
152-
cf.samples_per_epoch = 4096
153-
cf.samples_per_validation = 512
154-
cf.shuffle = True
155-
156-
cf.lr_scaling_policy = 'sqrt'
157-
cf.lr_start = 0.000001
158-
cf.lr_max = 0.00003
159-
cf.lr_final_decay = 0.000001
160-
cf.lr_final = 0.0
161-
cf.lr_steps_warmup = 256
162-
cf.lr_steps_cooldown = 4096
163-
cf.lr_policy_warmup = 'cosine'
164-
cf.lr_policy_decay = 'linear'
165-
cf.lr_policy_cooldown = 'linear'
166-
167-
cf.grad_clip = 5.
168-
cf.weight_decay = 0.1
169-
cf.norm_type = 'LayerNorm' #'LayerNorm' #'RMSNorm'
170-
cf.nn_module = 'te'
171-
172-
cf.data_path = '/home/mlx/ai-ml/datasets/stable/'
173-
# cf.data_path = '/lus/h2resw01/fws4/lb/project/ai-ml/observations/v1'
174-
# cf.data_path = '/leonardo_scratch/large/userexternal/clessig0/obs/v1'
175-
cf.start_date = 201301010000
176-
cf.end_date = 202012310000
177-
cf.start_date_val = 202101010000
178-
cf.end_date_val = 202201010000
179-
cf.len_hrs = 6
180-
cf.step_hrs = 6
181-
cf.input_window_steps = 1
182-
183-
cf.val_initial = False
184-
185-
cf.loader_num_workers = 8
186-
cf.data_loader_rng_seed = int(time.time())
187-
cf.log_validation = 0
188-
189-
cf.istep = 0
190-
cf.run_history = []
191-
192-
cf.run_id = run_id
193-
cf.desc = ''
194-
195-
trainer = Trainer( log_freq=20, checkpoint_freq=250, print_freq=10)
196-
197-
try :
198-
trainer.run( cf)
199-
except :
200-
extype, value, tb = sys.exc_info()
201-
traceback.print_exc()
202-
pdb.post_mortem(tb)
203-
204-
if __name__ == '__main__':
205-
train()
67+
def train(run_id=None) -> None:
68+
cf = Config()
69+
70+
# directory where input streams are specified
71+
# cf.streams_directory = './streams_large/'
72+
# cf.streams_directory = './streams_anemoi/'
73+
cf.streams_directory = "./streams_mixed/"
74+
75+
# embed_orientation : 'channels' or 'columns'
76+
# channels: embedding is per channel for a token (#tokens=num_channels)
77+
# columns: embedding is per "column", all channels are embedded together (#tokens=token_size)
78+
# the per-stream embedding paramters, in particular dim_embed, have to be chosen accordingly
79+
cf.embed_orientation = "channels"
80+
cf.embed_local_coords = True
81+
# False since per cell coords are meaningless for cells
82+
cf.embed_centroids_local_coords = False
83+
cf.embed_size_centroids = 64
84+
cf.embed_unembed_mode = "block"
85+
86+
cf.target_cell_local_prediction = True
87+
cf.target_coords_local = True
88+
89+
# parameters for local assimilation engine
90+
cf.ae_local_dim_embed = 1024 # 2048 #1024
91+
cf.ae_local_num_blocks = 2
92+
cf.ae_local_num_heads = 16
93+
cf.ae_local_dropout_rate = 0.1
94+
cf.ae_local_with_qk_lnorm = True
95+
96+
# assimilation engine local -> global adapter
97+
cf.ae_local_num_queries = 2
98+
cf.ae_local_queries_per_cell = False
99+
cf.ae_adapter_num_heads = 16
100+
cf.ae_adapter_embed = 128
101+
cf.ae_adapter_with_qk_lnorm = True
102+
cf.ae_adapter_with_residual = True
103+
cf.ae_adapter_dropout_rate = 0.1
104+
105+
# parameters for global assimilation engine
106+
cf.ae_global_dim_embed = 2048
107+
cf.ae_global_num_blocks = 8
108+
cf.ae_global_num_heads = 32
109+
cf.ae_global_dropout_rate = 0.1
110+
cf.ae_global_with_qk_lnorm = True
111+
cf.ae_global_att_dense_rate = 0.2 # 0.25 : every 4-th block is dense attention
112+
cf.ae_global_block_factor = 64
113+
cf.ae_global_mlp_hidden_factor = 2
114+
115+
cf.pred_adapter_kv = False
116+
cf.pred_self_attention = True
117+
cf.pred_dyadic_dims = False
118+
cf.pred_mlp_adaln = True
119+
120+
# forecasting engine
121+
cf.forecast_delta_hrs = 0
122+
cf.forecast_steps = 0 # [j for j in range(1,11) for i in range(1)]
123+
cf.forecast_policy = None #'fixed', 'sequential'
124+
cf.forecast_freeze_model = False # False
125+
cf.forecast_att_dense_rate = 0.25
126+
127+
cf.fe_num_blocks = 0
128+
cf.fe_num_heads = 16
129+
cf.fe_dropout_rate = 0.1
130+
cf.fe_with_qk_lnorm = True
131+
132+
cf.healpix_level = 5
133+
134+
# working precision
135+
cf.with_mixed_precision = True
136+
cf.with_flash_attention = True
137+
if cf.with_flash_attention:
138+
assert cf.with_mixed_precision
139+
# compile entire model
140+
cf.compile_model = False
141+
142+
cf.with_fsdp = True
143+
144+
cf.loss_fcts = [["mse", 1.0]]
145+
cf.loss_fcts_val = [["mse", 1.0]]
146+
# cf.loss_fcts = [['mse', 0.5], ['stats', 0.5]]
147+
# cf.loss_fcts_val = [['mse', 0.5], ['stats', 0.5]]
148+
149+
cf.batch_size = 1
150+
cf.batch_size_validation = 1
151+
152+
# forecast
153+
cf.masking_mode = "forecast"
154+
cf.masking_rate = 0.0
155+
cf.masking_rate_sampling = True # False
156+
cf.sampling_rate_target = 1.0
157+
158+
cf.num_epochs = 24
159+
cf.samples_per_epoch = 4096
160+
cf.samples_per_validation = 512
161+
cf.shuffle = True
162+
163+
cf.lr_scaling_policy = "sqrt"
164+
cf.lr_start = 0.000001
165+
cf.lr_max = 0.00003
166+
cf.lr_final_decay = 0.000001
167+
cf.lr_final = 0.0
168+
cf.lr_steps_warmup = 256
169+
cf.lr_steps_cooldown = 4096
170+
cf.lr_policy_warmup = "cosine"
171+
cf.lr_policy_decay = "linear"
172+
cf.lr_policy_cooldown = "linear"
173+
174+
cf.grad_clip = 5.0
175+
cf.weight_decay = 0.1
176+
cf.norm_type = "LayerNorm" #'LayerNorm' #'RMSNorm'
177+
cf.nn_module = "te"
178+
179+
cf.data_path = "/home/mlx/ai-ml/datasets/stable/"
180+
# cf.data_path = '/lus/h2resw01/fws4/lb/project/ai-ml/observations/v1'
181+
# cf.data_path = '/leonardo_scratch/large/userexternal/clessig0/obs/v1'
182+
cf.start_date = 201301010000
183+
cf.end_date = 202012310000
184+
cf.start_date_val = 202101010000
185+
cf.end_date_val = 202201010000
186+
cf.len_hrs = 6
187+
cf.step_hrs = 6
188+
cf.input_window_steps = 1
189+
190+
cf.val_initial = False
191+
192+
cf.loader_num_workers = 8
193+
cf.data_loader_rng_seed = int(time.time())
194+
cf.log_validation = 0
195+
196+
cf.istep = 0
197+
cf.run_history = []
198+
199+
cf.run_id = run_id
200+
cf.desc = ""
201+
202+
trainer = Trainer(log_freq=20, checkpoint_freq=250, print_freq=10)
203+
204+
try:
205+
trainer.run(cf)
206+
except:
207+
extype, value, tb = sys.exc_info()
208+
traceback.print_exc()
209+
pdb.post_mortem(tb)
210+
211+
212+
if __name__ == "__main__":
213+
train()

0 commit comments

Comments
 (0)