Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 28 additions & 8 deletions kubeflow/trainer/api/trainer_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

import logging
from typing import Optional, Union
from typing import Optional, Union, Iterator

from kubeflow.trainer.constants import constants
from kubeflow.trainer.types import types
Expand Down Expand Up @@ -120,8 +120,7 @@ def list_jobs(self, runtime: Optional[types.Runtime] = None) -> list[types.Train
runtime: Reference to one of the existing runtimes.

Returns:
List: List of created TrainJobs.
If no TrainJob exist, an empty list is returned.
List of created TrainJobs. If no TrainJob exist, an empty list is returned.

Raises:
TimeoutError: Timeout to list TrainJobs.
Expand All @@ -148,12 +147,33 @@ def get_job(self, name: str) -> types.TrainJob:
def get_job_logs(
self,
name: str,
step: str = constants.NODE + "-0",
follow: Optional[bool] = False,
step: str = constants.NODE,
node_rank: int = 0,
) -> dict[str, str]:
"""Get the logs from TrainJob"""
return self.backend.get_job_logs(name=name, follow=follow, step=step, node_rank=node_rank)
) -> Iterator[str]:
"""Get logs from a specific step of a TrainJob.

You can watch for the logs in realtime as follows:
```python
from kubeflow.trainer import TrainerClient

for logline in TrainerClient().get_job_logs(name="s8d44aa4fb6d", follow=True):
print(logline)
```

Args:
name: Name of the TrainJob.
step: Step of the TrainJob to collect logs from, like dataset-initializer or node-0.
follow: Whether to stream logs in realtime as they produced.

Returns:
Iterator of log lines.


Raises:
TimeoutError: Timeout to get a TrainJob.
RuntimeError: Failed to get a TrainJob.
"""
return self.backend.get_job_logs(name=name, follow=follow, step=step)

def wait_for_job_status(
self,
Expand Down
7 changes: 3 additions & 4 deletions kubeflow/trainer/backends/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import abc

from typing import Optional, Union
from typing import Optional, Union, Iterator
from kubeflow.trainer.constants import constants
from kubeflow.trainer.types import types

Expand Down Expand Up @@ -47,9 +47,8 @@ def get_job_logs(
self,
name: str,
follow: Optional[bool] = False,
step: str = constants.NODE,
node_rank: int = 0,
) -> dict[str, str]:
step: str = constants.NODE + "-0",
) -> Iterator[str]:
raise NotImplementedError()

def wait_for_job_status(
Expand Down
94 changes: 25 additions & 69 deletions kubeflow/trainer/backends/kubernetes/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,12 @@
import copy
import logging
import multiprocessing
import queue
import random
import string
import time
import uuid
from typing import Optional, Union
from typing import Optional, Union, Iterator
import re

from kubeflow.trainer.constants import constants
from kubeflow.trainer.types import types
Expand Down Expand Up @@ -173,7 +173,7 @@ def print_packages():
)

self.wait_for_job_status(job_name)
print(self.get_job_logs(job_name)["node-0"])
print(self.get_job_logs(job_name))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will print a generator object, not logs, right?
I think we should instead do:

for line in self.get_job_logs(job_name):
    print(line, end="")

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We also need to update other references, for example README

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would print(*self.get_job_logs(job_name), sep="\n") work?

Otherwise, would there be a way to override the string representation of the returned iterator?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm it could, but I think we might run out of memory :) And there'll be no streaming cause we have to wait for iterator to finish, which works against the purpose of iterator

Why don't just use this?

for line in self.get_job_logs(job_name):
    print(line, end="")

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree, and prefer the streaming approach, it was really to get a one liner :)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm it could, but I think we might run out of memory :)

Why we run out of memory ? Since we just print the pip list + nvidia-smi the log would be small.

Users can do something like this if they don't want to define loop:

print("\n".join(TrainerClient().get_job_logs(name=job_id)))

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why we run out of memory ? Since we just print the pip list + nvidia-smi the log would be small.

Yeah, I meant when logs are large

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alternatively we could have two APIs -- one for streaming and another for returning a complete string, similar to what Ray does

@andreyvelich @astefanutti any thoughts on that?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would suggest that we consolidate it in the single get_job_logs() API, similar to how we consolidate BuiltinTrainer and CustomTrainer into train() API.
I don't see much value to separate them, since it is better to return Iterator[str] for both.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SGTM

self.delete_job(job_name)

