66# In applying this licence, ECMWF does not waive the privileges and immunities
77# granted to it by virtue of its status as an intergovernmental organisation
88# nor does it submit to any jurisdiction.
9-
9+
1010import time
1111import sys
1212import pdb
1616from weathergen .train .trainer import Trainer
1717from weathergen .train .utils import get_run_id
1818
19+
1920####################################################################################################
20- def evaluate ( run_id , epoch , masking_mode = None , forecacast_steps = None ,
21- samples = 10000000 , shuffle = False ,
22- save_samples = True , gridded_output_streams = []) :
21+ def evaluate (
22+ run_id ,
23+ epoch ,
24+ masking_mode = None ,
25+ forecacast_steps = None ,
26+ samples = 10000000 ,
27+ shuffle = False ,
28+ save_samples = True ,
29+ gridded_output_streams = [],
30+ ):
31+ # load config if specified
32+ cf = Config .load (run_id , epoch if epoch is not None else - 1 )
33+
34+ cf .run_history += [(cf .run_id , cf .istep )]
35+
36+ cf .samples_per_validation = samples
37+ cf .log_validation = samples if save_samples else 0
38+
39+ if masking_mode is not None :
40+ cf .masking_mode = masking_mode
2341
24- # load config if specified
25- cf = Config .load ( run_id , epoch if epoch is not None else - 1 )
26-
27- cf .run_history += [ (cf .run_id , cf .istep ) ]
42+ # Oct-Nov 2022
43+ cf .start_date_val = 202210011600
44+ cf .end_date_val = 202212010400
45+ # # 2022
46+ # cf.start_date_val = 202201010400
47+ # cf.end_date_val = 202301010400
2848
29- cf .samples_per_validation = samples
30- cf .log_validation = samples if save_samples else 0
49+ cf .step_hrs = 12
3150
32- if masking_mode is not None :
33- cf .masking_mode = masking_mode
51+ cf .shuffle = shuffle
3452
35- # Oct-Nov 2022
36- cf .start_date_val = 202210011600
37- cf .end_date_val = 202212010400
38- # # 2022
39- # cf.start_date_val = 202201010400
40- # cf.end_date_val = 202301010400
41-
42- cf .step_hrs = 12
53+ cf .forecast_steps = forecacast_steps if forecacast_steps else cf .forecast_steps
54+ # cf.forecast_policy = 'fixed'
4355
44- cf .shuffle = shuffle
56+ # cf.analysis_streams_output = ['Surface', 'Air', 'METEOSAT', 'ATMS', 'IASI', 'AMSR2']
57+ cf .analysis_streams_output = ["ERA5" ]
4558
46- cf . forecast_steps = forecacast_steps if forecacast_steps else cf . forecast_steps
47- # cf.forecast_policy = 'fixed'
59+ # make sure number of loaders does not exceed requested samples
60+ cf .loader_num_workers = min ( cf . loader_num_workers , samples )
4861
49- # cf.analysis_streams_output = ['Surface', 'Air', 'METEOSAT', 'ATMS', 'IASI', 'AMSR2']
50- cf .analysis_streams_output = ['ERA5' ]
51-
52- # make sure number of loaders does not exceed requested samples
53- cf .loader_num_workers = min ( cf .loader_num_workers , samples )
62+ trainer = Trainer ()
63+ trainer .evaluate (cf , run_id , epoch , True )
5464
55- trainer = Trainer ()
56- trainer .evaluate ( cf , run_id , epoch , True )
5765
5866####################################################################################################
59- def train ( run_id = None ) -> None :
60-
61- cf = Config ()
62-
63- # directory where input streams are specified
64- # cf.streams_directory = './streams_large /'
65- # cf.streams_directory = './streams_anemoi/'
66- cf . streams_directory = './streams_mixed/'
67-
68- # embed_orientation : ' channels' or 'columns'
69- # channels: embedding is per channel for a token (#tokens=num_channels )
70- # columns: embedding is per "column", all channels are embedded together (#tokens=token_size)
71- # the per-stream embedding paramters, in particular dim_embed, have to be chosen accordingly
72- cf .embed_orientation = 'channels'
73- cf . embed_local_coords = True
74- # False since per cell coords are meaningless for cells
75- cf .embed_centroids_local_coords = False
76- cf .embed_size_centroids = 64
77- cf . embed_unembed_mode = 'block'
78-
79- cf .target_cell_local_prediction = True
80- cf . target_coords_local = True
81-
82- # parameters for local assimilation engine
83- cf .ae_local_dim_embed = 1024 #2048 #1024
84- cf .ae_local_num_blocks = 2
85- cf .ae_local_num_heads = 16
86- cf .ae_local_dropout_rate = 0.1
87- cf . ae_local_with_qk_lnorm = True
88-
89- # assimilation engine local -> global adapter
90- cf .ae_local_num_queries = 2
91- cf .ae_local_queries_per_cell = False
92- cf .ae_adapter_num_heads = 16
93- cf .ae_adapter_embed = 128
94- cf .ae_adapter_with_qk_lnorm = True
95- cf .ae_adapter_with_residual = True
96- cf . ae_adapter_dropout_rate = 0.1
97-
98- # parameters for global assimilation engine
99- cf .ae_global_dim_embed = 2048
100- cf .ae_global_num_blocks = 8
101- cf .ae_global_num_heads = 32
102- cf .ae_global_dropout_rate = 0.1
103- cf .ae_global_with_qk_lnorm = True
104- cf .ae_global_att_dense_rate = 0.2 # 0.25 : every 4-th block is dense attention
105- cf .ae_global_block_factor = 64
106- cf . ae_global_mlp_hidden_factor = 2
107-
108- cf .pred_adapter_kv = False
109- cf .pred_self_attention = True
110- cf .pred_dyadic_dims = False
111- cf . pred_mlp_adaln = True
112-
113- # forecasting engine
114- cf .forecast_delta_hrs = 0
115- cf .forecast_steps = 0 # [j for j in range(1,11) for i in range(1)]
116- cf .forecast_policy = None #'fixed', 'sequential'
117- cf .forecast_freeze_model = False # False
118- cf . forecast_att_dense_rate = 0.25
119-
120- cf .fe_num_blocks = 0
121- cf .fe_num_heads = 16
122- cf .fe_dropout_rate = 0.1
123- cf . fe_with_qk_lnorm = True
124-
125- cf . healpix_level = 5
126-
127- # working precision
128- cf .with_mixed_precision = True
129- cf .with_flash_attention = True
130- if cf .with_flash_attention :
131- assert cf . with_mixed_precision
132- # compile entire model
133- cf . compile_model = False
134-
135- cf . with_fsdp = True
136-
137- cf .loss_fcts = [[' mse' , 1.0 ]]
138- cf .loss_fcts_val = [['mse' , 1.0 ]]
139- # cf.loss_fcts = [['mse', 0.5], ['stats', 0.5]]
140- # cf.loss_fcts_val = [['mse', 0.5], ['stats', 0.5]]
141-
142- cf .batch_size = 1
143- cf . batch_size_validation = 1
144-
145- # forecast
146- cf .masking_mode = 'forecast'
147- cf .masking_rate = 0.0
148- cf .masking_rate_sampling = True #False
149- cf . sampling_rate_target = 1.0
150-
151- cf .num_epochs = 24
152- cf .samples_per_epoch = 4096
153- cf .samples_per_validation = 512
154- cf . shuffle = True
155-
156- cf .lr_scaling_policy = 'sqrt'
157- cf .lr_start = 0.000001
158- cf .lr_max = 0.00003
159- cf .lr_final_decay = 0.000001
160- cf .lr_final = 0.0
161- cf .lr_steps_warmup = 256
162- cf .lr_steps_cooldown = 4096
163- cf .lr_policy_warmup = 'cosine'
164- cf .lr_policy_decay = ' linear'
165- cf . lr_policy_cooldown = 'linear'
166-
167- cf .grad_clip = 5.
168- cf .weight_decay = 0.1
169- cf .norm_type = 'LayerNorm' #'LayerNorm' #'RMSNorm'
170- cf . nn_module = 'te'
171-
172- cf .data_path = '/home/mlx/ ai-ml/datasets/stable/ '
173- # cf.data_path = '/lus/h2resw01/fws4/lb/project/ai-ml/observations /v1'
174- # cf.data_path = '/leonardo_scratch/large/userexternal/clessig0/obs/v1'
175- cf .start_date = 201301010000
176- cf .end_date = 202012310000
177- cf .start_date_val = 202101010000
178- cf .end_date_val = 202201010000
179- cf .len_hrs = 6
180- cf .step_hrs = 6
181- cf . input_window_steps = 1
182-
183- cf . val_initial = False
184-
185- cf .loader_num_workers = 8
186- cf .data_loader_rng_seed = int ( time . time ())
187- cf . log_validation = 0
188-
189- cf .istep = 0
190- cf . run_history = []
191-
192- cf .run_id = run_id
193- cf . desc = ''
194-
195- trainer = Trainer ( log_freq = 20 , checkpoint_freq = 250 , print_freq = 10 )
196-
197- try :
198- trainer . run ( cf )
199- except :
200- extype , value , tb = sys . exc_info ()
201- traceback . print_exc ( )
202- pdb . post_mortem ( tb )
203-
204- if __name__ == ' __main__' :
205- train ()
67+ def train (run_id = None ) -> None :
68+ cf = Config ()
69+
70+ # directory where input streams are specified
71+ # cf.streams_directory = './streams_large/'
72+ # cf.streams_directory = './streams_anemoi /'
73+ cf .streams_directory = "./streams_mixed/"
74+
75+ # embed_orientation : 'channels' or 'columns'
76+ # channels: embedding is per channel for a token (#tokens=num_channels)
77+ # columns: embedding is per "column", all channels are embedded together (#tokens=token_size )
78+ # the per-stream embedding paramters, in particular dim_embed, have to be chosen accordingly
79+ cf . embed_orientation = "channels"
80+ cf .embed_local_coords = True
81+ # False since per cell coords are meaningless for cells
82+ cf . embed_centroids_local_coords = False
83+ cf .embed_size_centroids = 64
84+ cf .embed_unembed_mode = "block"
85+
86+ cf . target_cell_local_prediction = True
87+ cf .target_coords_local = True
88+
89+ # parameters for local assimilation engine
90+ cf . ae_local_dim_embed = 1024 # 2048 #1024
91+ cf .ae_local_num_blocks = 2
92+ cf .ae_local_num_heads = 16
93+ cf .ae_local_dropout_rate = 0.1
94+ cf .ae_local_with_qk_lnorm = True
95+
96+ # assimilation engine local -> global adapter
97+ cf . ae_local_num_queries = 2
98+ cf .ae_local_queries_per_cell = False
99+ cf .ae_adapter_num_heads = 16
100+ cf .ae_adapter_embed = 128
101+ cf .ae_adapter_with_qk_lnorm = True
102+ cf .ae_adapter_with_residual = True
103+ cf .ae_adapter_dropout_rate = 0.1
104+
105+ # parameters for global assimilation engine
106+ cf . ae_global_dim_embed = 2048
107+ cf .ae_global_num_blocks = 8
108+ cf .ae_global_num_heads = 32
109+ cf .ae_global_dropout_rate = 0.1
110+ cf .ae_global_with_qk_lnorm = True
111+ cf .ae_global_att_dense_rate = 0.2 # 0.25 : every 4-th block is dense attention
112+ cf .ae_global_block_factor = 64
113+ cf .ae_global_mlp_hidden_factor = 2
114+
115+ cf . pred_adapter_kv = False
116+ cf .pred_self_attention = True
117+ cf .pred_dyadic_dims = False
118+ cf .pred_mlp_adaln = True
119+
120+ # forecasting engine
121+ cf . forecast_delta_hrs = 0
122+ cf .forecast_steps = 0 # [j for j in range(1,11) for i in range(1)]
123+ cf .forecast_policy = None #'fixed', 'sequential'
124+ cf .forecast_freeze_model = False # False
125+ cf .forecast_att_dense_rate = 0.25
126+
127+ cf . fe_num_blocks = 0
128+ cf .fe_num_heads = 16
129+ cf .fe_dropout_rate = 0.1
130+ cf .fe_with_qk_lnorm = True
131+
132+ cf . healpix_level = 5
133+
134+ # working precision
135+ cf . with_mixed_precision = True
136+ cf .with_flash_attention = True
137+ if cf .with_flash_attention :
138+ assert cf .with_mixed_precision
139+ # compile entire model
140+ cf . compile_model = False
141+
142+ cf . with_fsdp = True
143+
144+ cf . loss_fcts = [[ "mse" , 1.0 ]]
145+ cf .loss_fcts_val = [[" mse" , 1.0 ]]
146+ # cf.loss_fcts = [['mse', 0.5], ['stats', 0.5 ]]
147+ # cf.loss_fcts_val = [['mse', 0.5], ['stats', 0.5]]
148+
149+ cf . batch_size = 1
150+ cf .batch_size_validation = 1
151+
152+ # forecast
153+ cf . masking_mode = " forecast"
154+ cf .masking_rate = 0.0
155+ cf .masking_rate_sampling = True # False
156+ cf .sampling_rate_target = 1.0
157+
158+ cf . num_epochs = 24
159+ cf .samples_per_epoch = 4096
160+ cf .samples_per_validation = 512
161+ cf .shuffle = True
162+
163+ cf . lr_scaling_policy = "sqrt"
164+ cf .lr_start = 0.000001
165+ cf .lr_max = 0.00003
166+ cf .lr_final_decay = 0.000001
167+ cf .lr_final = 0.0
168+ cf .lr_steps_warmup = 256
169+ cf .lr_steps_cooldown = 4096
170+ cf .lr_policy_warmup = "cosine"
171+ cf .lr_policy_decay = "linear"
172+ cf .lr_policy_cooldown = " linear"
173+
174+ cf . grad_clip = 5.0
175+ cf .weight_decay = 0.1
176+ cf .norm_type = "LayerNorm" #'LayerNorm' #'RMSNorm'
177+ cf .nn_module = "te"
178+
179+ cf . data_path = "/home/mlx/ai-ml/datasets/stable/"
180+ # cf.data_path = '/lus/h2resw01/fws4/lb/project/ ai-ml/observations/v1 '
181+ # cf.data_path = '/leonardo_scratch/large/userexternal/clessig0/obs /v1'
182+ cf .start_date = 201301010000
183+ cf .end_date = 202012310000
184+ cf .start_date_val = 202101010000
185+ cf .end_date_val = 202201010000
186+ cf .len_hrs = 6
187+ cf .step_hrs = 6
188+ cf .input_window_steps = 1
189+
190+ cf . val_initial = False
191+
192+ cf . loader_num_workers = 8
193+ cf .data_loader_rng_seed = int ( time . time ())
194+ cf .log_validation = 0
195+
196+ cf . istep = 0
197+ cf .run_history = []
198+
199+ cf . run_id = run_id
200+ cf .desc = ""
201+
202+ trainer = Trainer ( log_freq = 20 , checkpoint_freq = 250 , print_freq = 10 )
203+
204+ try :
205+ trainer . run ( cf )
206+ except :
207+ extype , value , tb = sys . exc_info ()
208+ traceback . print_exc ()
209+ pdb . post_mortem ( tb )
210+
211+
212+ if __name__ == " __main__" :
213+ train ()
0 commit comments