Skip to content

Commit b034986

Browse files
authored
Sgrasse/develop/issue 169 (ecmwf#170)
* remove dead code: `Trainer.evaluate_jac` * Remove unused arg `jac` in `validation_io.write_validation`
1 parent c42c3e7 commit b034986

File tree

2 files changed

+0
-141
lines changed

2 files changed

+0
-141
lines changed

src/weathergen/train/trainer.py

Lines changed: 0 additions & 139 deletions
Original file line numberDiff line numberDiff line change
@@ -139,145 +139,6 @@ def evaluate(self, cf, run_id_trained, epoch, run_id_new=False):
139139
self.validate(epoch=0)
140140
print(f"Finished evaluation run with id: {cf.run_id}")
141141

142-
###########################################
143-
def evaluate_jac(self, cf, run_id, epoch, mode="row", date=None, obs_id=0, sample_id=0):
144-
"""Computes a row or column of the Jacobian as determined by mode ('row' or 'col'), i.e.
145-
determines sensitivities with respect to outputs or inputs
146-
"""
147-
# TODO: this function is not complete
148-
149-
# general initalization
150-
self.init(cf, run_id, epoch, run_id_new=True, run_mode="offline")
151-
152-
self.dataset = MultiStreamDataSampler(
153-
cf,
154-
cf.start_date_val,
155-
cf.end_date_val,
156-
cf.delta_time,
157-
1,
158-
cf.masking_mode,
159-
cf.masking_rate_sampling,
160-
cf.t_win_hour,
161-
cf.loss_chs,
162-
shuffle=False,
163-
source_chs=cf.source_chs,
164-
forecast_steps=cf.forecast_steps,
165-
forecast_policy=cf.forecast_policy,
166-
healpix_level=cf.healpix_level,
167-
)
168-
169-
num_channels = self.dataset.get_num_chs()
170-
171-
self.model = Model(cf, num_channels).create().to(self.devices[0])
172-
self.model.load(run_id, epoch)
173-
print(f"Loaded model id={run_id}.")
174-
175-
# TODO: support loading of specific data
176-
dataset_iter = iter(self.dataset)
177-
(sources, targets, targets_idxs, s_lens) = next(dataset_iter)
178-
179-
dev = self.devices[0]
180-
sources = [source.to(dev, non_blocking=True) for source in sources]
181-
targets = [[toks.to(dev, non_blocking=True) for toks in target] for target in targets]
182-
183-
# evaluate model
184-
with torch.autocast(
185-
device_type="cuda", dtype=torch.float16, enabled=cf.with_mixed_precision
186-
):
187-
if mode == "row":
188-
sources_in = [*sources, s_lens.to(torch.float32)]
189-
y = self.model(sources, s_lens)
190-
# vectors used to extract row from Jacobian
191-
vs_sources = [torch.zeros_like(y_obs) for y_obs in y[0]]
192-
vs_sources[obs_id][sample_id] = 1.0
193-
# evaluate
194-
out = torch.autograd.functional.vjp(
195-
self.model.forward_jac, tuple(sources_in), tuple(vs_sources)
196-
)
197-
198-
elif mode == "col":
199-
# vectors used to extract col from Jacobian
200-
vs_sources = [torch.zeros_like(s_obs) for s_obs in sources]
201-
vs_sources[obs_id][sample_id] = 1.0
202-
vs_s_lens = torch.zeros_like(s_lens, dtype=torch.float32)
203-
# provide one tuple in the end
204-
sources_in = [*sources, s_lens.to(torch.float32)]
205-
vs_sources.append(vs_s_lens)
206-
# evaluate
207-
out = torch.autograd.functional.jvp(
208-
self.model.forward_jac, tuple(sources_in), tuple(vs_sources)
209-
)
210-
else:
211-
assert False, "Unsupported mode."
212-
213-
# extract and write output
214-
# TODO: refactor and try to combine with the code in compute_loss
215-
216-
preds = out[0]
217-
jac = [j_obs.cpu().detach().numpy() for j_obs in out[1]]
218-
219-
sources_all, preds_all = [[] for _ in cf.streams], [[] for _ in cf.streams]
220-
targets_all, targets_coords_all = [[] for _ in cf.streams], [[] for _ in cf.streams]
221-
targets_idxs_all = [[] for _ in cf.streams]
222-
sources_lens = [toks.shape[0] for toks in sources]
223-
targets_lens = [[toks.shape[0] for toks in target] for target in targets]
224-
225-
for i_obs, b_targets_idxs in enumerate(targets_idxs):
226-
for i_b, target_idxs_obs in enumerate(b_targets_idxs): # 1 batch
227-
if len(targets[i_obs][i_b]) == 0:
228-
continue
229-
230-
gs = self.cf.geoinfo_size
231-
target_i_obs = torch.cat([t[:, gs:].unsqueeze(0) for t in targets[i_obs][i_b]], 0)
232-
preds_i_obs = preds[i_obs][target_idxs_obs]
233-
preds_i_obs = preds_i_obs.reshape([*preds_i_obs.shape[:2], *target_i_obs.shape[1:]])
234-
235-
if self.cf.loss_chs is not None:
236-
if len(self.cf.loss_chs[i_obs]) == 0:
237-
continue
238-
target_i_obs = target_i_obs[..., self.cf.loss_chs[i_obs]]
239-
preds_i_obs = preds_i_obs[..., self.cf.loss_chs[i_obs]]
240-
241-
ds_val = self.dataset
242-
n = self.cf.geoinfo_size
243-
244-
sources[i_obs][:, :, n:] = ds_val.denormalize_data(i_obs, sources[i_obs][:, :, n:])
245-
sources[i_obs][:, :, :n] = ds_val.denormalize_coords(
246-
i_obs, sources[i_obs][:, :, :n]
247-
)
248-
sources_all[i_obs] += [sources[i_obs].detach().cpu()]
249-
250-
preds_all[i_obs] += [ds_val.denormalize_data(i_obs, preds_i_obs).detach().cpu()]
251-
targets_all[i_obs] += [ds_val.denormalize_data(i_obs, target_i_obs).detach().cpu()]
252-
253-
target_i_coords = (
254-
torch.cat([t[:, :n].unsqueeze(0) for t in targets[i_obs][i_b]], 0)
255-
.detach()
256-
.cpu()
257-
)
258-
targets_coords_all[i_obs] += [
259-
ds_val.denormalize_coords(i_obs, target_i_coords).detach().cpu()
260-
]
261-
targets_idxs_all[i_obs] += [target_idxs_obs]
262-
263-
# cols = [ds[0][0].colnames for ds in dataset_val.obs_datasets_norm]
264-
cols = [] # TODO
265-
write_validation(
266-
self.cf,
267-
self.path_run,
268-
self.cf.rank,
269-
epoch,
270-
cols,
271-
sources_all,
272-
preds_all,
273-
targets_all,
274-
targets_coords_all,
275-
targets_idxs_all,
276-
sources_lens,
277-
targets_lens,
278-
jac,
279-
)
280-
281142
###########################################
282143
def run(self, cf, private_cf, run_id_contd=None, epoch_contd=None, run_id_new=False):
283144
# general initalization

src/weathergen/utils/validation_io.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,13 +67,11 @@ def write_validation(
6767
targets_coords_all,
6868
targets_times_all,
6969
targets_lens,
70-
jac=None,
7170
):
7271
if len(cf.analysis_streams_output) == 0:
7372
return
7473

7574
fname = f"validation_epoch{epoch:05d}_rank{rank:04d}"
76-
fname += "" if jac is None else "_jac"
7775
fname += ".zarr"
7876

7977
store = zarr.DirectoryStore(base_path / fname)

0 commit comments

Comments
 (0)