Skip to content

Commit 8c28b26

Browse files
Merge pull request #1651 from basetenlabs/bump-version-0.9.92
Release 0.9.92
2 parents 07586c8 + 07252fd commit 8c28b26

File tree

10 files changed

+104
-183
lines changed

10 files changed

+104
-183
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "truss"
3-
version = "0.9.91"
3+
version = "0.9.92"
44
description = "A seamless bridge from model development to model delivery"
55
license = "MIT"
66
readme = "README.md"

truss/base/trt_llm_config.py

Lines changed: 35 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import os
55
import warnings
66
from enum import Enum
7-
from typing import TYPE_CHECKING, Annotated, Any, Dict, Literal, Optional
7+
from typing import TYPE_CHECKING, Annotated, Dict, Literal, Optional
88

99
from huggingface_hub.errors import HFValidationError
1010
from huggingface_hub.utils import validate_repo_id
@@ -26,6 +26,13 @@
2626
ENGINE_BUILDER_TRUSS_RUNTIME_MIGRATION = (
2727
os.environ.get("ENGINE_BUILDER_TRUSS_RUNTIME_MIGRATION", "False") == "True"
2828
)
29+
try:
30+
from truss.base import custom_types
31+
32+
PydanticTrTBaseModel = custom_types.ConfigModel
33+
except ImportError:
34+
# fallback for briton
35+
PydanticTrTBaseModel = BaseModel # type: ignore[assignment,misc]
2936

3037

3138
class TrussTRTLLMModel(str, Enum):
@@ -54,13 +61,13 @@ class TrussTRTLLMQuantizationType(str, Enum):
5461
FP4_KV = "fp4_kv"
5562

5663

57-
class TrussTRTLLMPluginConfiguration(BaseModel):
64+
class TrussTRTLLMPluginConfiguration(PydanticTrTBaseModel):
5865
paged_kv_cache: bool = True
5966
use_paged_context_fmha: bool = True
6067
use_fp8_context_fmha: bool = False
6168

6269

