Skip to content

Commit 07586c8

Browse files
Merge pull request #1646 from basetenlabs/bump-version-0.9.91
Release 0.9.91
2 parents 4f159ef + 89d7637 commit 07586c8

File tree

11 files changed

+57
-51
lines changed

11 files changed

+57
-51
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "truss"
3-
version = "0.9.90"
3+
version = "0.9.91"
44
description = "A seamless bridge from model development to model delivery"
55
license = "MIT"
66
readme = "README.md"

truss-train/tests/import/deploy_checkpoints_config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
"HF_TOKEN": definitions.SecretReference(name="hf_access_token")
1313
}
1414
),
15-
checkpoint_details=definitions.CheckpointDetails(
15+
checkpoint_details=definitions.CheckpointList(
1616
base_model_id="unsloth/gemma-3-1b-it",
1717
checkpoints=[
1818
definitions.Checkpoint(

truss-train/tests/import/project_with_deploy_checkpoints_config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
"HF_TOKEN": definitions.SecretReference(name="hf_access_token")
1313
}
1414
),
15-
checkpoint_details=definitions.CheckpointDetails(
15+
checkpoint_details=definitions.CheckpointList(
1616
base_model_id="unsloth/gemma-3-1b-it",
1717
checkpoints=[
1818
definitions.Checkpoint(

truss-train/truss_train/definitions.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -87,14 +87,14 @@ def to_truss_config(self) -> truss_config.Checkpoint:
8787
)
8888

8989

90-
class CheckpointDetails(custom_types.SafeModel):
90+
class CheckpointList(custom_types.SafeModel):
9191
download_folder: str = truss_config.DEFAULT_TRAINING_CHECKPOINT_FOLDER
9292
base_model_id: Optional[str] = None
9393
checkpoints: List[Checkpoint] = []
9494

95-
def to_truss_config(self) -> truss_config.CheckpointConfiguration:
95+
def to_truss_config(self) -> truss_config.CheckpointList:
9696
checkpoints = [checkpoint.to_truss_config() for checkpoint in self.checkpoints]
97-
return truss_config.CheckpointConfiguration(
97+
return truss_config.CheckpointList(
9898
checkpoints=checkpoints, download_folder=self.download_folder
9999
)
100100

@@ -104,7 +104,7 @@ class DeployCheckpointsRuntime(custom_types.SafeModel):
104104

105105

106106
class DeployCheckpointsConfig(custom_types.SafeModel):
107-
checkpoint_details: Optional[CheckpointDetails] = None
107+
checkpoint_details: Optional[CheckpointList] = None
108108
model_name: Optional[str] = None
109109
deployment_name: Optional[str] = None
110110
runtime: Optional[DeployCheckpointsRuntime] = None

truss/base/truss_config.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -545,7 +545,7 @@ class Checkpoint(custom_types.ConfigModel):
545545
name: str
546546

547547

548-
class CheckpointConfiguration(custom_types.ConfigModel):
548+
class CheckpointList(custom_types.ConfigModel):
549549
download_folder: str = DEFAULT_TRAINING_CHECKPOINT_FOLDER
550550
checkpoints: list[Checkpoint] = pydantic.Field(default_factory=list)
551551

@@ -586,7 +586,7 @@ class TrussConfig(custom_types.ConfigModel):
586586
trt_llm: Optional[trt_llm_config.TRTLLMConfiguration] = None
587587

588588
# deploying from checkpoint
589-
training_checkpoints: Optional[CheckpointConfiguration] = None
589+
training_checkpoints: Optional[CheckpointList] = None
590590

591591
# Internal / Legacy.
592592
input_type: str = "Any"

truss/cli/cli.py

Lines changed: 6 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -361,7 +361,6 @@ def cleanup() -> None:
361361
@click.option("-n", "--name", type=click.STRING)
362362
@click.option(
363363
"--python-config/--no-python-config",
364-
type=bool,
365364
default=False,
366365
help="Uses the code first tooling to build models.",
367366
)
@@ -453,7 +452,6 @@ def _extract_request_data(data: Optional[str], file: Optional[Path]):
453452
)
454453
@click.option(
455454
"--published",
456-
type=bool,
457455
is_flag=True,
458456
required=False,
459457
default=False,
@@ -580,7 +578,6 @@ def run_python(script, target_directory):
580578
@click.option("--model-name", type=str, required=False, help="Name of the model")
581579
@click.option(
582580
"--publish",
583-
type=bool,
584581
is_flag=True,
585582
required=False,
586583
default=False,
@@ -592,7 +589,6 @@ def run_python(script, target_directory):
592589
)
593590
@click.option(
594591
"--promote",
595-
type=bool,
596592
is_flag=True,
597593
required=False,
598594
default=False,
@@ -613,7 +609,6 @@ def run_python(script, target_directory):
613609
)
614610
@click.option(
615611
"--preserve-previous-production-deployment",
616-
type=bool,
617612
is_flag=True,
618613
required=False,
619614
default=False,
@@ -625,15 +620,13 @@ def run_python(script, target_directory):
625620
)
626621
@click.option(
627622
"--trusted",
628-
type=bool,
629623
is_flag=True,
630624
required=False,
631625
default=None,
632626
help="[DEPRECATED] All models are trusted by default.",
633627
)
634628
@click.option(
635629
"--disable-truss-download",
636-
type=bool,
637630
is_flag=True,
638631
required=False,
639632
default=False,
@@ -650,7 +643,6 @@ def run_python(script, target_directory):
650643
)
651644
@click.option(
652645
"--wait/--no-wait",
653-
type=bool,
654646
is_flag=True,
655647
required=False,
656648
default=False,
@@ -667,16 +659,14 @@ def run_python(script, target_directory):
667659
)
668660
@click.option(
669661
"--include-git-info",
670-
type=bool,
671662
is_flag=True,
672663
required=False,
673664
default=False,
674665
help=_INCLUDE_GIT_INFO_DOC,
675666
)
676-
@click.option("--tail", type=bool, is_flag=True)
667+
@click.option("--tail", is_flag=True)
677668
@click.option(
678669
"--preserve-env-instance-type/--no-preserve-env-instance-type",
679-
type=bool,
680670
is_flag=True,
681671
required=False,
682672
default=None,
@@ -879,7 +869,7 @@ def push(
879869
@click.option("--remote", type=str, required=False)
880870
@click.option("--model-id", type=str, required=True)
881871
@click.option("--deployment-id", type=str, required=True)
882-
@click.option("--tail", type=bool, is_flag=True, help="Tail for ongoing logs.")
872+
@click.option("--tail", is_flag=True, help="Tail for ongoing logs.")
883873
@common_options()
884874
def model_logs(
885875
remote: Optional[str], model_id: str, deployment_id: str, tail: bool = False
@@ -1212,13 +1202,11 @@ def _create_chains_table(service) -> Tuple[rich.table.Table, List[str]]:
12121202
)
12131203
@click.option(
12141204
"--publish/--no-publish",
1215-
type=bool,
12161205
default=False,
12171206
help="Create chainlets as published deployments.",
12181207
)
12191208
@click.option(
12201209
"--promote/--no-promote",
1221-
type=bool,
12221210
default=False,
12231211
help="Replace production chainlets with newly deployed chainlets.",
12241212
)
@@ -1233,13 +1221,11 @@ def _create_chains_table(service) -> Tuple[rich.table.Table, List[str]]:
12331221
)
12341222
@click.option(
12351223
"--wait/--no-wait",
1236-
type=bool,
12371224
default=True,
12381225
help="Wait until all chainlets are ready (or deployment failed).",
12391226
)
12401227
@click.option(
12411228
"--watch/--no-watch",
1242-
type=bool,
12431229
default=False,
12441230
help=(
12451231
"Watches the chains source code and applies live patches. Using this option "
@@ -1250,7 +1236,6 @@ def _create_chains_table(service) -> Tuple[rich.table.Table, List[str]]:
12501236
)
12511237
@click.option(
12521238
"--dryrun",
1253-
type=bool,
12541239
default=False,
12551240
is_flag=True,
12561241
help="Produces only generated files, but doesn't deploy anything.",
@@ -1274,7 +1259,6 @@ def _create_chains_table(service) -> Tuple[rich.table.Table, List[str]]:
12741259
)
12751260
@click.option(
12761261
"--include-git-info",
1277-
type=bool,
12781262
is_flag=True,
12791263
required=False,
12801264
default=False,
@@ -1554,9 +1538,7 @@ def train():
15541538
@train.command(name="push")
15551539
@click.argument("config", type=Path, required=True)
15561540
@click.option("--remote", type=str, required=False, help="Remote to use")
1557-
@click.option(
1558-
"--tail", type=bool, is_flag=True, help="Tail for status + logs after push."
1559-
)
1541+
@click.option("--tail", is_flag=True, help="Tail for status + logs after push.")
15601542
@common_options()
15611543
def push_training_job(config: Path, remote: Optional[str], tail: bool):
15621544
"""Run a training job"""
@@ -1600,7 +1582,7 @@ def push_training_job(config: Path, remote: Optional[str], tail: bool):
16001582
@click.option("--remote", type=str, required=False, help="Remote to use")
16011583
@click.option("--project-id", type=str, required=False, help="Project ID.")
16021584
@click.option("--job-id", type=str, required=False, help="Job ID.")
1603-
@click.option("--tail", type=bool, is_flag=True, help="Tail for ongoing logs.")
1585+
@click.option("--tail", is_flag=True, help="Tail for ongoing logs.")
16041586
@common_options()
16051587
def get_job_logs(
16061588
remote: Optional[str], project_id: Optional[str], job_id: Optional[str], tail: bool
@@ -1632,7 +1614,7 @@ def get_job_logs(
16321614
@train.command(name="stop")
16331615
@click.option("--project-id", type=str, required=False, help="Project ID.")
16341616
@click.option("--job-id", type=str, required=False, help="Job ID.")
1635-
@click.option("--all", type=bool, is_flag=True, help="Stop all running jobs.")
1617+
@click.option("--all", is_flag=True, help="Stop all running jobs.")
16361618
@click.option("--remote", type=str, required=False, help="Remote to use")
16371619
@common_options()
16381620
def stop_job(
@@ -1708,10 +1690,7 @@ def get_job_metrics(
17081690
help="path to a python file that defines a DeployCheckpointsConfig",
17091691
)
17101692
@click.option(
1711-
"--dry-run",
1712-
type=bool,
1713-
is_flag=True,
1714-
help="Generate a truss config without deploying",
1693+
"--dry-run", is_flag=True, help="Generate a truss config without deploying"
17151694
)
17161695
@click.option("--remote", type=str, required=False, help="Remote to use")
17171696
@common_options()

truss/cli/train/deploy_checkpoints.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from truss_train.definitions import (
2222
DEFAULT_LORA_RANK,
2323
Checkpoint,
24-
CheckpointDetails,
24+
CheckpointList,
2525
Compute,
2626
DeployCheckpointsConfig,
2727
DeployCheckpointsRuntime,
@@ -156,10 +156,10 @@ def _render_vllm_lora_truss_config(
156156
def _get_checkpoint_details(
157157
console: Console,
158158
remote_provider: BasetenRemote,
159-
checkpoint_details: Optional[CheckpointDetails],
159+
checkpoint_details: Optional[CheckpointList],
160160
project_id: Optional[str],
161161
job_id: Optional[str],
162-
) -> CheckpointDetails:
162+
) -> CheckpointList:
163163
if checkpoint_details and checkpoint_details.checkpoints:
164164
return _process_user_provided_checkpoints(checkpoint_details, remote_provider)
165165
else:
@@ -171,16 +171,16 @@ def _get_checkpoint_details(
171171
def _prompt_user_for_checkpoint_details(
172172
console: Console,
173173
remote_provider: BasetenRemote,
174-
checkpoint_details: Optional[CheckpointDetails],
174+
checkpoint_details: Optional[CheckpointList],
175175
project_id: Optional[str],
176176
job_id: Optional[str],
177-
) -> CheckpointDetails:
177+
) -> CheckpointList:
178178
project_id, job_id = get_most_recent_job(
179179
console, remote_provider, project_id, job_id
180180
)
181181
response_checkpoints = _fetch_checkpoints(remote_provider, project_id, job_id)
182182
if not checkpoint_details:
183-
checkpoint_details = CheckpointDetails()
183+
checkpoint_details = CheckpointList()
184184

185185
# first, gather all checkpoint ids the user wants to deploy
186186
# allow the user to select a checkpoint
@@ -196,8 +196,8 @@ def _prompt_user_for_checkpoint_details(
196196

197197

198198
def _process_user_provided_checkpoints(
199-
checkpoint_details: CheckpointDetails, remote_provider: BasetenRemote
200-
) -> CheckpointDetails:
199+
checkpoint_details: CheckpointList, remote_provider: BasetenRemote
200+
) -> CheckpointList:
201201
# check if the user-provided checkpoint details are valid. Fill in missing values.
202202
checkpoints_by_training_job_id = {}
203203
for checkpoint in checkpoint_details.checkpoints:
@@ -342,9 +342,10 @@ def _get_runtime(
342342
hf_secret_name = _get_hf_secret_name(
343343
console, runtime.environment_variables.get(HF_TOKEN_ENVVAR_NAME)
344344
)
345-
runtime.environment_variables[HF_TOKEN_ENVVAR_NAME] = SecretReference(
346-
name=hf_secret_name
347-
)
345+
if hf_secret_name:
346+
runtime.environment_variables[HF_TOKEN_ENVVAR_NAME] = SecretReference(
347+
name=hf_secret_name
348+
)
348349
return runtime
349350

350351

truss/cli/train/types.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from typing import Optional
44

55
from truss_train.definitions import (
6-
CheckpointDetails,
6+
CheckpointList,
77
Compute,
88
DeployCheckpointsConfig,
99
DeployCheckpointsRuntime,
@@ -23,7 +23,7 @@ class DeployCheckpointsConfigComplete(DeployCheckpointsConfig):
2323
removes the optional fileds. This helps provide type safety internal handling.
2424
"""
2525

26-
checkpoint_details: CheckpointDetails
26+
checkpoint_details: CheckpointList
2727
model_name: str
2828
deployment_name: str
2929
runtime: DeployCheckpointsRuntime

truss/remote/baseten/api.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,8 @@ def _post_graphql_query(self, query: str, variables: Optional[dict] = None) -> d
109109
if errors:
110110
message = errors[0]["message"]
111111
error_code = errors[0].get("extensions", {}).get("code")
112+
if errors[0].get("extensions", {}).get("description") is not None:
113+
message = errors[0].get("extensions", {}).get("description")
112114
raise ApiError(message, error_code)
113115

114116
return resp_dict

truss/tests/cli/train/test_deploy_checkpoints.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def deploy_checkpoints_mock_checkbox(create_mock_prompt):
8383

8484
def test_render_vllm_lora_truss_config():
8585
deploy_config = DeployCheckpointsConfigComplete(
86-
checkpoint_details=definitions.CheckpointDetails(
86+
checkpoint_details=definitions.CheckpointList(
8787
checkpoints=[
8888
definitions.Checkpoint(
8989
id="checkpoint-1",
@@ -181,7 +181,7 @@ def test_prepare_checkpoint_deploy_complete_config(
181181
):
182182
# Create complete config with all fields specified
183183
complete_config = definitions.DeployCheckpointsConfig(
184-
checkpoint_details=definitions.CheckpointDetails(
184+
checkpoint_details=definitions.CheckpointList(
185185
checkpoints=[
186186
definitions.Checkpoint(
187187
id="checkpoint-1",

truss/tests/remote/baseten/test_api.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,19 @@ def mock_graphql_error_response():
3434
return response
3535

3636

37+
def mock_graphql_error_response_with_description():
38+
response = Response()
39+
response.status_code = 200
40+
response.json = mock.Mock(
41+
return_value={
42+
"errors": [
43+
{"message": "error", "extensions": {"description": "descriptive_error"}}
44+
]
45+
}
46+
)
47+
return response
48+
49+
3750
def mock_unsuccessful_response():
3851
response = Response()
3952
response.status_code = 400
@@ -141,6 +154,17 @@ def test_post_graphql_query_error(mock_post, baseten_api):
141154
baseten_api._post_graphql_query("sample_query_string")
142155

143156

157+
@mock.patch(
158+
"requests.post", return_value=mock_graphql_error_response_with_description()
159+
)
160+
def test_post_graphql_query_error_with_description(mock_post, baseten_api):
161+
with pytest.raises(ApiError) as exc_info:
162+
baseten_api._post_graphql_query("sample_query_string")
163+
164+
exception = exc_info.value
165+
assert str(exception) == "descriptive_error"
166+
167+
144168
@mock.patch("requests.post", return_value=mock_unsuccessful_response())
145169
def test_post_requests_error(mock_post, baseten_api):
146170
with pytest.raises(requests.exceptions.HTTPError):

0 commit comments

Comments
 (0)