Skip to content

Commit 221eb14

Browse files
[ODSC-66852] Watch aqua FT job logs after creation (#1047)
2 parents 2cf5b51 + ef86c0e commit 221eb14

File tree

4 files changed

+61
-3
lines changed

4 files changed

+61
-3
lines changed

ads/aqua/finetuning/entities.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,8 @@ class CreateFineTuningDetails(Serializable):
122122
The log group id for fine tuning job infrastructure.
123123
log_id: (str, optional). Defaults to `None`.
124124
The log id for fine tuning job infrastructure.
125+
watch_logs: (bool, optional). Defaults to `False`.
126+
The flag to watch the job run logs when a fine-tuning job is created.
125127
force_overwrite: (bool, optional). Defaults to `False`.
126128
Whether to force overwrite the existing file in object storage.
127129
freeform_tags: (dict, optional)
@@ -148,6 +150,7 @@ class CreateFineTuningDetails(Serializable):
148150
subnet_id: Optional[str] = None
149151
log_id: Optional[str] = None
150152
log_group_id: Optional[str] = None
153+
watch_logs: Optional[bool] = False
151154
force_overwrite: Optional[bool] = False
152155
freeform_tags: Optional[dict] = None
153156
defined_tags: Optional[dict] = None

ads/aqua/finetuning/finetuning.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44

55
import json
66
import os
7+
import time
8+
import traceback
79
from typing import Dict
810

911
from oci.data_science.models import (
@@ -149,6 +151,15 @@ def create(
149151
f"Logging is required for fine tuning if replica is larger than {DEFAULT_FT_REPLICA}."
150152
)
151153

154+
if create_fine_tuning_details.watch_logs and not (
155+
create_fine_tuning_details.log_id
156+
and create_fine_tuning_details.log_group_id
157+
):
158+
raise AquaValueError(
159+
"Logging is required for fine tuning if watch_logs is set to True. "
160+
"Please provide log_id and log_group_id with the request parameters."
161+
)
162+
152163
ft_parameters = self._get_finetuning_params(
153164
create_fine_tuning_details.ft_parameters
154165
)
@@ -422,6 +433,20 @@ def create(
422433
value=source.display_name,
423434
)
424435

436+
if create_fine_tuning_details.watch_logs:
437+
logger.info(
438+
f"Watching fine-tuning job run logs for {ft_job_run.id}. Press Ctrl+C to stop watching logs.\n"
439+
)
440+
try:
441+
ft_job_run.watch()
442+
except KeyboardInterrupt:
443+
logger.info(f"\nStopped watching logs for {ft_job_run.id}.\n")
444+
time.sleep(1)
445+
except Exception:
446+
logger.debug(
447+
f"Something unexpected occurred while watching logs.\n{traceback.format_exc()}"
448+
)
449+
425450
return AquaFineTuningSummary(
426451
id=ft_model.id,
427452
name=ft_model.display_name,

ads/cli.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
11
#!/usr/bin/env python
2-
32
# Copyright (c) 2021, 2025 Oracle and/or its affiliates.
43
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
54

5+
import json
66
import logging
77
import sys
88
import traceback
99
import uuid
1010

1111
import fire
12+
from pydantic import BaseModel
1213

1314
from ads.common import logger
1415

@@ -84,7 +85,13 @@ def serialize(data):
8485
The string representation of each dataclass object.
8586
"""
8687
if isinstance(data, list):
87-
[print(str(item)) for item in data]
88+
for item in data:
89+
if isinstance(item, BaseModel):
90+
print(json.dumps(item.dict(), indent=4))
91+
else:
92+
print(str(item))
93+
elif isinstance(data, BaseModel):
94+
print(json.dumps(data.dict(), indent=4))
8895
else:
8996
print(str(data))
9097

tests/unitary/with_extras/aqua/test_finetuning.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from ads.aqua.finetuning.constants import FineTuneCustomMetadata
2323
from ads.aqua.finetuning.entities import AquaFineTuningParams
2424
from ads.jobs.ads_job import Job
25+
from ads.jobs.builders.infrastructure.dsc_job import DataScienceJobRun
2526
from ads.model.datascience_model import DataScienceModel
2627
from ads.model.model_metadata import ModelCustomMetadata
2728
from ads.aqua.common.errors import AquaValueError
@@ -49,6 +50,12 @@ def tearDownClass(cls):
4950
reload(ads.aqua)
5051
reload(ads.aqua.finetuning.finetuning)
5152

53+
@parameterized.expand(
54+
[
55+
("watch_logs", True),
56+
("watch_logs", False),
57+
]
58+
)
5259
@patch.object(Job, "run")
5360
@patch("ads.jobs.ads_job.Job.name", new_callable=PropertyMock)
5461
@patch("ads.jobs.ads_job.Job.id", new_callable=PropertyMock)
@@ -60,6 +67,8 @@ def tearDownClass(cls):
6067
@patch.object(AquaApp, "get_source")
6168
def test_create_fine_tuning(
6269
self,
70+
mock_watch_logs,
71+
mock_watch_logs_called,
6372
mock_get_source,
6473
mock_mvs_create,
6574
mock_ds_model_create,
@@ -117,6 +126,7 @@ def test_create_fine_tuning(
117126
ft_job_run.id = "test_ft_job_run_id"
118127
ft_job_run.lifecycle_details = "Job run artifact execution in progress."
119128
ft_job_run.lifecycle_state = "IN_PROGRESS"
129+
ft_job_run.watch = MagicMock()
120130
mock_job_run.return_value = ft_job_run
121131

122132
self.app.ds_client.update_model = MagicMock()
@@ -144,7 +154,20 @@ def test_create_fine_tuning(
144154
defined_tags=ft_model_defined_tags,
145155
)
146156

147-
aqua_ft_summary = self.app.create(**create_aqua_ft_details)
157+
inputs = {
158+
**create_aqua_ft_details,
159+
**{
160+
mock_watch_logs: mock_watch_logs_called,
161+
"log_id": "test_log_id",
162+
"log_group_id": "test_log_group_id",
163+
},
164+
}
165+
aqua_ft_summary = self.app.create(**inputs)
166+
167+
if mock_watch_logs_called:
168+
ft_job_run.watch.assert_called()
169+
else:
170+
ft_job_run.watch.assert_not_called()
148171

149172
assert aqua_ft_summary.to_dict() == {
150173
"console_url": f"https://cloud.oracle.com/data-science/models/{ft_model.id}?region={self.app.region}",

0 commit comments

Comments
 (0)