Skip to content

Commit 19f844e

Browse files
authored
Deprecate cubed.extensions and move to cubed.diagnostics (#533)
* Rename `cubed.extensions` to `cubed.diagnostics` * Deprecate `cubed.extensions` * Update tests, examples and notebooks to use `cubed.diagnostics`
1 parent ceddb7f commit 19f844e

20 files changed

+463
-420
lines changed

cubed/diagnostics/__init__.py

Whitespace-only changes.

cubed/diagnostics/history.py

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
from dataclasses import asdict
2+
from pathlib import Path
3+
4+
import pandas as pd
5+
6+
from cubed.runtime.pipeline import visit_nodes
7+
from cubed.runtime.types import Callback
8+
9+
10+
class HistoryCallback(Callback):
11+
def on_compute_start(self, event):
12+
plan = []
13+
for name, node in visit_nodes(event.dag, event.resume):
14+
primitive_op = node["primitive_op"]
15+
plan.append(
16+
dict(
17+
name=name,
18+
op_name=node["op_name"],
19+
projected_mem=primitive_op.projected_mem,
20+
reserved_mem=primitive_op.reserved_mem,
21+
num_tasks=primitive_op.num_tasks,
22+
)
23+
)
24+
25+
self.plan = plan
26+
self.events = []
27+
28+
def on_task_end(self, event):
29+
self.events.append(asdict(event))
30+
31+
def on_compute_end(self, event):
32+
self.plan_df = pd.DataFrame(self.plan)
33+
self.events_df = pd.DataFrame(self.events)
34+
history_path = Path(f"history/{event.compute_id}")
35+
history_path.mkdir(parents=True, exist_ok=True)
36+
self.plan_df_path = history_path / "plan.csv"
37+
self.events_df_path = history_path / "events.csv"
38+
self.stats_df_path = history_path / "stats.csv"
39+
self.plan_df.to_csv(self.plan_df_path, index=False)
40+
self.events_df.to_csv(self.events_df_path, index=False)
41+
42+
self.stats_df = analyze(self.plan_df, self.events_df)
43+
self.stats_df.to_csv(self.stats_df_path, index=False)
44+
45+
46+
def analyze(plan_df, events_df):
47+
# convert memory to MB
48+
plan_df["projected_mem_mb"] = plan_df["projected_mem"] / 1_000_000
49+
plan_df["reserved_mem_mb"] = plan_df["reserved_mem"] / 1_000_000
50+
plan_df = plan_df[
51+
[
52+
"name",
53+
"op_name",
54+
"projected_mem_mb",
55+
"reserved_mem_mb",
56+
"num_tasks",
57+
]
58+
]
59+
events_df["peak_measured_mem_start_mb"] = (
60+
events_df["peak_measured_mem_start"] / 1_000_000
61+
)
62+
events_df["peak_measured_mem_end_mb"] = (
63+
events_df["peak_measured_mem_end"] / 1_000_000
64+
)
65+
events_df["peak_measured_mem_delta_mb"] = (
66+
events_df["peak_measured_mem_end_mb"] - events_df["peak_measured_mem_start_mb"]
67+
)
68+
69+
# find per-array stats
70+
df = events_df.groupby("name", as_index=False).agg(
71+
{
72+
"peak_measured_mem_start_mb": ["min", "mean", "max"],
73+
"peak_measured_mem_end_mb": ["max"],
74+
"peak_measured_mem_delta_mb": ["min", "mean", "max"],
75+
}
76+
)
77+
78+
# flatten multi-index
79+
df.columns = ["_".join(a).rstrip("_") for a in df.columns.to_flat_index()]
80+
df = df.merge(plan_df, on="name")
81+
82+
def projected_mem_utilization(row):
83+
return row["peak_measured_mem_end_mb_max"] / row["projected_mem_mb"]
84+
85+
df["projected_mem_utilization"] = df.apply(
86+
lambda row: projected_mem_utilization(row), axis=1
87+
)
88+
df = df[
89+
[
90+
"name",
91+
"op_name",
92+
"num_tasks",
93+
"peak_measured_mem_start_mb_max",
94+
"peak_measured_mem_end_mb_max",
95+
"peak_measured_mem_delta_mb_max",
96+
"projected_mem_mb",
97+
"reserved_mem_mb",
98+
"projected_mem_utilization",
99+
]
100+
]
101+
102+
return df

cubed/diagnostics/mem_warn.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
import warnings
2+
from collections import Counter
3+
4+
from cubed.runtime.pipeline import visit_nodes
5+
from cubed.runtime.types import Callback
6+
7+
8+
class MemoryWarningCallback(Callback):
9+
def on_compute_start(self, event):
10+
# store ops keyed by name
11+
self.ops = {}
12+
for name, node in visit_nodes(event.dag, event.resume):
13+
primitive_op = node["primitive_op"]
14+
self.ops[name] = primitive_op
15+
16+
# count number of times each op exceeds allowed mem
17+
self.counter = Counter()
18+
19+
def on_task_end(self, event):
20+
allowed_mem = self.ops[event.name].allowed_mem
21+
if (
22+
event.peak_measured_mem_end is not None
23+
and event.peak_measured_mem_end > allowed_mem
24+
):
25+
self.counter.update({event.name: 1})
26+
27+
def on_compute_end(self, event):
28+
if sum(self.counter.values()) > 0:
29+
exceeded = [
30+
f"{k} ({v}/{self.ops[k].num_tasks})" for k, v in self.counter.items()
31+
]
32+
warnings.warn(
33+
f"Peak memory usage exceeded allowed_mem when running tasks: {', '.join(exceeded)}",
34+
UserWarning,
35+
)

cubed/diagnostics/rich.py

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
import logging
2+
import sys
3+
from contextlib import contextmanager
4+
5+
from rich.console import RenderableType
6+
from rich.progress import (
7+
BarColumn,
8+
MofNCompleteColumn,
9+
Progress,
10+
SpinnerColumn,
11+
Task,
12+
TaskProgressColumn,
13+
TextColumn,
14+
TimeElapsedColumn,
15+
)
16+
from rich.text import Text
17+
18+
from cubed.runtime.pipeline import visit_nodes
19+
from cubed.runtime.types import Callback
20+
21+
22+
class RichProgressBar(Callback):
23+
"""Rich progress bar for a computation."""
24+
25+
def on_compute_start(self, event):
26+
# Set the pulse_style to the background colour to disable pulsing,
27+
# since Rich will pulse all non-started bars.
28+
logger_aware_progress = LoggerAwareProgress(
29+
SpinnerWhenRunningColumn(),
30+
TextColumn("[progress.description]{task.description}"),
31+
LeftJustifiedMofNCompleteColumn(),
32+
BarColumn(bar_width=None, pulse_style="bar.back"),
33+
TaskProgressColumn(
34+
text_format="[progress.percentage]{task.percentage:>3.1f}%"
35+
),
36+
TimeElapsedColumn(),
37+
logger=logging.getLogger(),
38+
)
39+
progress = logger_aware_progress.__enter__()
40+
41+
progress_tasks = {}
42+
for name, node in visit_nodes(event.dag, event.resume):
43+
num_tasks = node["primitive_op"].num_tasks
44+
op_display_name = node["op_display_name"].replace("\n", " ")
45+
progress_task = progress.add_task(
46+
f"{op_display_name}", start=False, total=num_tasks
47+
)
48+
progress_tasks[name] = progress_task
49+
50+
self.logger_aware_progress = logger_aware_progress
51+
self.progress = progress
52+
self.progress_tasks = progress_tasks
53+
54+
def on_compute_end(self, event):
55+
self.logger_aware_progress.__exit__(None, None, None)
56+
57+
def on_operation_start(self, event):
58+
self.progress.start_task(self.progress_tasks[event.name])
59+
60+
def on_task_end(self, event):
61+
self.progress.update(
62+
self.progress_tasks[event.name], advance=event.num_tasks, refresh=True
63+
)
64+
65+
66+
class SpinnerWhenRunningColumn(SpinnerColumn):
67+
def __init__(self, *args, **kwargs):
68+
super().__init__(*args, **kwargs)
69+
70+
# Override so spinner is not shown when bar has not yet started
71+
def render(self, task: "Task") -> RenderableType:
72+
text = (
73+
self.finished_text
74+
if not task.started or task.finished
75+
else self.spinner.render(task.get_time())
76+
)
77+
return text
78+
79+
80+
class LeftJustifiedMofNCompleteColumn(MofNCompleteColumn):
81+
def __init__(self, *args, **kwargs):
82+
super().__init__(*args, **kwargs)
83+
84+
def render(self, task: "Task") -> Text:
85+
"""Show completed/total."""
86+
completed = int(task.completed)
87+
total = int(task.total) if task.total is not None else "?"
88+
total_width = len(str(total))
89+
return Text(
90+
f"{completed}{self.separator}{total}".ljust(total_width + 1 + total_width),
91+
style="progress.download",
92+
)
93+
94+
95+
# Based on CustomProgress from https://github.com/Textualize/rich/discussions/1578
96+
@contextmanager
97+
def LoggerAwareProgress(*args, **kwargs):
98+
"""Wrapper around rich.progress.Progress to manage logging output to stderr."""
99+
try:
100+
__logger = kwargs.pop("logger", None)
101+
streamhandlers = [
102+
x for x in __logger.root.handlers if type(x) is logging.StreamHandler
103+
]
104+
105+
with Progress(*args, **kwargs) as progress:
106+
for handler in streamhandlers:
107+
__prior_stderr = handler.stream
108+
handler.setStream(sys.stderr)
109+
110+
yield progress
111+
112+
finally:
113+
streamhandlers = [
114+
x for x in __logger.root.handlers if type(x) is logging.StreamHandler
115+
]
116+
for handler in streamhandlers:
117+
handler.setStream(__prior_stderr)

cubed/diagnostics/timeline.py

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
import os
2+
import time
3+
from dataclasses import asdict
4+
from typing import Optional
5+
6+
import matplotlib.patches as mpatches
7+
import numpy as np
8+
import pandas as pd
9+
import pylab
10+
import seaborn as sns
11+
12+
from cubed.runtime.types import Callback
13+
14+
sns.set_style("whitegrid")
15+
pylab.switch_backend("Agg")
16+
17+
18+
class TimelineVisualizationCallback(Callback):
19+
def __init__(self, format: Optional[str] = None) -> None:
20+
self.format = format
21+
22+
def on_compute_start(self, event):
23+
self.start_tstamp = time.time()
24+
self.stats = []
25+
26+
def on_task_end(self, event):
27+
self.stats.append(asdict(event))
28+
29+
def on_compute_end(self, event):
30+
end_tstamp = time.time()
31+
dst = f"history/{event.compute_id}"
32+
format = self.format
33+
create_timeline(self.stats, self.start_tstamp, end_tstamp, dst, format)
34+
35+
36+
# copy of lithops function of the same name, and modified for different field names
37+
def create_timeline(stats, start_tstamp, end_tstamp, dst=None, format=None):
38+
stats_df = pd.DataFrame(stats)
39+
40+
stats_df = stats_df.sort_values(by=["task_create_tstamp", "name"], ascending=True)
41+
42+
total_calls = len(stats_df)
43+
44+
palette = sns.color_palette("deep", 6)
45+
46+
fig = pylab.figure(figsize=(10, 6))
47+
ax = fig.add_subplot(1, 1, 1)
48+
49+
y = np.arange(total_calls)
50+
point_size = 10
51+
52+
fields = [
53+
("task create", stats_df.task_create_tstamp - start_tstamp),
54+
("function start", stats_df.function_start_tstamp - start_tstamp),
55+
("function end", stats_df.function_end_tstamp - start_tstamp),
56+
("task result", stats_df.task_result_tstamp - start_tstamp),
57+
]
58+
59+
patches = []
60+
for f_i, (field_name, val) in enumerate(fields):
61+
ax.scatter(val, y, c=[palette[f_i]], edgecolor="none", s=point_size, alpha=0.8)
62+
patches.append(mpatches.Patch(color=palette[f_i], label=field_name))
63+
64+
ax.set_xlabel("Execution Time (sec)")
65+
ax.set_ylabel("Function Call")
66+
67+
legend = pylab.legend(handles=patches, loc="upper right", frameon=True)
68+
legend.get_frame().set_facecolor("#FFFFFF")
69+
70+
yplot_step = int(np.max([1, total_calls / 20]))
71+
y_ticks = np.arange(total_calls // yplot_step + 2) * yplot_step
72+
ax.set_yticks(y_ticks)
73+
ax.set_ylim(-0.02 * total_calls, total_calls * 1.02)
74+
for y in y_ticks:
75+
ax.axhline(y, c="k", alpha=0.1, linewidth=1)
76+
77+
max_seconds = np.max(end_tstamp - start_tstamp) * 1.25
78+
xplot_step = max(int(max_seconds / 8), 1)
79+
x_ticks = np.arange(max_seconds // xplot_step + 2) * xplot_step
80+
ax.set_xlim(0, max_seconds)
81+
82+
ax.set_xticks(x_ticks)
83+
for x in x_ticks:
84+
ax.axvline(x, c="k", alpha=0.2, linewidth=0.8)
85+
86+
ax.grid(False)
87+
fig.tight_layout()
88+
89+
if format is None:
90+
format = "svg"
91+
92+
if dst is None:
93+
os.makedirs("plots", exist_ok=True)
94+
dst = os.path.join(
95+
os.getcwd(), "plots", "{}_{}".format(int(time.time()), f"timeline.{format}")
96+
)
97+
else:
98+
dst = os.path.expanduser(dst) if "~" in dst else dst
99+
dst = "{}/{}".format(os.path.realpath(dst), f"timeline.{format}")
100+
101+
fig.savefig(dst)

0 commit comments

Comments
 (0)