Skip to content

Commit 1ecf161

Browse files
authored
[103] Initial profiling tools (ecmwf#104)
* profiling * adding viztracer * work * annotations * better * comments * working * cleanup * changes * comments * changes * changes * changes * fix
1 parent 05c06f0 commit 1ecf161

File tree

9 files changed

+733
-49
lines changed

9 files changed

+733
-49
lines changed

config/profiling/annotations.json

Lines changed: 654 additions & 0 deletions
Large diffs are not rendered by default.

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,9 @@ requires-python = ">=3.11,<3.13"
1212
dependencies = [
1313
'torch==2.6.0',
1414
'numpy~=2.2',
15-
'astropy_healpix~=1.0',
15+
'astropy_healpix~=1.1.2',
1616
'zarr~=2.17',
17-
'anemoi-datasets~=0.5',
17+
'anemoi-datasets~=0.5.16',
1818
'pandas~=2.2',
1919
'pynvml',
2020
'tqdm',

src/weathergen/datasets/anemoi_dataset.py

Lines changed: 22 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,13 @@
88
# nor does it submit to any jurisdiction.
99

1010
import datetime
11+
import logging
1112

1213
import numpy as np
1314
from anemoi.datasets import open_dataset
1415

16+
_logger = logging.getLogger(__name__)
17+
1518

1619
class AnemoiDataset:
1720
"Wrapper for Anemoi dataset"
@@ -30,26 +33,26 @@ def __init__(
3033
assert len_hrs == step_hrs, "Currently only step_hrs=len_hrs is supported"
3134

3235
# open dataset to peak that it is compatible with requested parameters
33-
self.ds = open_dataset(filename)
36+
ds = open_dataset(filename)
3437

3538
# check that start and end time are within the dataset time range
3639

37-
ds_dt_start = self.ds.dates[0]
38-
ds_dt_end = self.ds.dates[-1]
40+
ds_dt_start = ds.dates[0]
41+
ds_dt_end = ds.dates[-1]
3942

4043
format_str = "%Y%m%d%H%M%S"
4144
dt_start = datetime.datetime.strptime(str(start), format_str)
4245
dt_end = datetime.datetime.strptime(str(end), format_str)
4346

4447
# TODO, TODO, TODO: we need proper alignment for the case where self.ds.frequency
4548
# is not a multile of len_hrs
46-
self.num_steps_per_window = int((len_hrs * 3600) / self.ds.frequency.seconds)
49+
self.num_steps_per_window = int((len_hrs * 3600) / ds.frequency.seconds)
4750

4851
# open dataset
4952

5053
# caches lats and lons
51-
self.latitudes = self.ds.latitudes.astype(np.float32)
52-
self.longitudes = self.ds.longitudes.astype(np.float32)
54+
self.latitudes = ds.latitudes.astype(np.float32)
55+
self.longitudes = ds.longitudes.astype(np.float32)
5356

5457
# TODO: define in base class
5558
self.geoinfo_idx = []
@@ -59,8 +62,8 @@ def __init__(
5962
source_channels = stream_info["source"] if "source" in stream_info else None
6063
self.source_idx = np.sort(
6164
[
62-
self.ds.name_to_index[k]
63-
for i, (k, v) in enumerate(self.ds.typed_variables.items())
65+
ds.name_to_index[k]
66+
for i, (k, v) in enumerate(ds.typed_variables.items())
6467
if (
6568
not v.is_computed_forcing
6669
and not v.is_constant_in_time
@@ -75,8 +78,8 @@ def __init__(
7578
target_channels = stream_info["target"] if "target" in stream_info else None
7679
self.target_idx = np.sort(
7780
[
78-
self.ds.name_to_index[k]
79-
for i, (k, v) in enumerate(self.ds.typed_variables.items())
81+
ds.name_to_index[k]
82+
for (k, v) in ds.typed_variables.items()
8083
if (
8184
not v.is_computed_forcing
8285
and not v.is_constant_in_time
@@ -88,21 +91,20 @@ def __init__(
8891
)
8992
]
9093
)
91-
self.source_channels = [self.ds.variables[i] for i in self.source_idx]
92-
self.target_channels = [self.ds.variables[i] for i in self.target_idx]
94+
self.source_channels = [ds.variables[i] for i in self.source_idx]
95+
self.target_channels = [ds.variables[i] for i in self.target_idx]
9396

9497
self.properties = {
9598
"stream_id": 0,
9699
}
97-
self.mean = self.ds.statistics["mean"]
98-
self.stdev = self.ds.statistics["stdev"]
100+
self.mean = ds.statistics["mean"]
101+
self.stdev = ds.statistics["stdev"]
99102

100103
# set dataset to None when no overlap with time range
101104
if dt_start >= ds_dt_end or dt_end <= ds_dt_start:
102105
self.ds = None
103-
return
104-
105-
self.ds = open_dataset(self.ds, frequency=str(step_hrs) + "h", start=dt_start, end=dt_end)
106+
else:
107+
self.ds = open_dataset(ds, frequency=str(step_hrs) + "h", start=dt_start, end=dt_end)
106108

107109
def __len__(self):
108110
"Length of dataset"
@@ -140,8 +142,10 @@ def _get(
140142
)
141143

142144
# extract number of time steps and collapse ensemble dimension
145+
143146
data = self.ds[idx : idx + self.num_steps_per_window][:, :, 0]
144-
# extract channels
147+
148+
# # extract channels
145149
data = (
146150
data[:, channels_idx].transpose([0, 2, 1]).reshape((data.shape[0] * data.shape[2], -1))
147151
)

src/weathergen/datasets/batchifyer.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
r3tos2,
2424
s2tor3,
2525
)
26+
from weathergen.utils.logger import init_loggers
2627

2728

2829
def encode_times_source(times, time_win) -> torch.tensor:
@@ -300,6 +301,7 @@ def batchify_source(
300301
time_win,
301302
normalizer,
302303
):
304+
init_loggers()
303305
si = stream_info
304306
token_size = si["token_size"]
305307
is_diagnostic = si["diagnostic"] if "diagnostic" in stream_info else False

src/weathergen/datasets/multi_stream_data_sampler.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
compute_offsets_scatter_embed,
2525
compute_source_cell_lens,
2626
)
27-
from weathergen.utils.logger import logger
27+
from weathergen.utils.logger import init_loggers, logger
2828

2929

3030
class MultiStreamDataSampler(torch.utils.data.IterableDataset):
@@ -245,7 +245,7 @@ def __iter__(self):
245245
len : number of batch items
246246
len[*] : number of streams
247247
"""
248-
248+
init_loggers()
249249
iter_start, iter_end = self.worker_workset()
250250

251251
# create new shuffeling

src/weathergen/run_train.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
"""
2+
The entry point to train the weathergen model.
3+
"""
4+
5+
# For profiling tools, the entry point cannot be in an __init__.py file.
6+
from weathergen import train
7+
8+
if __name__ == "__main__":
9+
train()

src/weathergen/train/trainer_base.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import pynvml
1717
import torch
1818
import torch.distributed as dist
19+
import torch.multiprocessing
1920
import torch.utils.data.distributed
2021
import yaml
2122

@@ -31,9 +32,19 @@ def __init__(self):
3132

3233
###########################################
3334
@staticmethod
34-
def init_torch(use_cuda=True, num_accs_per_task=1):
35+
def init_torch(use_cuda=True, num_accs_per_task=1, multiprocessing_method="fork"):
36+
"""
37+
Initialize torch, set device and multiprocessing method.
38+
39+
NOTE: If using the Nvidia profiler, the multiprocessing method must be set to "spawn".
40+
The default for linux systems is "fork", which prevents traces from being generated with DDP.
41+
"""
3542
torch.set_printoptions(linewidth=120)
3643

44+
# This strategy is required by the nvidia profiles to properly trace events in worker processes.
45+
# This may cause issues with logging. Alternative: "fork"
46+
torch.multiprocessing.set_start_method(multiprocessing_method, force=True)
47+
3748
torch.backends.cuda.matmul.allow_tf32 = True
3849

3950
use_cuda = torch.cuda.is_available()

src/weathergen/utils/logger.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import logging
1111
import os
1212
import pathlib
13+
from functools import cache
1314

1415

1516
class RelPathFormatter(logging.Formatter):
@@ -23,11 +24,14 @@ def format(self, record):
2324
return super().format(record)
2425

2526

27+
@cache
2628
def init_loggers():
2729
"""
2830
Initialize the logger for the package.
2931
3032
WARNING: this function resets all the logging handlers.
33+
34+
This function can be called only once, so that it can be called repeatedly in multiprocessing pipelines.
3135
"""
3236
formatter = RelPathFormatter("%(pathname)s:%(lineno)d : %(levelname)-8s : %(message)s")
3337
for package in ["obslearn", "weathergen"]:

uv.lock

Lines changed: 26 additions & 26 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)