Skip to content

Commit 6b8dec2

Browse files
authored
[23] Applying basic linter rules (ecmwf#24)
* reformating code * cicd * linting fixes * format
1 parent 145c18d commit 6b8dec2

27 files changed

+220
-280
lines changed

pyproject.toml

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -54,18 +54,35 @@ line-length = 100
5454
# All disabled until the code is formatted.
5555
select = [
5656
# pycodestyle
57-
# "E",
57+
"E",
5858
# Pyflakes
59-
# "F",
59+
"F",
6060
# pyupgrade
61-
# "UP",
61+
"UP",
6262
# flake8-bugbear
63-
# "B",
63+
"B",
6464
# flake8-simplify
65-
# "SIM",
65+
"SIM",
6666
# isort
67-
# "I",
67+
"I",
6868
]
69-
ignore = [
7069

70+
# These rules are sensible and should be enabled at a later stage.
71+
ignore = [
72+
"E501",
73+
"E721",
74+
"E722",
75+
"B006",
76+
"B011",
77+
"UP008",
78+
"SIM115",
79+
"SIM117",
80+
"SIM118",
81+
"SIM102",
82+
"SIM210",
83+
"SIM212",
84+
"SIM401",
85+
"F811",
86+
# To ignore, not relevant for us
87+
"E741",
7188
]

src/weathergen/__init__.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,13 @@
77
# granted to it by virtue of its status as an intergovernmental organisation
88
# nor does it submit to any jurisdiction.
99

10-
import time
11-
import sys
1210
import pdb
11+
import sys
12+
import time
1313
import traceback
1414

15-
from weathergen.utils.config import Config
1615
from weathergen.train.trainer import Trainer
17-
from weathergen.train.utils import get_run_id
16+
from weathergen.utils.config import Config
1817

1918

2019
####################################################################################################

src/weathergen/datasets/anemoi_dataset.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,9 @@
77
# granted to it by virtue of its status as an intergovernmental organisation
88
# nor does it submit to any jurisdiction.
99

10-
import code
1110
import datetime
1211

1312
import numpy as np
14-
1513
from anemoi.datasets import open_dataset
1614

1715

src/weathergen/datasets/batchifyer.py

Lines changed: 7 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -7,29 +7,20 @@
77
# granted to it by virtue of its status as an intergovernmental organisation
88
# nor does it submit to any jurisdiction.
99

10-
import torch
11-
import numpy as np
12-
import code
1310
import warnings
14-
import time
11+
from functools import partial
1512

1613
import astropy_healpix as hp
17-
from astropy_healpix.healpy import ang2pix, pix2ang
18-
19-
from functools import partial
14+
import numpy as np
15+
import torch
16+
from astropy_healpix.healpy import ang2pix
2017

2118
from weathergen.datasets.utils import (
22-
vecs_to_rots,
23-
s2tor3,
24-
r3tos2,
25-
locs_to_cell_coords,
26-
coords_to_hpyidxs,
27-
healpix_verts,
28-
get_target_coords_local,
29-
get_target_coords_local_fast,
3019
get_target_coords_local_ffast,
3120
healpix_verts_rots,
3221
locs_to_cell_coords_ctrs,
22+
r3tos2,
23+
s2tor3,
3324
)
3425

3526

@@ -64,7 +55,6 @@ def tokenize_window_space(
6455
)
6556
hpy_idxs_ord_split = np.split(hpy_idxs_ord, splits + 1)
6657

67-
lens = []
6858
for i, c in enumerate(cells_idxs):
6959
thetas_sorted = torch.argsort(thetas[hpy_idxs_ord_split[i]], stable=True)
7060
posr3_cell = posr3[hpy_idxs_ord_split[i]][thetas_sorted]
@@ -110,7 +100,7 @@ def tokenize_window_spacetime(
110100
mr,
111101
):
112102
t_unique = np.unique(times)
113-
for i, t in enumerate(t_unique):
103+
for _, t in enumerate(t_unique):
114104
mask = t == times
115105
tokens_cells = tokenize_window_space(
116106
source[mask],

src/weathergen/datasets/multi_stream_data_sampler.py

Lines changed: 22 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -7,25 +7,18 @@
77
# granted to it by virtue of its status as an intergovernmental organisation
88
# nor does it submit to any jurisdiction.
99

10-
import numpy as np
11-
import torch
12-
import math
1310
import datetime
14-
from copy import deepcopy
1511
import logging
16-
import time
17-
import code
18-
import os
19-
import yaml
2012

13+
import numpy as np
2114
import pandas as pd
15+
import torch
2216

23-
from weathergen.datasets.obs_dataset import ObsDataset
2417
from weathergen.datasets.anemoi_dataset import AnemoiDataset
25-
from weathergen.datasets.normalizer import DataNormalizer
2618
from weathergen.datasets.batchifyer import Batchifyer
19+
from weathergen.datasets.normalizer import DataNormalizer
20+
from weathergen.datasets.obs_dataset import ObsDataset
2721
from weathergen.datasets.utils import merge_cells
28-
2922
from weathergen.utils.logger import logger
3023

3124

@@ -69,7 +62,7 @@ def __init__(
6962
self.len_hrs = len_hrs
7063
self.step_hrs = step_hrs
7164

72-
fc_policy_seq = "sequential" == forecast_policy or "sequential_random" == forecast_policy
65+
fc_policy_seq = forecast_policy == "sequential" or forecast_policy == "sequential_random"
7366
assert forecast_steps >= 0 if not fc_policy_seq else True
7467
self.forecast_delta_hrs = forecast_delta_hrs if forecast_delta_hrs > 0 else self.len_hrs
7568
self.forecast_steps = np.array(
@@ -111,7 +104,7 @@ def __init__(
111104
# the processing here is not natural but a workaround to various inconsistencies in the
112105
# current datasets
113106
data_idxs = [
114-
i for i, cn in enumerate(ds.selected_colnames[do:]) if "obsvalue_" == cn[:9]
107+
i for i, cn in enumerate(ds.selected_colnames[do:]) if cn[:9] == "obsvalue_"
115108
]
116109
mask = np.ones(len(ds.selected_colnames[do:]), dtype=np.int32).astype(bool)
117110
mask[data_idxs] = False
@@ -272,7 +265,7 @@ def __iter__(self):
272265
# idx_raw is used to index into the dataset; the decoupling is needed
273266
# since there are empty batches
274267
idx_raw = iter_start
275-
for i, bidx in enumerate(range(iter_start, iter_end, self.batch_size)):
268+
for i, _bidx in enumerate(range(iter_start, iter_end, self.batch_size)):
276269
# targets, targets_coords, targets_idxs = [], [], [],
277270
tcs, tcs_lens, target_tokens, source_tokens_cells, source_tokens_lens = (
278271
[],
@@ -314,7 +307,7 @@ def __iter__(self):
314307
c_source_raw = []
315308

316309
for obs_id, (stream_info, stream_dsn, stream_idxs) in enumerate(
317-
zip(self.streams, self.obs_datasets_norm, self.obs_datasets_idxs)
310+
zip(self.streams, self.obs_datasets_norm, self.obs_datasets_idxs, strict=False)
318311
):
319312
s_tcs = []
320313
s_tcs_lens = []
@@ -326,17 +319,17 @@ def __iter__(self):
326319
s_source_raw = []
327320

328321
token_size = stream_info["token_size"]
329-
grid = (
330-
stream_info["gridded_output"] if "gridded_output" in stream_info else None
331-
)
332-
grid_info = (
333-
stream_info["gridded_output_info"]
334-
if "gridded_output_info" in stream_info
335-
else None
336-
)
322+
# grid = (
323+
# stream_info["gridded_output"] if "gridded_output" in stream_info else None
324+
# )
325+
# grid_info = (
326+
# stream_info["gridded_output_info"]
327+
# if "gridded_output_info" in stream_info
328+
# else None
329+
# )
337330

338331
for i_source, ((ds, normalizer, do), s_idxs) in enumerate(
339-
zip(stream_dsn, stream_idxs)
332+
zip(stream_dsn, stream_idxs, strict=False)
340333
):
341334
# source window (of potentially multi-step length)
342335
(source1, times1) = ds[idx]
@@ -417,7 +410,7 @@ def __iter__(self):
417410
for fstep in range(forecast_dt + 1):
418411
# collect all streams
419412
for i_source, ((ds, normalizer, do), s_idxs) in enumerate(
420-
zip(stream_dsn, stream_idxs)
413+
zip(stream_dsn, stream_idxs, strict=False)
421414
):
422415
(source2, times2) = ds[idx + step_forecast_dt]
423416

@@ -534,15 +527,17 @@ def __iter__(self):
534527
idxs = torch.cat(
535528
[
536529
torch.arange(o, o + l, dtype=torch.int64)
537-
for o, l in zip(offsets, source_tokens_lens[ib, itype])
530+
for o, l in zip(offsets, source_tokens_lens[ib, itype], strict=False)
538531
]
539532
)
540533
idxs_embed[-1] += [idxs.unsqueeze(1)]
541534
idxs_embed_pe[-1] += [
542535
torch.cat(
543536
[
544537
torch.arange(o, o + l, dtype=torch.int32)
545-
for o, l in zip(offsets_pe, source_tokens_lens[ib][itype])
538+
for o, l in zip(
539+
offsets_pe, source_tokens_lens[ib][itype], strict=False
540+
)
546541
]
547542
)
548543
]

src/weathergen/datasets/normalizer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,11 +69,11 @@ def normalize_coords(self, data, normalize_latlon=True):
6969

7070
go = self.geoinfo_offset
7171
for i, ch in enumerate(self.geoinfo_idx):
72-
if 0 == i: # lats
72+
if i == 0: # lats
7373
if normalize_latlon:
7474
data[..., go + i] = np.sin(np.deg2rad(data[..., go + i]))
7575
pass
76-
elif 1 == i: # lons
76+
elif i == 1: # lons
7777
if normalize_latlon:
7878
data[..., go + i] = np.sin(0.5 * np.deg2rad(data[..., go + i]))
7979
else:

src/weathergen/datasets/obs_dataset.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,11 @@
77
# granted to it by virtue of its status as an intergovernmental organisation
88
# nor does it submit to any jurisdiction.
99

10+
import code
1011
import datetime
1112

1213
import numpy as np
1314
import zarr
14-
import code
1515

1616

1717
class ObsDataset:
@@ -37,12 +37,14 @@ def __init__(
3737

3838
# self.selected_colnames = self.colnames
3939
# self.selected_cols_idx = np.arange(len(self.colnames))
40+
idx = 0
4041
for i, col in enumerate(reversed(self.colnames)):
42+
idx = i
4143
# if col[:9] == 'obsvalue_' :
4244
if not (col[:4] == "sin_" or col[:4] == "cos_"):
4345
break
44-
self.selected_colnames = self.colnames[: len(self.colnames) - i]
45-
self.selected_cols_idx = np.arange(len(self.colnames))[: len(self.colnames) - i]
46+
self.selected_colnames = self.colnames[: len(self.colnames) - idx]
47+
self.selected_cols_idx = np.arange(len(self.colnames))[: len(self.colnames) - idx]
4648

4749
# Create index for samples
4850
self._setup_sample_index(start, end, self.len_hrs, self.step_hrs)
@@ -190,7 +192,7 @@ def _load_properties(self) -> None:
190192

191193
####################################################################################################
192194
if __name__ == "__main__":
193-
zarrpath = config.zarrpath
195+
# zarrpath = config.zarrpath
194196
zarrpath = "/lus/h2resw01/fws4/lb/project/ai-ml/observations/zarr/v0.2"
195197

196198
# # polar orbiting satellites

0 commit comments

Comments
 (0)