63-
class TrussTRTQuantizationConfiguration(BaseModel):
70+
class TrussTRTQuantizationConfiguration(PydanticTrTBaseModel):
6471
"""Configuration for quantization of TRT models
6572
6673
Args:
@@ -96,7 +103,7 @@ class CheckpointSource(str, Enum):
96103
REMOTE_URL = "REMOTE_URL"
97104

98105

99-
class CheckpointRepository(BaseModel):
106+
class CheckpointRepository(PydanticTrTBaseModel):
100107
source: CheckpointSource
101108
repo: str
102109
revision: Optional[str] = None
@@ -125,7 +132,7 @@ class TrussSpecDecMode(str, Enum):
125132
LOOKAHEAD_DECODING = "LOOKAHEAD_DECODING"
126133

127134

128-
class TrussTRTLLMRuntimeConfiguration(BaseModel):
135+
class TrussTRTLLMRuntimeConfiguration(PydanticTrTBaseModel):
129136
kv_cache_free_gpu_mem_fraction: float = 0.9
130137
kv_cache_host_memory_bytes: Optional[Annotated[int, Field(strict=True, ge=1)]] = (
131138
None
@@ -144,12 +151,12 @@ class TrussTRTLLMRuntimeConfiguration(BaseModel):
144151
] = None
145152

146153

147-
class TrussTRTLLMLoraConfiguration(BaseModel):
154+
class TrussTRTLLMLoraConfiguration(PydanticTrTBaseModel):
148155
max_lora_rank: int = 64
149156
lora_target_modules: list[str] = []
150157

151158

152-
class TrussTRTLLMBuildConfiguration(BaseModel):
159+
class TrussTRTLLMBuildConfiguration(PydanticTrTBaseModel):
153160
base_model: TrussTRTLLMModel = TrussTRTLLMModel.DECODER
154161
max_seq_len: Optional[Annotated[int, Field(strict=True, ge=1, le=1048576)]] = None
155162
max_batch_size: Annotated[int, Field(strict=True, ge=1, le=2048)] = 256
@@ -302,8 +309,14 @@ def max_draft_len(self) -> Optional[int]:
302309
return self.speculator.num_draft_tokens
303310
return None
304311

312+
@property
313+
def parsed_trt_llm_build_configs(self) -> list["TrussTRTLLMBuildConfiguration"]:
314+
if self.speculator and self.speculator.build:
315+
return [self, self.speculator.build]
316+
return [self]
317+
305318

306-
class TrussSpeculatorConfiguration(BaseModel):
319+
class TrussSpeculatorConfiguration(PydanticTrTBaseModel):
307320
speculative_decoding_mode: TrussSpecDecMode = TrussSpecDecMode.DRAFT_EXTERNAL
308321
num_draft_tokens: Optional[Annotated[int, Field(strict=True, ge=1)]] = None
309322
checkpoint_repository: Optional[CheckpointRepository] = None
@@ -408,7 +421,7 @@ def resolved_checkpoint_repository(self) -> CheckpointRepository:
408421
)
409422

410423

411-
class VersionsOverrides(BaseModel):
424+
class VersionsOverrides(PydanticTrTBaseModel):
412425
# If an override is specified, it takes precedence over the backend's current
413426
# default version. The version is used to create a full image ref and should look
414427
# like a semver, e.g. for the briton the version `0.17.0-fd30ac1` could be specified
@@ -418,8 +431,16 @@ class VersionsOverrides(BaseModel):
418431
briton_version: Optional[str] = None
419432
bei_version: Optional[str] = None
420433

434+
@model_validator(mode="before")
435+
def version_must_start_with_number(cls, data):
436+
for field in ["engine_builder_version", "briton_version", "bei_version"]:
437+
v = data.get(field)
438+
if v is not None and (not v or not v[0].isdigit()):
439+
raise ValueError(f"{field.name} must start with a number")
440+
return data
441+
421442

422-
class ImageVersions(BaseModel):
443+
class ImageVersions(PydanticTrTBaseModel):
423444
# Required versions for patching truss config during docker build setup.
424445
# The schema of this model must be such that it can parse the values serialized
425446
# from the backend. The inserted values are full image references, resolved using
@@ -428,50 +449,16 @@ class ImageVersions(BaseModel):
428449
briton_image: str
429450

430451

431-
class TRTLLMConfiguration(BaseModel):
432-
runtime: TrussTRTLLMRuntimeConfiguration = TrussTRTLLMRuntimeConfiguration()
452+
class TRTLLMConfiguration(PydanticTrTBaseModel):
433453
build: TrussTRTLLMBuildConfiguration
454+
runtime: TrussTRTLLMRuntimeConfiguration = TrussTRTLLMRuntimeConfiguration()
434455
# If versions are not set, the baseten backend will insert current defaults.
435456
version_overrides: VersionsOverrides = VersionsOverrides()
436457

437458
def model_post_init(self, __context):
438459
self.add_bei_default_route()
439460
self.chunked_context_fix()
440461

441-
@model_validator(mode="before")
442-
@classmethod
443-
def migrate_runtime_fields(cls, data: Any) -> Any:
444-
extra_runtime_fields = {}
445-
valid_build_fields = {}
446-
if isinstance(data.get("build"), dict):
447-
for key, value in data.get("build").items():
448-
if key in TrussTRTLLMBuildConfiguration.__annotations__:
449-
valid_build_fields[key] = value
450-
else:
451-
if key in TrussTRTLLMRuntimeConfiguration.__annotations__:
452-
logger.warning(f"Found runtime.{key}: {value} in build config")
453-
extra_runtime_fields[key] = value
454-
if extra_runtime_fields:
455-
logger.warning(
456-
f"Found extra fields {list(extra_runtime_fields.keys())} in build configuration, unspecified runtime fields will be configured using these values."
457-
" This configuration of deprecated fields is scheduled for removal, please upgrade to the latest truss version and update configs according to https://docs.baseten.co/performance/engine-builder-config."
458-
)
459-
if data.get("runtime"):
460-
data.get("runtime").update(
461-
{
462-
k: v
463-
for k, v in extra_runtime_fields.items()
464-
if k not in data.get("runtime")
465-
}
466-
)
467-
else:
468-
data.update(
469-
{"runtime": {k: v for k, v in extra_runtime_fields.items()}}
470-
)
471-
data.update({"build": valid_build_fields})
472-
return data
473-
return data
474-
475462
def chunked_context_fix(self: "TRTLLMConfiguration") -> "TRTLLMConfiguration":
476463
"""check if there is an error wrt. runtime.enable_chunked_context"""
477464
if (
@@ -482,16 +469,8 @@ def chunked_context_fix(self: "TRTLLMConfiguration") -> "TRTLLMConfiguration":
482469
and self.build.plugin_configuration.paged_kv_cache
483470
)
484471
):
485-
logger.warning(
486-
"If trt_llm.runtime.enable_chunked_context is True, then trt_llm.build.plugin_configuration.use_paged_context_fmha and trt_llm.build.plugin_configuration.paged_kv_cache should be True. "
487-
"Setting trt_llm.build.plugin_configuration.use_paged_context_fmha and trt_llm.build.plugin_configuration.paged_kv_cache to True."
488-
)
489-
self.build = self.build.model_copy(
490-
update={
491-
"plugin_configuration": self.build.plugin_configuration.model_copy(
492-
update={"use_paged_context_fmha": True, "paged_kv_cache": True}
493-
)
494-
}
472+
raise ValueError(
473+
"If trt_llm.runtime.enable_chunked_context is True, then trt_llm.build.plugin_configuration.use_paged_context_fmha and trt_llm.build.plugin_configuration.paged_kv_cache need to be True. "
495474
)
496475

497476
return self

truss/base/truss_config.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -611,16 +611,6 @@ class Config:
611611
def canonical_python_version(self) -> str:
612612
return to_dotted_python_version(self.python_version)
613613

614-
@property
615-
def parsed_trt_llm_build_configs(
616-
self,
617-
) -> list[trt_llm_config.TrussTRTLLMBuildConfiguration]:
618-
if self.trt_llm:
619-
if self.trt_llm.build.speculator and self.trt_llm.build.speculator.build:
620-
return [self.trt_llm.build, self.trt_llm.build.speculator.build]
621-
return [self.trt_llm.build]
622-
return []
623-
624614
def to_dict(self, verbose: bool = True) -> dict:
625615
self.runtime.sync_is_websocket() # type: ignore[operator] # This is callable.
626616
data = super().to_dict(verbose)

truss/cli/cli.py

Lines changed: 21 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,6 @@
5050
from truss.remote.remote_factory import USER_TRUSSRC_PATH, RemoteFactory
5151
from truss.trt_llm.config_checks import (
5252
has_no_tags_trt_llm_builder,
53-
is_missing_secrets_for_trt_llm_builder,
5453
memory_updated_for_trt_llm_builder,
5554
uses_trt_llm_builder,
5655
)
@@ -153,6 +152,16 @@ def _get_required_option(ctx: click.Context, name: str) -> object:
153152
return value
154153

155154

155+
def _prepare_click_context(f: click.Command, params: dict) -> click.Context:
156+
"""create new click context for invoking a command via f.invoke(ctx)"""
157+
current_ctx = click.get_current_context()
158+
current_obj = current_ctx.find_root().obj
159+
160+
ctx = click.Context(f, obj=current_obj)
161+
ctx.params = params
162+
return ctx
163+
164+
156165
def _log_level_option(f: Callable[..., object]) -> Callable[..., object]:
157166
return click.option(
158167
"--log",
@@ -753,21 +762,20 @@ def push(
753762
console.print(live_reload_disabled_text, style="red")
754763
sys.exit(1)
755764

756-
if is_missing_secrets_for_trt_llm_builder(tr):
757-
missing_token_text = (
758-
"`hf_access_token` must be provided in secrets to build a gated model. "
759-
"Please see https://docs.baseten.co/deploy/guides/private-model for configuration instructions."
760-
)
761-
console.print(missing_token_text, style="yellow")
762765
if memory_updated_for_trt_llm_builder(tr):
763766
console.print(
764767
f"Automatically increasing memory for trt-llm builder to {TRTLLM_MIN_MEMORY_REQUEST_GI}Gi."
765768
)
766-
message_oai = has_no_tags_trt_llm_builder(tr)
769+
message_oai, raised_message_oai = has_no_tags_trt_llm_builder(tr)
767770
if message_oai:
768-
console.print(message_oai, style="red")
769-
sys.exit(1)
770-
for trt_llm_build_config in tr.spec.config.parsed_trt_llm_build_configs:
771+
console.print(message_oai, style="yellow")
772+
if raised_message_oai:
773+
console.print(message_oai, style="red")
774+
sys.exit(1)
775+
776+
for (
777+
trt_llm_build_config
778+
) in tr.spec.config.trt_llm.build.parsed_trt_llm_build_configs:
771779
if (
772780
trt_llm_build_config.quantization_type
773781
in [TrussTRTLLMQuantizationType.FP8, TrussTRTLLMQuantizationType.FP8_KV]
@@ -1722,19 +1730,18 @@ def deploy_checkpoints(
17221730
),
17231731
)
17241732

1725-
ctx = click.Context(push, obj={})
1726-
ctx.params = {
1733+
params = {
17271734
"target_directory": prepare_checkpoint_result.truss_directory,
17281735
"remote": remote,
17291736
"model_name": prepare_checkpoint_result.checkpoint_deploy_config.model_name,
17301737
"publish": True,
17311738
"deployment_name": prepare_checkpoint_result.checkpoint_deploy_config.deployment_name,
17321739
}
1740+
ctx = _prepare_click_context(push, params)
17331741
if dry_run:
17341742
console.print("--dry-run flag provided, not deploying", style="yellow")
17351743
else:
17361744
push.invoke(ctx)
1737-
17381745
train_cli.print_deploy_checkpoints_success_message(prepare_checkpoint_result)
17391746

17401747

truss/tests/conftest.py

Lines changed: 4 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -780,45 +780,16 @@ def deprecated_trtllm_config(default_config) -> Dict[str, Any]:
780780
"base_model": "llama",
781781
"max_seq_len": 2048,
782782
"max_batch_size": 512,
783-
# start deprecated fields
784-
"kv_cache_free_gpu_mem_fraction": 0.1,
785-
"enable_chunked_context": True,
786-
"batch_scheduler_policy": TrussTRTLLMBatchSchedulerPolicy.MAX_UTILIZATION.value,
787-
"request_default_max_tokens": 10,
788-
"total_token_limit": 50,
789-
# end deprecated fields
790783
"checkpoint_repository": {"source": "HF", "repo": "meta/llama4-500B"},
791784
"gather_all_token_logits": False,
792-
}
793-
}
794-
return trtllm_config
795-
796-
797-
@pytest.fixture
798-
def deprecated_trtllm_config_with_runtime_existing(default_config) -> Dict[str, Any]:
799-
trtllm_config = default_config
800-
trtllm_config["resources"] = {
801-
"accelerator": Accelerator.L4.value,
802-
"cpu": "1",
803-
"memory": "24Gi",
804-
"use_gpu": True,
805-
}
806-
trtllm_config["trt_llm"] = {
807-
"build": {
808-
"base_model": "llama",
809-
"max_seq_len": 2048,
810-
"max_batch_size": 512,
811-
# start deprecated fields
785+
},
786+
"runtime": {
787+
"total_token_limit": 100,
812788
"kv_cache_free_gpu_mem_fraction": 0.1,
813789
"enable_chunked_context": True,
814-
"batch_scheduler_policy": TrussTRTLLMBatchSchedulerPolicy.MAX_UTILIZATION.value,
790+
"batch_scheduler_policy": "max_utilization",
815791
"request_default_max_tokens": 10,
816-
"total_token_limit": 50,
817-
# end deprecated fields
818-
"checkpoint_repository": {"source": "HF", "repo": "meta/llama4-500B"},
819-
"gather_all_token_logits": False,
820792
},
821-
"runtime": {"total_token_limit": 100},
822793
}
823794
return trtllm_config
824795

truss/tests/test_data/test_streaming_truss_with_tracing/config.yaml

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ requirements: []
3030
requirements_file: null
3131
resources:
3232
accelerator: null
33-
cpu: '1'
33+
cpu: "1"
3434
memory: 2Gi
3535
use_gpu: false
3636
runtime:
@@ -46,7 +46,6 @@ runtime:
4646
kind: websocket
4747
truss_server_version_override: null
4848
secrets: {}
49-
spec_version: '2.0'
49+
spec_version: "2.0"
5050
system_packages: []
51-
trt_llm: null
5251
use_local_src: false

truss/tests/test_data/test_trt_llm_truss/config.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
11
model_name: Test
2+
model_metadata:
3+
tags:
4+
- openai-compatible
25
resources:
36
accelerator: A100
47
use_gpu: True

0 commit comments

Comments
 (0)