Skip to content

Commit 79a4064

Browse files
Merge pull request #1631 from basetenlabs/bump-version-0.9.88
Release 0.9.88
2 parents 17925a3 + b654774 commit 79a4064

30 files changed

+2750
-1262
lines changed

docs/examples/frameworks/pytorch.mdx

Lines changed: 0 additions & 4 deletions
This file was deleted.

docs/examples/frameworks/sklearn.mdx

Lines changed: 0 additions & 4 deletions
This file was deleted.

docs/examples/frameworks/tensorflow.mdx

Lines changed: 0 additions & 4 deletions
This file was deleted.

docs/examples/frameworks/xgboost.mdx

Lines changed: 0 additions & 4 deletions
This file was deleted.

poetry.lock

Lines changed: 2 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "truss"
3-
version = "0.9.87"
3+
version = "0.9.88"
44
description = "A seamless bridge from model development to model delivery"
55
license = "MIT"
66
readme = "README.md"
@@ -102,6 +102,7 @@ rich = { version = "^13.4.2", optional = false }
102102
rich-click = { version = "^1.6.1", optional = false }
103103
ruff = { version = ">=0.4.8", optional = false } # Not a dev dep, needed for chains code gen.
104104
tenacity = { version = "^8.0.1", optional = false }
105+
tomlkit = { version = ">=0.13.2", optional = false }
105106
watchfiles = { version = "^0.19.0", optional = false }
106107

107108

@@ -127,6 +128,7 @@ rich = { components = "other" }
127128
rich-click = { components = "other" }
128129
ruff = { components = "other" }
129130
tenacity = { components = "other" }
131+
tomlkit = { components = "other" }
130132
watchfiles = { components = "other" }
131133

132134
[tool.poetry.group.dev.dependencies]
@@ -146,7 +148,6 @@ pytest-check = "^2.4.1"
146148
pytest-cov = "^3.0.0"
147149
pytest-split = ">=0.10.0"
148150
requests-mock = ">=1.11.0"
149-
tomlkit = ">=0.12"
150151
types-PyYAML = "^6.0.12.12"
151152
types-aiofiles = ">=24.1.0"
152153
types-requests = "==2.31.0.2"

truss-chains/truss_chains/deployment/deployment_client.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
from truss.remote.baseten import remote as b10_remote
3333
from truss.remote.baseten import service as b10_service
3434
from truss.truss_handle import truss_handle
35-
from truss.util import log_utils
35+
from truss.util import log_utils, user_config
3636
from truss.util import path as truss_path
3737
from truss_chains import framework, private_types, public_types
3838
from truss_chains.deployment import code_gen
@@ -472,7 +472,7 @@ def _create_baseten_chain(
472472
remote_factory.RemoteFactory.create(remote=baseten_options.remote),
473473
)
474474

