Skip to content

Commit 02be8a4

Browse files
feat(truss): Add Metrics watching to truss train (#1574)
* metrics start * working * TrainingPoller * small updates * small refactor * precommit * remove storage * rename common func * dynamic coloring * offset minutes * move files, offset minutes * i think this is it * add unit test with heavy mocking * pr review * test train core * remove unnecessary codes
1 parent a4ed096 commit 02be8a4

File tree

9 files changed

+395
-80
lines changed

9 files changed

+395
-80
lines changed

truss/cli/cli.py

Lines changed: 30 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from rich.console import Console
2222

2323
import truss
24+
import truss.cli.train.core as train_cli
2425
from truss.base.constants import (
2526
PRODUCTION_ENVIRONMENT_NAME,
2627
TRTLLM_MIN_MEMORY_REQUEST_GI,
@@ -32,12 +33,6 @@
3233
from truss.cli.logs import utils as cli_log_utils
3334
from truss.cli.logs.model_log_watcher import ModelDeploymentLogWatcher
3435
from truss.cli.logs.training_log_watcher import TrainingLogWatcher
35-
from truss.cli.train import (
36-
get_args_for_logs,
37-
get_args_for_stop,
38-
stop_all_jobs,
39-
view_training_details,
40-
)
4136
from truss.remote.baseten.core import (
4237
ACTIVE_STATUS,
4338
DEPLOYING_STATUSES,
@@ -925,7 +920,9 @@ def push_training_job(config: Path, remote: Optional[str], tail: bool):
925920
console.print("✨ Training job successfully created!", style="green")
926921
console.print(
927922
f"🪵 View logs for your job via "
928-
f"[cyan]`truss train logs --project-id {project_resp['id']} --job-id {job_resp['id']} [--tail]`[/cyan]"
923+
f"[cyan]`truss train logs --job-id {job_resp['id']} [--tail]`[/cyan]\n"
924+
f"🔍 View metrics for your job via "
925+
f"[cyan]`truss train metrics --job-id {job_resp['id']}`[/cyan]"
929926
)
930927

931928
if tail:
@@ -953,7 +950,9 @@ def get_job_logs(
953950
remote_provider: BasetenRemote = cast(
954951
BasetenRemote, RemoteFactory.create(remote=remote)
955952
)
956-
project_id, job_id = get_args_for_logs(console, remote_provider, project_id, job_id)
953+
project_id, job_id = train_cli.get_args_for_monitoring(
954+
console, remote_provider, project_id, job_id
955+
)
957956

958957
if not tail:
959958
logs = remote_provider.api.get_training_job_logs(project_id, job_id)
@@ -986,9 +985,9 @@ def stop_job(
986985
BasetenRemote, RemoteFactory.create(remote=remote)
987986
)
988987
if all:
989-
stop_all_jobs(console, remote_provider, project_id)
988+
train_cli.stop_all_jobs(console, remote_provider, project_id)
990989
else:
991-
project_id, job_id = get_args_for_stop(
990+
project_id, job_id = train_cli.get_args_for_stop(
992991
console, remote_provider, project_id, job_id
993992
)
994993
remote_provider.api.stop_training_job(project_id, job_id)
@@ -1016,7 +1015,27 @@ def view_training(
10161015
remote_provider: BasetenRemote = cast(
10171016
BasetenRemote, RemoteFactory.create(remote=remote)
10181017
)
1019-
view_training_details(console, remote_provider, project_id, job_id)
1018+
train_cli.view_training_details(console, remote_provider, project_id, job_id)
1019+
1020+
1021+
@train.command(name="metrics")
1022+
@click.option("--project-id", type=str, required=False, help="Project ID.")
1023+
@click.option("--job-id", type=str, required=False, help="Job ID.")
1024+
@click.option("--remote", type=str, required=False, help="Remote to use")
1025+
@log_level_option
1026+
@error_handling
1027+
def get_job_metrics(
1028+
project_id: Optional[str], job_id: Optional[str], remote: Optional[str]
1029+
):
1030+
"""Get metrics for a training job"""
1031+
1032+
if not remote:
1033+
remote = remote_cli.inquire_remote_name()
1034+
1035+
remote_provider: BasetenRemote = cast(
1036+
BasetenRemote, RemoteFactory.create(remote=remote)
1037+
)
1038+
train_cli.view_training_job_metrics(console, remote_provider, project_id, job_id)
10201039

10211040

10221041
# End Training Stuff #####################################################################

truss/cli/common.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
POLL_INTERVAL_SEC = 2

truss/cli/logs/base_watcher.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,12 @@
55

66
from rich import console as rich_console
77

8+
from truss.cli.common import POLL_INTERVAL_SEC
89
from truss.cli.logs.utils import ParsedLog, parse_logs
910
from truss.remote.baseten.api import BasetenApi
1011

1112
# NB(nikhil): This helps account for (1) log processing delays (2) clock skews
1213
CLOCK_SKEW_BUFFER_MS = 10000
13-
POLL_INTERVAL_SEC = 2
1414

1515

1616
class LogWatcher(ABC):
Lines changed: 7 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,16 @@
11
import signal
2-
import time
32
from typing import Any, List, Optional
43

54
from rich import console as rich_console
65

7-
from truss.cli.logs.base_watcher import POLL_INTERVAL_SEC, LogWatcher
6+
from truss.cli.logs.base_watcher import LogWatcher
7+
from truss.cli.train.poller import TrainingPollerMixin
88
from truss.remote.baseten.api import BasetenApi
99

10-
# NB(nikhil): When a job ends, we poll for this many seconds after to capture
11-
# any trailing logs that contain information about errors.
12-
JOB_TERMINATION_GRACE_PERIOD_SEC = 10
1310

14-
JOB_STARTING_STATES = ["TRAINING_JOB_CREATED", "TRAINING_JOB_DEPLOYING"]
15-
JOB_RUNNING_STATES = ["TRAINING_JOB_RUNNING"]
16-
JOB_ENDED_STATES = [
17-
"TRAINING_JOB_COMPLETED",
18-
"TRAINING_JOB_FAILED",
19-
"TRAINING_JOB_STOPPED",
20-
]
21-
22-
23-
class TrainingLogWatcher(LogWatcher):
11+
class TrainingLogWatcher(TrainingPollerMixin, LogWatcher):
2412
project_id: str
2513
job_id: str
26-
_poll_stop_time: Optional[int] = None
27-
_current_status: Optional[str] = None
2814

2915
def __init__(
3016
self,
@@ -33,60 +19,19 @@ def __init__(
3319
job_id: str,
3420
console: rich_console.Console,
3521
):
36-
super().__init__(api, console)
37-
self.project_id = project_id
38-
self.job_id = job_id
39-
# register siging handler that instructs user on how to stop the job
22+
TrainingPollerMixin.__init__(self, api, project_id, job_id, console)
23+
LogWatcher.__init__(self, api, console)
24+
# registering the sigint allows us to provide messaging on next steps
4025
signal.signal(signal.SIGINT, self._handle_sigint)
4126

4227
def _handle_sigint(self, signum: int, frame: Any) -> None:
43-
msg = f"\n\nExiting training job logs. To stop the job, run `truss train stop --project-id {self.project_id} --job-id {self.job_id}`"
28+
msg = f"\n\nExiting training job logs. To stop the job, run `truss train stop --job-id {self.job_id}`"
4429
self.console.print(msg, style="yellow")
4530
raise KeyboardInterrupt()
4631

47-
def _get_current_job_status(self) -> str:
48-
job = self.api.get_training_job(self.project_id, self.job_id)
49-
return job["training_job"]["current_status"]
50-
51-
def before_polling(self) -> None:
52-
self._current_status = self._get_current_job_status()
53-
status_str = "Waiting for job to run, currently {current_status}..."
54-
with self.console.status(
55-
status_str.format(current_status=self._current_status), spinner="dots"
56-
) as status:
57-
while self._current_status in JOB_STARTING_STATES:
58-
time.sleep(POLL_INTERVAL_SEC)
59-
self._current_status = self._get_current_job_status()
60-
status.update(status_str.format(current_status=self._current_status))
61-
6232
def fetch_logs(
6333
self, start_epoch_millis: Optional[int], end_epoch_millis: Optional[int]
6434
) -> List[Any]:
6535
return self.api.get_training_job_logs(
6636
self.project_id, self.job_id, start_epoch_millis, end_epoch_millis
6737
)
68-
69-
def should_poll_again(self) -> bool:
70-
return self._current_status in JOB_RUNNING_STATES or self._poll_final_logs()
71-
72-
def post_poll(self) -> None:
73-
self._current_status = self._get_current_job_status()
74-
self._maybe_update_poll_stop_time(self._current_status)
75-
76-
def after_polling(self) -> None:
77-
if self._current_status == "TRAINING_JOB_COMPLETED":
78-
self.console.print("Training job completed successfully.", style="green")
79-
elif self._current_status == "TRAINING_JOB_FAILED":
80-
self.console.print("Training job failed.", style="red")
81-
elif self._current_status == "TRAINING_JOB_STOPPED":
82-
self.console.print("Training job stopped by user.", style="yellow")
83-
84-
def _poll_final_logs(self):
85-
if self._poll_stop_time is None:
86-
return False
87-
88-
return int(time.time()) <= self._poll_stop_time
89-
90-
def _maybe_update_poll_stop_time(self, current_status: str) -> None:
91-
if current_status not in JOB_RUNNING_STATES and self._poll_stop_time is None:
92-
self._poll_stop_time = int(time.time()) + JOB_TERMINATION_GRACE_PERIOD_SEC

truss/cli/train.py renamed to truss/cli/train/core.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from InquirerPy import inquirer
66
from rich.console import Console
77

8+
from truss.cli.train.metrics_watcher import MetricsWatcher
89
from truss.remote.baseten.remote import BasetenRemote
910

1011
ACTIVE_JOB_STATUSES = [
@@ -38,7 +39,7 @@ def get_args_for_stop(
3839
return project_id, job_id
3940

4041

41-
def get_args_for_logs(
42+
def get_args_for_monitoring(
4243
console: Console,
4344
remote_provider: BasetenRemote,
4445
project_id: Optional[str],
@@ -49,12 +50,12 @@ def get_args_for_logs(
4950
project_id=project_id, job_id=job_id
5051
)
5152
if not jobs:
52-
raise click.UsageError("Unable to get logs. No jobs found.")
53+
raise click.UsageError("No jobs found.")
5354
if len(jobs) > 1:
5455
sorted_jobs = sorted(jobs, key=lambda x: x["created_at"], reverse=True)
5556
job = sorted_jobs[0]
5657
console.print(
57-
f"Multiple jobs found. Showing logs for the most recently created job: {job['id']}",
58+
f"Multiple jobs found. Showing the most recently created job: {job['id']}",
5859
style="yellow",
5960
)
6061
else:
@@ -187,3 +188,19 @@ def stop_all_jobs(
187188
for job in active_jobs:
188189
remote_provider.api.stop_training_job(job["training_project"]["id"], job["id"])
189190
console.print("Training jobs stopped successfully.", style="green")
191+
192+
193+
def view_training_job_metrics(
194+
console: Console,
195+
remote_provider: BasetenRemote,
196+
project_id: Optional[str],
197+
job_id: Optional[str],
198+
):
199+
"""
200+
view_training_job_metrics shows a list of metrics for a training job.
201+
"""
202+
project_id, job_id = get_args_for_monitoring(
203+
console, remote_provider, project_id, job_id
204+
)
205+
metrics_display = MetricsWatcher(remote_provider.api, project_id, job_id, console)
206+
metrics_display.watch()

truss/cli/train/metrics_watcher.py

Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
import signal
2+
import time
3+
import traceback
4+
from typing import Any, Dict, List, Optional, Tuple
5+
6+
from rich.console import Console
7+
from rich.live import Live
8+
from rich.table import Table
9+
from rich.text import Text
10+
11+
from truss.cli.train.poller import TrainingPollerMixin
12+
from truss.remote.baseten.api import BasetenApi
13+
14+
METRICS_POLL_INTERVAL_SEC = 30
15+
16+
17+
class MetricsWatcher(TrainingPollerMixin):
18+
live: Optional[Live]
19+
20+
def __init__(self, api: BasetenApi, project_id: str, job_id: str, console: Console):
21+
super().__init__(api, project_id, job_id, console)
22+
23+
self.live = None
24+
signal.signal(signal.SIGINT, self._handle_sigint)
25+
26+
def _handle_sigint(self, signum: int, frame: Any) -> None:
27+
if self.live:
28+
self.live.stop()
29+
msg = f"\n\nExiting training job metrics. To stop the job, run `truss train stop --job-id {self.job_id}`"
30+
self.console.print(msg, style="yellow")
31+
raise KeyboardInterrupt()
32+
33+
def _format_bytes(self, bytes_val: float) -> Tuple[str, str]:
34+
"""Convert bytes to human readable format"""
35+
color_map = {"MB": "green", "GB": "cyan", "TB": "magenta"}
36+
unit = "MB"
37+
if bytes_val > 1024 * 1024 * 1024 * 1024:
38+
unit = "TB"
39+
elif bytes_val > 1024 * 1024 * 1024:
40+
unit = "GB"
41+
42+
if unit == "MB":
43+
return f"{bytes_val / (1024 * 1024):.2f} MB", color_map[unit]
44+
elif unit == "GB":
45+
return f"{bytes_val / (1024 * 1024 * 1024):.2f} GB", color_map[unit]
46+
return f"{bytes_val:.2f} bytes", color_map[unit]
47+
48+
def _get_latest_metric(self, metrics: List[Dict]) -> Optional[float]:
49+
"""Get the most recent metric value"""
50+
if not metrics:
51+
return None
52+
return metrics[-1].get("value")
53+
54+
def create_metrics_table(self, metrics_data: Dict) -> Table:
55+
"""Create a Rich table with the metrics"""
56+
table = Table(title="Training Job Metrics")
57+
table.add_column("Metric")
58+
table.add_column("Value")
59+
60+
# Add timestamp if available
61+
cpu_usage_data = metrics_data.get("cpu_usage", [])
62+
if cpu_usage_data and len(cpu_usage_data) > 0:
63+
latest_timestamp = cpu_usage_data[-1].get("timestamp")
64+
if latest_timestamp:
65+
table.add_row("Timestamp", latest_timestamp)
66+
table.add_section()
67+
68+
# CPU metrics
69+
cpu_usage = self._get_latest_metric(metrics_data.get("cpu_usage", []))
70+
if cpu_usage is not None:
71+
table.add_row("CPU Usage", f"{cpu_usage:.2f} cores")
72+
73+
cpu_memory = self._get_latest_metric(
74+
metrics_data.get("cpu_memory_usage_bytes", [])
75+
)
76+
if cpu_memory is not None:
77+
formatted_value, color = self._format_bytes(cpu_memory)
78+
table.add_row("CPU Memory", Text(formatted_value, style=color))
79+
80+
# Add separator after CPU metrics
81+
table.add_section()
82+
83+
# GPU metrics - grouped by GPU ID
84+
gpu_metrics = metrics_data.get("gpu_utilization", {})
85+
gpu_memory = metrics_data.get("gpu_memory_usage_bytes", {})
86+
87+
for gpu_id in sorted(set(gpu_metrics.keys()) | set(gpu_memory.keys())):
88+
# Add GPU utilization
89+
latest_util = self._get_latest_metric(gpu_metrics.get(gpu_id, []))
90+
if latest_util is not None:
91+
table.add_row(f"GPU {gpu_id} Usage", f"{latest_util * 100:.1f}%")
92+
93+
# Add GPU memory right after its utilization
94+
latest_memory = self._get_latest_metric(gpu_memory.get(gpu_id, []))
95+
if latest_memory is not None:
96+
formatted_value, color = self._format_bytes(latest_memory)
97+
table.add_row(
98+
f"GPU {gpu_id} Memory", Text(formatted_value, style=color)
99+
)
100+
101+
# Add separator after each GPU's metrics (except for the last one)
102+
if gpu_id != max(set(gpu_metrics.keys()) | set(gpu_memory.keys())):
103+
table.add_section()
104+
105+
# Add separator before storage metrics
106+
if gpu_metrics or gpu_memory:
107+
table.add_section()
108+
109+
return table
110+
111+
def watch(self, refresh_rate: int = METRICS_POLL_INTERVAL_SEC):
112+
"""Display continuously updating metrics"""
113+
self.before_polling()
114+
with Live(auto_refresh=False) as live:
115+
self.live = live
116+
while True:
117+
# our first instance of fetching metrics passes no explicit time range. We do this so that we can fetch metrics
118+
# for inactive jobs, using the job's completion time to set the time range.
119+
# Subsequent queries will fetch only the most recent data to avoid unnecessary load on VM
120+
metrics = self.api.get_training_job_metrics(
121+
self.project_id, self.job_id
122+
)
123+
try:
124+
# range of one minute since we only want the last recording
125+
table = self.create_metrics_table(metrics)
126+
live.update(table, refresh=True)
127+
if not self.should_poll_again():
128+
live.stop()
129+
break
130+
time.sleep(refresh_rate)
131+
end_epoch_millis = int(time.time() * 1000)
132+
start_epoch_millis = end_epoch_millis - 60 * 1000
133+
metrics = self.api.get_training_job_metrics(
134+
self.project_id,
135+
self.job_id,
136+
end_epoch_millis=end_epoch_millis,
137+
start_epoch_millis=start_epoch_millis,
138+
)
139+
self.post_poll()
140+
except Exception as e:
141+
live.stop()
142+
self.console.print(
143+
f"Error fetching metrics: {e}: {traceback.format_exc()}",
144+
style="red",
145+
)
146+
break
147+
self.after_polling()

0 commit comments

Comments
 (0)