def train(
Expand Down Expand Up @@ -328,92 +328,48 @@ def get_job_logs(
self,
name: str,
follow: Optional[bool] = False,
step: str = constants.NODE,
node_rank: int = 0,
) -> dict[str, str]:
"""Get the logs from TrainJob"""

step: str = constants.NODE + "-0",
) -> Iterator[str]:
"""Get the TrainJob logs"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need this docstring here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not necessary, but this is just reminder how this API is used for developers and AI tools 🙂

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, sounds good to me!

# Get the TrainJob Pod name.
pod_name = None
for c in self.get_job(name).steps:
if c.status != constants.POD_PENDING:
if c.name == step or c.name == f"{step}-{node_rank}":
pod_name = c.pod_name
if c.status != constants.POD_PENDING and c.name == step:
pod_name = c.pod_name
break
if pod_name is None:
return {}

# Dict where key is the Pod type and value is the Pod logs.
logs_dict = {}

# TODO (andreyvelich): Potentially, refactor this.
# Support logging of multiple Pods.
# TODO (andreyvelich): Currently, follow is supported only for node container.
if follow and step == constants.NODE:
log_streams = []
log_streams.append(
watch.Watch().stream(
return iter([])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shall we raise RuntimeError or log a warning in this case?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if we should raise an Exception here, since at this stage TrainJob is not yet produced logs, so we just return empty logs.
@astefanutti thoughts ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, makes sense to me to return an empty iterator then


try:
if follow:
log_stream = watch.Watch().stream(
self.core_api.read_namespaced_pod_log,
name=pod_name,
namespace=self.namespace,
container=constants.NODE,
container=re.sub(r"-\d+$", "", step), # Remove the number for the node step.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: we call this twice, cold just count once and reuse it

follow=True,
)
)
finished = [False] * len(log_streams)

# Create thread and queue per stream, for non-blocking iteration.
log_queue_pool = utils.get_log_queue_pool(log_streams)

# Iterate over every watching pods' log queue
while True:
for index, log_queue in enumerate(log_queue_pool):
if all(finished):
# Stream logs incrementally
for logline in log_stream:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we sure each item are entire lines?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, log_stream is <generator object Watch.stream at 0x1073c6340> object, and we can return it line by line.

Copy link
Contributor

@astefanutti astefanutti Sep 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've checked and it indeed does yield items line by line: ttps://github.com/kubernetes-client/python/blob/6e7c539f52dec4e993d2c32a4408920d8522f47e/kubernetes/base/watch/watch.py#L54-L83

I wasn't sure whether we had to do it ourselves or not.

if logline is None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this actually yield None?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think, we can just do this:

  if logline is None:
      return

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, IIUC it never yields None https://github.com/kubernetes-client/python/blob/master/kubernetes/base/watch/watch.py#L213-L216

We could just do

for logline in log_stream:
    yield logline

break
if finished[index]:
continue
# grouping the every 50 log lines of the same pod.
for _ in range(50):
try:
logline = log_queue.get(timeout=1)
if logline is None:
finished[index] = True
break
# Print logs to the StdOut and update results dict.
print(f"[{step}-{node_rank}]: {logline}")
logs_dict[f"{step}-{node_rank}"] = (
logs_dict.get(f"{step}-{node_rank}", "") + logline + "\n"
)
except queue.Empty:
break
if all(finished):
return logs_dict

try:
if step == constants.DATASET_INITIALIZER:
logs_dict[constants.DATASET_INITIALIZER] = self.core_api.read_namespaced_pod_log(
name=pod_name,
namespace=self.namespace,
container=constants.DATASET_INITIALIZER,
)
elif step == constants.MODEL_INITIALIZER:
logs_dict[constants.MODEL_INITIALIZER] = self.core_api.read_namespaced_pod_log(
name=pod_name,
namespace=self.namespace,
container=constants.MODEL_INITIALIZER,
)
yield logline # type:ignore
else:
logs_dict[f"{step}-{node_rank}"] = self.core_api.read_namespaced_pod_log(
logs = self.core_api.read_namespaced_pod_log(
name=pod_name,
namespace=self.namespace,
container=constants.NODE,
container=re.sub(r"-\d+$", "", step), # Remove the number for the node step.
)

for line in logs.splitlines():
yield line

except Exception as e:
raise RuntimeError(
f"Failed to read logs for the pod {self.namespace}/{pod_name}"
) from e

return logs_dict

def wait_for_job_status(
self,
name: str,
Expand Down
12 changes: 6 additions & 6 deletions kubeflow/trainer/backends/kubernetes/backend_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -917,14 +917,12 @@ def test_list_jobs(trainer_client, test_case):
name="valid flow with all defaults",
expected_status=SUCCESS,
config={"name": BASIC_TRAIN_JOB_NAME},
expected_output={
"node-0": "test log content",
},
expected_output=["test log content"],
),
TestCase(
name="runtime error when getting logs",
expected_status=FAILED,
config={"name": RUNTIME},
config={"name": BASIC_TRAIN_JOB_NAME, "namespace": FAIL_LOGS},
expected_error=RuntimeError,
),
],
Expand All @@ -933,10 +931,12 @@ def test_get_job_logs(trainer_client, test_case):
"""Test TrainerClient.get_job_logs with basic success path."""
print("Executing test:", test_case.name)
try:
trainer_client.namespace = test_case.config.get("namespace", DEFAULT_NAMESPACE)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be from backend, right?

Suggested change
trainer_client.namespace = test_case.config.get("namespace", DEFAULT_NAMESPACE)
trainer_client.backend.namespace = test_case.config.get("namespace", DEFAULT_NAMESPACE)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, Trainer Client here is the Kubernetes backend, not TrainerClient()

yield KubernetesBackend(KubernetesBackendConfig())

Let me rename it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ahh I see, thank you!

logs = trainer_client.get_job_logs(test_case.config.get("name"))
# Convert iterator to list for comparison.
logs_list = list(logs)
assert test_case.expected_status == SUCCESS
assert logs == test_case.expected_output

assert logs_list == test_case.expected_output
except Exception as e:
assert type(e) is test_case.expected_error
print("test execution complete")
Expand Down
21 changes: 0 additions & 21 deletions kubeflow/trainer/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,7 @@

import inspect
import os
import queue
import textwrap
import threading
from typing import Any, Callable, Optional
from urllib.parse import urlparse

Expand Down Expand Up @@ -571,22 +569,3 @@ def get_model_initializer(
)

return model_initializer


def wrap_log_stream(q: queue.Queue, log_stream: Any):
while True:
try:
logline = next(log_stream)
q.put(logline)
except StopIteration:
q.put(None)
return


def get_log_queue_pool(log_streams: list[Any]) -> list[queue.Queue]:
pool = []
for log_stream in log_streams:
q = queue.Queue(maxsize=100)
pool.append(q)
threading.Thread(target=wrap_log_stream, args=(q, log_stream)).start()
return pool
Loading