@@ -139,145 +139,6 @@ def evaluate(self, cf, run_id_trained, epoch, run_id_new=False):
139139 self .validate (epoch = 0 )
140140 print (f"Finished evaluation run with id: { cf .run_id } " )
141141
142- ###########################################
143- def evaluate_jac (self , cf , run_id , epoch , mode = "row" , date = None , obs_id = 0 , sample_id = 0 ):
144- """Computes a row or column of the Jacobian as determined by mode ('row' or 'col'), i.e.
145- determines sensitivities with respect to outputs or inputs
146- """
147- # TODO: this function is not complete
148-
149- # general initalization
150- self .init (cf , run_id , epoch , run_id_new = True , run_mode = "offline" )
151-
152- self .dataset = MultiStreamDataSampler (
153- cf ,
154- cf .start_date_val ,
155- cf .end_date_val ,
156- cf .delta_time ,
157- 1 ,
158- cf .masking_mode ,
159- cf .masking_rate_sampling ,
160- cf .t_win_hour ,
161- cf .loss_chs ,
162- shuffle = False ,
163- source_chs = cf .source_chs ,
164- forecast_steps = cf .forecast_steps ,
165- forecast_policy = cf .forecast_policy ,
166- healpix_level = cf .healpix_level ,
167- )
168-
169- num_channels = self .dataset .get_num_chs ()
170-
171- self .model = Model (cf , num_channels ).create ().to (self .devices [0 ])
172- self .model .load (run_id , epoch )
173- print (f"Loaded model id={ run_id } ." )
174-
175- # TODO: support loading of specific data
176- dataset_iter = iter (self .dataset )
177- (sources , targets , targets_idxs , s_lens ) = next (dataset_iter )
178-
179- dev = self .devices [0 ]
180- sources = [source .to (dev , non_blocking = True ) for source in sources ]
181- targets = [[toks .to (dev , non_blocking = True ) for toks in target ] for target in targets ]
182-
183- # evaluate model
184- with torch .autocast (
185- device_type = "cuda" , dtype = torch .float16 , enabled = cf .with_mixed_precision
186- ):
187- if mode == "row" :
188- sources_in = [* sources , s_lens .to (torch .float32 )]
189- y = self .model (sources , s_lens )
190- # vectors used to extract row from Jacobian
191- vs_sources = [torch .zeros_like (y_obs ) for y_obs in y [0 ]]
192- vs_sources [obs_id ][sample_id ] = 1.0
193- # evaluate
194- out = torch .autograd .functional .vjp (
195- self .model .forward_jac , tuple (sources_in ), tuple (vs_sources )
196- )
197-
198- elif mode == "col" :
199- # vectors used to extract col from Jacobian
200- vs_sources = [torch .zeros_like (s_obs ) for s_obs in sources ]
201- vs_sources [obs_id ][sample_id ] = 1.0
202- vs_s_lens = torch .zeros_like (s_lens , dtype = torch .float32 )
203- # provide one tuple in the end
204- sources_in = [* sources , s_lens .to (torch .float32 )]
205- vs_sources .append (vs_s_lens )
206- # evaluate
207- out = torch .autograd .functional .jvp (
208- self .model .forward_jac , tuple (sources_in ), tuple (vs_sources )
209- )
210- else :
211- assert False , "Unsupported mode."
212-
213- # extract and write output
214- # TODO: refactor and try to combine with the code in compute_loss
215-
216- preds = out [0 ]
217- jac = [j_obs .cpu ().detach ().numpy () for j_obs in out [1 ]]
218-
219- sources_all , preds_all = [[] for _ in cf .streams ], [[] for _ in cf .streams ]
220- targets_all , targets_coords_all = [[] for _ in cf .streams ], [[] for _ in cf .streams ]
221- targets_idxs_all = [[] for _ in cf .streams ]
222- sources_lens = [toks .shape [0 ] for toks in sources ]
223- targets_lens = [[toks .shape [0 ] for toks in target ] for target in targets ]
224-
225- for i_obs , b_targets_idxs in enumerate (targets_idxs ):
226- for i_b , target_idxs_obs in enumerate (b_targets_idxs ): # 1 batch
227- if len (targets [i_obs ][i_b ]) == 0 :
228- continue
229-
230- gs = self .cf .geoinfo_size
231- target_i_obs = torch .cat ([t [:, gs :].unsqueeze (0 ) for t in targets [i_obs ][i_b ]], 0 )
232- preds_i_obs = preds [i_obs ][target_idxs_obs ]
233- preds_i_obs = preds_i_obs .reshape ([* preds_i_obs .shape [:2 ], * target_i_obs .shape [1 :]])
234-
235- if self .cf .loss_chs is not None :
236- if len (self .cf .loss_chs [i_obs ]) == 0 :
237- continue
238- target_i_obs = target_i_obs [..., self .cf .loss_chs [i_obs ]]
239- preds_i_obs = preds_i_obs [..., self .cf .loss_chs [i_obs ]]
240-
241- ds_val = self .dataset
242- n = self .cf .geoinfo_size
243-
244- sources [i_obs ][:, :, n :] = ds_val .denormalize_data (i_obs , sources [i_obs ][:, :, n :])
245- sources [i_obs ][:, :, :n ] = ds_val .denormalize_coords (
246- i_obs , sources [i_obs ][:, :, :n ]
247- )
248- sources_all [i_obs ] += [sources [i_obs ].detach ().cpu ()]
249-
250- preds_all [i_obs ] += [ds_val .denormalize_data (i_obs , preds_i_obs ).detach ().cpu ()]
251- targets_all [i_obs ] += [ds_val .denormalize_data (i_obs , target_i_obs ).detach ().cpu ()]
252-
253- target_i_coords = (
254- torch .cat ([t [:, :n ].unsqueeze (0 ) for t in targets [i_obs ][i_b ]], 0 )
255- .detach ()
256- .cpu ()
257- )
258- targets_coords_all [i_obs ] += [
259- ds_val .denormalize_coords (i_obs , target_i_coords ).detach ().cpu ()
260- ]
261- targets_idxs_all [i_obs ] += [target_idxs_obs ]
262-
263- # cols = [ds[0][0].colnames for ds in dataset_val.obs_datasets_norm]
264- cols = [] # TODO
265- write_validation (
266- self .cf ,
267- self .path_run ,
268- self .cf .rank ,
269- epoch ,
270- cols ,
271- sources_all ,
272- preds_all ,
273- targets_all ,
274- targets_coords_all ,
275- targets_idxs_all ,
276- sources_lens ,
277- targets_lens ,
278- jac ,
279- )
280-
281142 ###########################################
282143 def run (self , cf , private_cf , run_id_contd = None , epoch_contd = None , run_id_new = False ):
283144 # general initalization
0 commit comments