475-
if remote_provider.include_git_info or baseten_options.include_git_info:
475+
if user_config.settings.include_git_info or baseten_options.include_git_info:
476476
truss_user_env = b10_types.TrussUserEnv.collect_with_git_info(
477477
baseten_options.working_dir
478478
)
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
from truss.base import truss_config
2+
from truss_train import definitions
3+
4+
deploy_checkpoint = definitions.DeployCheckpointsConfig(
5+
compute=definitions.Compute(
6+
accelerator=truss_config.AcceleratorSpec(
7+
accelerator=truss_config.Accelerator.A10G, count=4
8+
)
9+
),
10+
runtime=definitions.DeployCheckpointsRuntime(
11+
environment_variables={
12+
"HF_TOKEN": definitions.SecretReference(name="hf_access_token")
13+
}
14+
),
15+
checkpoint_details=definitions.CheckpointDetails(
16+
base_model_id="unsloth/gemma-3-1b-it",
17+
checkpoints=[
18+
definitions.Checkpoint(
19+
id="checkpoint-24", name="checkpoint-24", training_job_id="lqz4pw4"
20+
),
21+
definitions.Checkpoint(
22+
id="checkpoint-42", name="checkpoint-42", training_job_id="lqz4pw4"
23+
),
24+
],
25+
),
26+
)
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
from truss.base import truss_config
2+
from truss_train import definitions
3+
4+
deploy_checkpoint = definitions.DeployCheckpointsConfig(
5+
compute=definitions.Compute(
6+
accelerator=truss_config.AcceleratorSpec(
7+
accelerator=truss_config.Accelerator.A10G, count=4
8+
)
9+
),
10+
runtime=definitions.DeployCheckpointsRuntime(
11+
environment_variables={
12+
"HF_TOKEN": definitions.SecretReference(name="hf_access_token")
13+
}
14+
),
15+
checkpoint_details=definitions.CheckpointDetails(
16+
base_model_id="unsloth/gemma-3-1b-it",
17+
checkpoints=[
18+
definitions.Checkpoint(
19+
id="checkpoint-24", name="checkpoint-24", training_job_id="lqz4pw4"
20+
),
21+
definitions.Checkpoint(
22+
id="checkpoint-42", name="checkpoint-42", training_job_id="lqz4pw4"
23+
),
24+
],
25+
),
26+
)
27+
runtime_config = definitions.Runtime(
28+
start_commands=["/bin/bash ./my-entrypoint.sh"],
29+
environment_variables={
30+
"FOO_VAR": "FOO_VAL",
31+
"BAR_VAR": definitions.SecretReference(name="BAR_SECRET"),
32+
},
33+
)
34+
35+
training_job = definitions.TrainingJob(
36+
image=definitions.Image(base_image="base-image"),
37+
compute=definitions.Compute(node_count=1, cpu_count=4),
38+
runtime_config=runtime_config,
39+
)
40+
41+
first_project = definitions.TrainingProject(name="first-project", job=training_job)

truss-train/tests/test_loader.py

Lines changed: 43 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,21 +11,21 @@ def test_import_requires_training_project():
1111
job_src = TEST_ROOT / "import" / "config_without_training_project.py"
1212
match = r"No `.+` was found."
1313
with pytest.raises(ValueError, match=match):
14-
with loader.import_target(job_src):
14+
with loader.import_training_project(job_src):
1515
pass
1616

1717

1818
def test_import_requires_single_training_project():
1919
job_src = TEST_ROOT / "import" / "config_with_multiple_training_projects.py"
2020
match = r"Multiple `.+`s were found."
2121
with pytest.raises(ValueError, match=match):
22-
with loader.import_target(job_src):
22+
with loader.import_training_project(job_src):
2323
pass
2424

2525

2626
def test_import_with_single_training_project():
2727
job_src = TEST_ROOT / "import" / "config_with_single_training_project.py"
28-
with loader.import_target(job_src) as training_project:
28+
with loader.import_training_project(job_src) as training_project:
2929
assert training_project.name == "first-project"
3030
assert training_project.job.compute.cpu_count == 4
3131

