4
4
import os
5
5
import warnings
6
6
from enum import Enum
7
- from typing import TYPE_CHECKING , Annotated , Any , Dict , Literal , Optional
7
+ from typing import TYPE_CHECKING , Annotated , Dict , Literal , Optional
8
8
9
9
from huggingface_hub .errors import HFValidationError
10
10
from huggingface_hub .utils import validate_repo_id
26
26
ENGINE_BUILDER_TRUSS_RUNTIME_MIGRATION = (
27
27
os .environ .get ("ENGINE_BUILDER_TRUSS_RUNTIME_MIGRATION" , "False" ) == "True"
28
28
)
29
+ try :
30
+ from truss .base import custom_types
31
+
32
+ PydanticTrTBaseModel = custom_types .ConfigModel
33
+ except ImportError :
34
+ # fallback for briton
35
+ PydanticTrTBaseModel = BaseModel # type: ignore[assignment,misc]
29
36
30
37
31
38
class TrussTRTLLMModel (str , Enum ):
@@ -54,13 +61,13 @@ class TrussTRTLLMQuantizationType(str, Enum):
54
61
FP4_KV = "fp4_kv"
55
62
56
63
57
- class TrussTRTLLMPluginConfiguration (BaseModel ):
64
+ class TrussTRTLLMPluginConfiguration (PydanticTrTBaseModel ):
58
65
paged_kv_cache : bool = True
59
66
use_paged_context_fmha : bool = True
60
67
use_fp8_context_fmha : bool = False
61
68
62
69
63
- class TrussTRTQuantizationConfiguration (BaseModel ):
70
+ class TrussTRTQuantizationConfiguration (PydanticTrTBaseModel ):
64
71
"""Configuration for quantization of TRT models
65
72
66
73
Args:
@@ -96,7 +103,7 @@ class CheckpointSource(str, Enum):
96
103
REMOTE_URL = "REMOTE_URL"
97
104
98
105
99
- class CheckpointRepository (BaseModel ):
106
+ class CheckpointRepository (PydanticTrTBaseModel ):
100
107
source : CheckpointSource
101
108
repo : str
102
109
revision : Optional [str ] = None
@@ -125,7 +132,7 @@ class TrussSpecDecMode(str, Enum):
125
132
LOOKAHEAD_DECODING = "LOOKAHEAD_DECODING"
126
133
127
134
128
- class TrussTRTLLMRuntimeConfiguration (BaseModel ):
135
+ class TrussTRTLLMRuntimeConfiguration (PydanticTrTBaseModel ):
129
136
kv_cache_free_gpu_mem_fraction : float = 0.9
130
137
kv_cache_host_memory_bytes : Optional [Annotated [int , Field (strict = True , ge = 1 )]] = (
131
138
None
@@ -144,12 +151,12 @@ class TrussTRTLLMRuntimeConfiguration(BaseModel):
144
151
] = None
145
152
146
153
147
- class TrussTRTLLMLoraConfiguration (BaseModel ):
154
+ class TrussTRTLLMLoraConfiguration (PydanticTrTBaseModel ):
148
155
max_lora_rank : int = 64
149
156
lora_target_modules : list [str ] = []
150
157
151
158
152
- class TrussTRTLLMBuildConfiguration (BaseModel ):
159
+ class TrussTRTLLMBuildConfiguration (PydanticTrTBaseModel ):
153
160
base_model : TrussTRTLLMModel = TrussTRTLLMModel .DECODER
154
161
max_seq_len : Optional [Annotated [int , Field (strict = True , ge = 1 , le = 1048576 )]] = None
155
162
max_batch_size : Annotated [int , Field (strict = True , ge = 1 , le = 2048 )] = 256
@@ -302,8 +309,14 @@ def max_draft_len(self) -> Optional[int]:
302
309
return self .speculator .num_draft_tokens
303
310
return None
304
311
312
+ @property
313
+ def parsed_trt_llm_build_configs (self ) -> list ["TrussTRTLLMBuildConfiguration" ]:
314
+ if self .speculator and self .speculator .build :
315
+ return [self , self .speculator .build ]
316
+ return [self ]
317
+
305
318
306
- class TrussSpeculatorConfiguration (BaseModel ):
319
+ class TrussSpeculatorConfiguration (PydanticTrTBaseModel ):
307
320
speculative_decoding_mode : TrussSpecDecMode = TrussSpecDecMode .DRAFT_EXTERNAL
308
321
num_draft_tokens : Optional [Annotated [int , Field (strict = True , ge = 1 )]] = None
309
322
checkpoint_repository : Optional [CheckpointRepository ] = None
@@ -408,7 +421,7 @@ def resolved_checkpoint_repository(self) -> CheckpointRepository:
408
421
)
409
422
410
423
411
- class VersionsOverrides (BaseModel ):
424
+ class VersionsOverrides (PydanticTrTBaseModel ):
412
425
# If an override is specified, it takes precedence over the backend's current
413
426
# default version. The version is used to create a full image ref and should look
414
427
# like a semver, e.g. for the briton the version `0.17.0-fd30ac1` could be specified
@@ -418,8 +431,16 @@ class VersionsOverrides(BaseModel):
418
431
briton_version : Optional [str ] = None
419
432
bei_version : Optional [str ] = None
420
433
434
+ @model_validator (mode = "before" )
435
+ def version_must_start_with_number (cls , data ):
436
+ for field in ["engine_builder_version" , "briton_version" , "bei_version" ]:
437
+ v = data .get (field )
438
+ if v is not None and (not v or not v [0 ].isdigit ()):
439
+ raise ValueError (f"{ field .name } must start with a number" )
440
+ return data
441
+
421
442
422
- class ImageVersions (BaseModel ):
443
+ class ImageVersions (PydanticTrTBaseModel ):
423
444
# Required versions for patching truss config during docker build setup.
424
445
# The schema of this model must be such that it can parse the values serialized
425
446
# from the backend. The inserted values are full image references, resolved using
@@ -428,50 +449,16 @@ class ImageVersions(BaseModel):
428
449
briton_image : str
429
450
430
451
431
- class TRTLLMConfiguration (BaseModel ):
432
- runtime : TrussTRTLLMRuntimeConfiguration = TrussTRTLLMRuntimeConfiguration ()
452
+ class TRTLLMConfiguration (PydanticTrTBaseModel ):
433
453
build : TrussTRTLLMBuildConfiguration
454
+ runtime : TrussTRTLLMRuntimeConfiguration = TrussTRTLLMRuntimeConfiguration ()
434
455
# If versions are not set, the baseten backend will insert current defaults.
435
456
version_overrides : VersionsOverrides = VersionsOverrides ()
436
457
437
458
def model_post_init (self , __context ):
438
459
self .add_bei_default_route ()
439
460
self .chunked_context_fix ()
440
461
441
- @model_validator (mode = "before" )
442
- @classmethod
443
- def migrate_runtime_fields (cls , data : Any ) -> Any :
444
- extra_runtime_fields = {}
445
- valid_build_fields = {}
446
- if isinstance (data .get ("build" ), dict ):
447
- for key , value in data .get ("build" ).items ():
448
- if key in TrussTRTLLMBuildConfiguration .__annotations__ :
449
- valid_build_fields [key ] = value
450
- else :
451
- if key in TrussTRTLLMRuntimeConfiguration .__annotations__ :
452
- logger .warning (f"Found runtime.{ key } : { value } in build config" )
453
- extra_runtime_fields [key ] = value
454
- if extra_runtime_fields :
455
- logger .warning (
456
- f"Found extra fields { list (extra_runtime_fields .keys ())} in build configuration, unspecified runtime fields will be configured using these values."
457
- " This configuration of deprecated fields is scheduled for removal, please upgrade to the latest truss version and update configs according to https://docs.baseten.co/performance/engine-builder-config."
458
- )
459
- if data .get ("runtime" ):
460
- data .get ("runtime" ).update (
461
- {
462
- k : v
463
- for k , v in extra_runtime_fields .items ()
464
- if k not in data .get ("runtime" )
465
- }
466
- )
467
- else :
468
- data .update (
469
- {"runtime" : {k : v for k , v in extra_runtime_fields .items ()}}
470
- )
471
- data .update ({"build" : valid_build_fields })
472
- return data
473
- return data
474
-
475
462
def chunked_context_fix (self : "TRTLLMConfiguration" ) -> "TRTLLMConfiguration" :
476
463
"""check if there is an error wrt. runtime.enable_chunked_context"""
477
464
if (
@@ -482,16 +469,8 @@ def chunked_context_fix(self: "TRTLLMConfiguration") -> "TRTLLMConfiguration":
482
469
and self .build .plugin_configuration .paged_kv_cache
483
470
)
484
471
):
485
- logger .warning (
486
- "If trt_llm.runtime.enable_chunked_context is True, then trt_llm.build.plugin_configuration.use_paged_context_fmha and trt_llm.build.plugin_configuration.paged_kv_cache should be True. "
487
- "Setting trt_llm.build.plugin_configuration.use_paged_context_fmha and trt_llm.build.plugin_configuration.paged_kv_cache to True."
488
- )
489
- self .build = self .build .model_copy (
490
- update = {
491
- "plugin_configuration" : self .build .plugin_configuration .model_copy (
492
- update = {"use_paged_context_fmha" : True , "paged_kv_cache" : True }
493
- )
494
- }
472
+ raise ValueError (
473
+ "If trt_llm.runtime.enable_chunked_context is True, then trt_llm.build.plugin_configuration.use_paged_context_fmha and trt_llm.build.plugin_configuration.paged_kv_cache need to be True. "
495
474
)
496
475
497
476
return self
0 commit comments