@@ -34,5 +34,44 @@ def test_import_directory_fails():
3434
job_src = TEST_ROOT / "import"
3535
match = r"You must point to a python file"
3636
with pytest.raises(ImportError, match=match):
37-
with loader.import_target(job_src):
37+
with loader.import_training_project(job_src):
3838
pass
39+
40+
41+
def test_import_deploy_checkpoints_config():
42+
job_src = TEST_ROOT / "import" / "deploy_checkpoints_config.py"
43+
with loader.import_deploy_checkpoints_config(job_src) as deploy_checkpoints_config:
44+
assert len(deploy_checkpoints_config.checkpoint_details.checkpoints) == 2
45+
assert (
46+
deploy_checkpoints_config.checkpoint_details.base_model_id
47+
== "unsloth/gemma-3-1b-it"
48+
)
49+
assert (
50+
deploy_checkpoints_config.checkpoint_details.checkpoints[0].id
51+
== "checkpoint-24"
52+
)
53+
assert (
54+
deploy_checkpoints_config.checkpoint_details.checkpoints[1].id
55+
== "checkpoint-42"
56+
)
57+
58+
59+
def test_import_handles_training_project_with_deploy_checkpoints_config():
60+
job_src = TEST_ROOT / "import" / "project_with_deploy_checkpoints_config.py"
61+
with loader.import_training_project(job_src) as training_project:
62+
assert training_project.name == "first-project"
63+
assert training_project.job.compute.cpu_count == 4
64+
with loader.import_deploy_checkpoints_config(job_src) as deploy_checkpoints_config:
65+
assert len(deploy_checkpoints_config.checkpoint_details.checkpoints) == 2
66+
assert (
67+
deploy_checkpoints_config.checkpoint_details.base_model_id
68+
== "unsloth/gemma-3-1b-it"
69+
)
70+
assert (
71+
deploy_checkpoints_config.checkpoint_details.checkpoints[0].id
72+
== "checkpoint-24"
73+
)
74+
assert (
75+
deploy_checkpoints_config.checkpoint_details.checkpoints[1].id
76+
== "checkpoint-42"
77+
)

truss-train/truss_train/definitions.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44

55
from truss.base import custom_types, truss_config
66

7+
DEFAULT_LORA_RANK = 16
8+
79

810
class SecretReference(custom_types.SafeModel):
911
name: str
@@ -24,6 +26,18 @@ def model_dump(self, *args, **kwargs):
2426
}
2527
return data
2628

29+
def to_truss_config(self) -> truss_config.Resources:
30+
if self.accelerator:
31+
return truss_config.Resources(
32+
cpu=str(self.cpu_count),
33+
memory=self.memory,
34+
accelerator=self.accelerator,
35+
node_count=self.node_count,
36+
)
37+
return truss_config.Resources(
38+
cpu=str(self.cpu_count), memory=self.memory, node_count=self.node_count
39+
)
40+
2741

2842
class CheckpointingConfig(custom_types.SafeModel):
2943
enabled: bool = False
@@ -57,3 +71,41 @@ class TrainingProject(custom_types.SafeModel):
5771
# TrainingProject is the wrapper around project config and job config. However, we exclude job
5872
# in serialization so just TrainingProject metadata is included in API requests.
5973
job: TrainingJob = pydantic.Field(exclude=True)
74+
75+
76+
class Checkpoint(custom_types.SafeModel):
77+
training_job_id: str
78+
id: str
79+
name: str
80+
lora_rank: Optional[int] = (
81+
None # lora rank will be fetched through the API if available.
82+
)
83+
84+
def to_truss_config(self) -> truss_config.Checkpoint:
85+
return truss_config.Checkpoint(
86+
id=f"{self.training_job_id}/{self.id}", name=self.id
87+
)
88+
89+
90+
class CheckpointDetails(custom_types.SafeModel):
91+
download_folder: str = truss_config.DEFAULT_TRAINING_CHECKPOINT_FOLDER
92+
base_model_id: Optional[str] = None
93+
checkpoints: List[Checkpoint] = []
94+
95+
def to_truss_config(self) -> truss_config.CheckpointConfiguration:
96+
checkpoints = [checkpoint.to_truss_config() for checkpoint in self.checkpoints]
97+
return truss_config.CheckpointConfiguration(
98+
checkpoints=checkpoints, download_folder=self.download_folder
99+
)
100+
101+
102+
class DeployCheckpointsRuntime(custom_types.SafeModel):
103+
environment_variables: Dict[str, Union[str, SecretReference]] = {}
104+
105+
106+
class DeployCheckpointsConfig(custom_types.SafeModel):
107+
checkpoint_details: Optional[CheckpointDetails] = None
108+
model_name: Optional[str] = None
109+
deployment_name: Optional[str] = None
110+
runtime: Optional[DeployCheckpointsRuntime] = None
111+
compute: Optional[Compute] = None

truss-train/truss_train/loader.py

Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,31 @@
22
import importlib.util
33
import os
44
import pathlib
5-
from typing import Iterator
5+
from typing import Iterator, Type, TypeVar
66

77
from truss_train import definitions
88

9+
T = TypeVar("T")
10+
11+
12+
@contextlib.contextmanager
13+
def import_training_project(
14+
module_path: pathlib.Path,
15+
) -> Iterator[definitions.TrainingProject]:
16+
with import_target(module_path, definitions.TrainingProject) as project:
17+
yield project
18+
19+
20+
@contextlib.contextmanager
21+
def import_deploy_checkpoints_config(
22+
module_path: pathlib.Path,
23+
) -> Iterator[definitions.DeployCheckpointsConfig]:
24+
with import_target(module_path, definitions.DeployCheckpointsConfig) as config:
25+
yield config
26+
927

1028
@contextlib.contextmanager
11-
def import_target(module_path: pathlib.Path) -> Iterator[definitions.TrainingProject]:
29+
def import_target(module_path: pathlib.Path, target_type: Type[T]) -> Iterator[T]:
1230
module_name = module_path.stem
1331
if not os.path.isfile(module_path):
1432
raise ImportError(
@@ -24,13 +42,11 @@ def import_target(module_path: pathlib.Path) -> Iterator[definitions.TrainingPro
2442
spec.loader.exec_module(module)
2543

2644
module_vars = (getattr(module, name) for name in dir(module))
27-
training_projects = [
28-
sym for sym in module_vars if isinstance(sym, definitions.TrainingProject)
29-
]
45+
target = [sym for sym in module_vars if isinstance(sym, target_type)]
3046

31-
if len(training_projects) == 0:
32-
raise ValueError(f"No `{definitions.TrainingProject}` was found.")
33-
elif len(training_projects) > 1:
34-
raise ValueError(f"Multiple `{definitions.TrainingProject}`s were found.")
47+
if len(target) == 0:
48+
raise ValueError(f"No `{target_type}` was found.")
49+
elif len(target) > 1:
50+
raise ValueError(f"Multiple `{target_type}`s were found.")
3551

36-
yield training_projects[0]
52+
yield target[0]

truss/base/constants.py

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,6 @@
11
import pathlib
2-
from typing import Set
32

4-
SKLEARN = "sklearn"
5-
TENSORFLOW = "tensorflow"
6-
KERAS = "keras"
7-
XGBOOST = "xgboost"
8-
PYTORCH = "pytorch"
93
CUSTOM = "custom"
10-
HUGGINGFACE_TRANSFORMER = "huggingface_transformer"
11-
LIGHTGBM = "lightgbm"
12-
134

145
_TRUSS_ROOT = pathlib.Path(__file__).parent.parent.resolve()
156

@@ -74,7 +65,6 @@
7465
TRUSS_DIR = "truss_dir"
7566
TRUSS_HASH = "truss_hash"
7667

77-
HUGGINGFACE_TRANSFORMER_MODULE_NAME: Set[str] = set({})
7868

7969
INFERENCE_SERVER_PORT = 8080
8070

@@ -84,7 +74,7 @@
8474

8575
TRTLLM_SPEC_DEC_TARGET_MODEL_NAME = "target"
8676
TRTLLM_SPEC_DEC_DRAFT_MODEL_NAME = "draft"
87-
TRTLLM_BASE_IMAGE = "baseten/briton-server:v0.18.1-cd81637"
77+
TRTLLM_BASE_IMAGE = "baseten/briton-server:v0.18.1-1856fb5"
8878
# TODO: build the image so that the default path `python3` can be used - then remove here.
8979
TRTLLM_PYTHON_EXECUTABLE = "/usr/local/briton/venv/bin/python"
9080
BEI_TRTLLM_BASE_IMAGE = "baseten/bei:0.0.23"

0 commit comments

Comments
 (0)