Skip to content

Commit 4d093f2

Browse files
tanmayv25yinggeh
andauthored
Multiple Model Configurations (#348) (#357)
Add `--mode-config-name` option when starting Triton server. Allow users to create multiple configurations and select a custom configuration other than the default `model/config.pbtxt`. Co-authored-by: Yingge He <157551214+yinggeh@users.noreply.github.com>
1 parent 817aaf4 commit 4d093f2

File tree

9 files changed

+140
-45
lines changed

9 files changed

+140
-45
lines changed

include/triton/core/tritonserver.h

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ struct TRITONSERVER_MetricFamily;
9191
/// }
9292
///
9393
#define TRITONSERVER_API_VERSION_MAJOR 1
94-
#define TRITONSERVER_API_VERSION_MINOR 30
94+
#define TRITONSERVER_API_VERSION_MINOR 31
9595

9696
/// Get the TRITONBACKEND API version supported by the Triton shared
9797
/// library. This value can be compared against the
@@ -1828,6 +1828,16 @@ TRITONSERVER_DECLSPEC struct TRITONSERVER_Error*
18281828
TRITONSERVER_ServerOptionsSetStrictModelConfig(
18291829
struct TRITONSERVER_ServerOptions* options, bool strict);
18301830

1831+
/// Set the custom model configuration name to load for all models.
1832+
/// Fall back to default config file if empty.
1833+
///
1834+
/// \param options The server options object.
1835+
/// \param config_name The name of the config file to load for all models.
1836+
/// \return a TRITONSERVER_Error indicating success or failure.
1837+
TRITONSERVER_DECLSPEC struct TRITONSERVER_Error*
1838+
TRITONSERVER_ServerOptionsSetModelConfigName(
1839+
struct TRITONSERVER_ServerOptions* options, const char* model_config_name);
1840+
18311841
/// Set the rate limit mode in a server options.
18321842
///
18331843
/// TRITONSERVER_RATE_LIMIT_EXEC_COUNT: The rate limiting prioritizes the

src/constants.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,8 @@ constexpr char kAutoMixedPrecisionExecutionAccelerator[] =
7171
"auto_mixed_precision";
7272

7373
constexpr char kModelConfigPbTxt[] = "config.pbtxt";
74+
constexpr char kPbTxtExtension[] = ".pbtxt";
75+
constexpr char kModelConfigFolder[] = "configs";
7476

7577
constexpr char kMetricsLabelModelNamespace[] = "namespace";
7678
constexpr char kMetricsLabelModelName[] = "model";

src/model_repository_manager/model_lifecycle.cc

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,10 @@ VersionsToLoad(
9292
RETURN_IF_ERROR(GetDirectorySubdirs(model_path, &subdirs));
9393
std::set<int64_t, std::greater<int64_t>> existing_versions;
9494
for (const auto& subdir : subdirs) {
95-
if (subdir == kWarmupDataFolder || subdir == kInitialStateFolder) {
95+
static const std::vector skip_dirs{
96+
kWarmupDataFolder, kInitialStateFolder, kModelConfigFolder};
97+
if (std::find(skip_dirs.begin(), skip_dirs.end(), subdir) !=
98+
skip_dirs.end()) {
9699
continue;
97100
}
98101
if ((subdir.length() > 1) && (subdir.front() == '0')) {

src/model_repository_manager/model_repository_manager.cc

Lines changed: 67 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copyright 2018-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1+
// Copyright 2018-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
//
33
// Redistribution and use in source and binary forms, with or without
44
// modification, are permitted provided that the following conditions
@@ -182,10 +182,39 @@ class LocalizeRepoAgent : public TritonRepoAgent {
182182
}
183183
};
184184

185+
/// Get the model config path to load for the model.
186+
const std::string
187+
GetModelConfigFullPath(
188+
const std::string& model_dir_path, const std::string& custom_config_name)
189+
{
190+
// "--model-config-name" is set. Select custom config from
191+
// "<model_dir_path>/configs" folder if config file exists.
192+
if (!custom_config_name.empty()) {
193+
bool custom_config_exists = false;
194+
const std::string custom_config_path = JoinPath(
195+
{model_dir_path, kModelConfigFolder,
196+
custom_config_name + kPbTxtExtension});
197+
198+
Status status = FileExists(custom_config_path, &custom_config_exists);
199+
if (!status.IsOk()) {
200+
LOG_ERROR << "Failed to get model configuration full path for '"
201+
<< model_dir_path << "': " << status.AsString();
202+
return "";
203+
}
204+
205+
if (custom_config_exists) {
206+
return custom_config_path;
207+
}
208+
}
209+
// "--model-config-name" is not set or custom config file does not exist.
210+
return JoinPath({model_dir_path, kModelConfigPbTxt});
211+
}
212+
185213
Status
186214
CreateAgentModelListWithLoadAction(
187215
const inference::ModelConfig& original_model_config,
188216
const std::string& original_model_path,
217+
const std::string& model_config_name,
189218
std::shared_ptr<TritonRepoAgentModelList>* agent_model_list)
190219
{
191220
if (original_model_config.has_model_repository_agents()) {
@@ -218,7 +247,8 @@ CreateAgentModelListWithLoadAction(
218247
std::unique_ptr<TritonRepoAgentModel> agent_model;
219248
if (lagent_model_list->Size() != 0) {
220249
lagent_model_list->Back()->Location(&artifact_type, &location);
221-
const auto config_path = JoinPath({location, kModelConfigPbTxt});
250+
const auto config_path =
251+
GetModelConfigFullPath(location, model_config_name);
222252
if (!ReadTextProto(config_path, &model_config).IsOk()) {
223253
model_config.Clear();
224254
}
@@ -283,10 +313,12 @@ GetModifiedTime(const std::string& path)
283313
}
284314
// Return the latest modification time in ns for '<config.pbtxt, model files>'
285315
// in a model directory path. The time for "config.pbtxt" will be 0 if not
286-
// found at "[model_dir_path]/config.pbtxt". The time for "model files" includes
287-
// the time for 'model_dir_path'.
316+
// found at "[model_dir_path]/config.pbtxt" or "[model_dir_path]/configs/
317+
// <custom-config-name>.pbtxt" if "--model-config-name" is set. The time for
318+
// "model files" includes the time for 'model_dir_path'.
288319
std::pair<int64_t, int64_t>
289-
GetDetailedModifiedTime(const std::string& model_dir_path)
320+
GetDetailedModifiedTime(
321+
const std::string& model_dir_path, const std::string& model_config_path)
290322
{
291323
// Check if 'model_dir_path' is a directory.
292324
bool is_dir;
@@ -322,12 +354,10 @@ GetDetailedModifiedTime(const std::string& model_dir_path)
322354
}
323355
// Get latest modification time for each files/folders, and place it at the
324356
// correct category.
325-
const std::string model_config_full_path(
326-
JoinPath({model_dir_path, kModelConfigPbTxt}));
327357
for (const auto& child : contents) {
328358
const auto full_path = JoinPath({model_dir_path, child});
329-
if (full_path == model_config_full_path) {
330-
// config.pbtxt
359+
if (full_path == model_config_path) {
360+
// config.pbtxt or customized config file in configs folder
331361
mtime.first = GetModifiedTime(full_path);
332362
} else {
333363
// model files
@@ -343,9 +373,10 @@ GetDetailedModifiedTime(const std::string& model_dir_path)
343373
// modified time.
344374
bool
345375
IsModified(
346-
const std::string& model_dir_path, std::pair<int64_t, int64_t>* last_ns)
376+
const std::string& model_dir_path, const std::string& model_config_path,
377+
std::pair<int64_t, int64_t>* last_ns)
347378
{
348-
auto new_ns = GetDetailedModifiedTime(model_dir_path);
379+
auto new_ns = GetDetailedModifiedTime(model_dir_path, model_config_path);
349380
bool modified = std::max(new_ns.first, new_ns.second) >
350381
std::max(last_ns->first, last_ns->second);
351382
last_ns->swap(new_ns);
@@ -356,10 +387,12 @@ IsModified(
356387

357388
ModelRepositoryManager::ModelRepositoryManager(
358389
const std::set<std::string>& repository_paths, const bool autofill,
359-
const bool polling_enabled, const bool model_control_enabled,
360-
const double min_compute_capability, const bool enable_model_namespacing,
390+
const std::string& model_config_name, const bool polling_enabled,
391+
const bool model_control_enabled, const double min_compute_capability,
392+
const bool enable_model_namespacing,
361393
std::unique_ptr<ModelLifeCycle> life_cycle)
362-
: autofill_(autofill), polling_enabled_(polling_enabled),
394+
: autofill_(autofill), model_config_name_(model_config_name),
395+
polling_enabled_(polling_enabled),
363396
model_control_enabled_(model_control_enabled),
364397
min_compute_capability_(min_compute_capability),
365398
dependency_graph_(&global_map_),
@@ -385,7 +418,8 @@ ModelRepositoryManager::Create(
385418
InferenceServer* server, const std::string& server_version,
386419
const std::set<std::string>& repository_paths,
387420
const std::set<std::string>& startup_models, const bool strict_model_config,
388-
const bool polling_enabled, const bool model_control_enabled,
421+
const std::string& model_config_name, const bool polling_enabled,
422+
const bool model_control_enabled,
389423
const ModelLifeCycleOptions& life_cycle_options,
390424
const bool enable_model_namespacing,
391425
std::unique_ptr<ModelRepositoryManager>* model_repository_manager)
@@ -414,9 +448,10 @@ ModelRepositoryManager::Create(
414448
// Not setting the smart pointer directly to simplify clean up
415449
std::unique_ptr<ModelRepositoryManager> local_manager(
416450
new ModelRepositoryManager(
417-
repository_paths, !strict_model_config, polling_enabled,
418-
model_control_enabled, life_cycle_options.min_compute_capability,
419-
enable_model_namespacing, std::move(life_cycle)));
451+
repository_paths, !strict_model_config, model_config_name,
452+
polling_enabled, model_control_enabled,
453+
life_cycle_options.min_compute_capability, enable_model_namespacing,
454+
std::move(life_cycle)));
420455
*model_repository_manager = std::move(local_manager);
421456

422457
// Support loading all models on startup in explicit model control mode with
@@ -549,7 +584,7 @@ ModelRepositoryManager::LoadModelByDependency(
549584
// encapsulate the interaction:
550585
// Each iteration:
551586
// - Check dependency graph for nodes that are ready for lifecycle changes:
552-
// - load if all dependencies are satisfied and the node is 'heathy'
587+
// - load if all dependencies are satisfied and the node is 'healthy'
553588
// - unload otherwise (should revisit this, logically will only happen in
554589
// ensemble, the ensemble is requested to be re-loaded, at this point
555590
// it is too late to revert model changes so the ensemble will not be
@@ -1298,10 +1333,11 @@ ModelRepositoryManager::Poll(
12981333
// its state will fallback to the state before the polling.
12991334
for (const auto& pair : model_to_path) {
13001335
std::unique_ptr<ModelInfo> model_info;
1336+
const auto& model_name = pair.first.name_;
13011337
// Load with parameters will be appiled to all models with the same
13021338
// name (namespace can be different), unless namespace is specified
13031339
// in the future.
1304-
const auto& mit = models.find(pair.first.name_);
1340+
const auto& mit = models.find(model_name);
13051341
static std::vector<const InferenceParameter*> empty_params;
13061342
auto status = InitializeModelInfo(
13071343
pair.first, pair.second,
@@ -1401,17 +1437,22 @@ ModelRepositoryManager::InitializeModelInfo(
14011437
// the override while the local files may still be unchanged.
14021438
linfo->mtime_nsec_ = std::make_pair(0, 0);
14031439
linfo->model_path_ = location;
1440+
linfo->model_config_path_ = JoinPath({location, kModelConfigPbTxt});
14041441
linfo->agent_model_list_.reset(new TritonRepoAgentModelList());
14051442
linfo->agent_model_list_->AddAgentModel(std::move(localize_agent_model));
14061443
} else {
1444+
linfo->model_config_path_ =
1445+
GetModelConfigFullPath(linfo->model_path_, model_config_name_);
1446+
// Model is not loaded.
14071447
if (iitr == infos_.end()) {
1408-
linfo->mtime_nsec_ = GetDetailedModifiedTime(linfo->model_path_);
1448+
linfo->mtime_nsec_ = GetDetailedModifiedTime(
1449+
linfo->model_path_, linfo->model_config_path_);
14091450
} else {
14101451
// Check the current timestamps to determine if model actually has been
14111452
// modified
14121453
linfo->mtime_nsec_ = linfo->prev_mtime_ns_;
1413-
unmodified =
1414-
!IsModified(std::string(linfo->model_path_), &linfo->mtime_nsec_);
1454+
unmodified = !IsModified(
1455+
linfo->model_path_, linfo->model_config_path_, &linfo->mtime_nsec_);
14151456
}
14161457
}
14171458

@@ -1461,7 +1502,7 @@ ModelRepositoryManager::InitializeModelInfo(
14611502
// this must be done before normalizing model config as agents might
14621503
// redirect to use the model config at a different location
14631504
if (!parsed_config) {
1464-
const auto config_path = JoinPath({linfo->model_path_, kModelConfigPbTxt});
1505+
const auto config_path = linfo->model_config_path_;
14651506
bool model_config_exists = false;
14661507
RETURN_IF_ERROR(FileExists(config_path, &model_config_exists));
14671508
// model config can be missing if auto fill is set
@@ -1474,7 +1515,8 @@ ModelRepositoryManager::InitializeModelInfo(
14741515
}
14751516
if (parsed_config) {
14761517
RETURN_IF_ERROR(CreateAgentModelListWithLoadAction(
1477-
linfo->model_config_, linfo->model_path_, &linfo->agent_model_list_));
1518+
linfo->model_config_, linfo->model_path_, model_config_name_,
1519+
&linfo->agent_model_list_));
14781520
if (linfo->agent_model_list_ != nullptr) {
14791521
// Get the latest repository path
14801522
const char* location;

src/model_repository_manager/model_repository_manager.h

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copyright 2018-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1+
// Copyright 2018-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
//
33
// Redistribution and use in source and binary forms, with or without
44
// modification, are permitted provided that the following conditions
@@ -79,10 +79,10 @@ class ModelRepositoryManager {
7979
ModelInfo(
8080
const std::pair<int64_t, int64_t>& mtime_nsec,
8181
const std::pair<int64_t, int64_t>& prev_mtime_ns,
82-
const std::string& model_path)
82+
const std::string& model_path, const std::string& model_config_path)
8383
: mtime_nsec_(mtime_nsec), prev_mtime_ns_(prev_mtime_ns),
8484
explicitly_load_(true), model_path_(model_path),
85-
is_config_provided_(false)
85+
model_config_path_(model_config_path), is_config_provided_(false)
8686
{
8787
}
8888
ModelInfo()
@@ -97,6 +97,7 @@ class ModelRepositoryManager {
9797
bool explicitly_load_;
9898
inference::ModelConfig model_config_;
9999
std::string model_path_;
100+
std::string model_config_path_;
100101
// Temporary location to hold agent model list before creating the model
101102
// the ownership must transfer to ModelLifeCycle to ensure
102103
// the agent model life cycle is handled properly.
@@ -391,15 +392,17 @@ class ModelRepositoryManager {
391392
/// and the models in the model repository will not be loaded at startup.
392393
/// Otherwise, LoadUnloadModel() is not allowed and the models will be loaded.
393394
/// Cannot be set to true if polling_enabled is true.
395+
/// \param model_config_name Custom model config name to load for all models.
396+
/// Fall back to default config file if empty.
394397
/// \param life_cycle_options The options to configure ModelLifeCycle.
395398
/// \param model_repository_manager Return the model repository manager.
396399
/// \return The error status.
397400
static Status Create(
398401
InferenceServer* server, const std::string& server_version,
399402
const std::set<std::string>& repository_paths,
400403
const std::set<std::string>& startup_models,
401-
const bool strict_model_config, const bool polling_enabled,
402-
const bool model_control_enabled,
404+
const bool strict_model_config, const std::string& model_config_name,
405+
const bool polling_enabled, const bool model_control_enabled,
403406
const ModelLifeCycleOptions& life_cycle_options,
404407
const bool enable_model_namespacing,
405408
std::unique_ptr<ModelRepositoryManager>* model_repository_manager);
@@ -410,7 +413,8 @@ class ModelRepositoryManager {
410413
Status PollAndUpdate();
411414

412415
/// Load or unload a specified model.
413-
/// \param models The models and the parameters to be loaded or unloaded
416+
/// \param models The models and the parameters to be loaded or unloaded.
417+
/// Expect the number of models to be exactly one.
414418
/// \param type The type action to be performed. If the action is LOAD and
415419
/// the model has been loaded, the model will be re-loaded.
416420
/// \return error status. Return "NOT_FOUND" if it tries to load
@@ -509,8 +513,9 @@ class ModelRepositoryManager {
509513

510514
ModelRepositoryManager(
511515
const std::set<std::string>& repository_paths, const bool autofill,
512-
const bool polling_enabled, const bool model_control_enabled,
513-
const double min_compute_capability, const bool enable_model_namespacing,
516+
const std::string& model_config_name, const bool polling_enabled,
517+
const bool model_control_enabled, const double min_compute_capability,
518+
const bool enable_model_namespacing,
514519
std::unique_ptr<ModelLifeCycle> life_cycle);
515520

516521
/// The internal function that are called in Create() and PollAndUpdate().
@@ -614,6 +619,7 @@ class ModelRepositoryManager {
614619
const std::vector<const InferenceParameter*>& model_params);
615620

616621
const bool autofill_;
622+
const std::string model_config_name_;
617623
const bool polling_enabled_;
618624
const bool model_control_enabled_;
619625
const double min_compute_capability_;

src/server.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copyright 2018-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1+
// Copyright 2018-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
//
33
// Redistribution and use in source and binary forms, with or without
44
// modification, are permitted provided that the following conditions
@@ -261,8 +261,8 @@ InferenceServer::Init()
261261
host_policy_map_, model_load_thread_count_, model_load_retry_count_);
262262
status = ModelRepositoryManager::Create(
263263
this, version_, model_repository_paths_, startup_models_,
264-
strict_model_config_, polling_enabled, model_control_enabled,
265-
life_cycle_options, enable_model_namespacing_,
264+
strict_model_config_, model_config_name_, polling_enabled,
265+
model_control_enabled, life_cycle_options, enable_model_namespacing_,
266266
&model_repository_manager_);
267267
if (!status.IsOk()) {
268268
if (model_repository_manager_ == nullptr) {

src/server.h

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copyright 2018-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1+
// Copyright 2018-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
//
33
// Redistribution and use in source and binary forms, with or without
44
// modification, are permitted provided that the following conditions
@@ -177,6 +177,13 @@ class InferenceServer {
177177
bool StrictModelConfigEnabled() const { return strict_model_config_; }
178178
void SetStrictModelConfigEnabled(bool e) { strict_model_config_ = e; }
179179

180+
// Get / set custom model configuration file name.
181+
std::string ModelConfigName() const { return model_config_name_; }
182+
void SetModelConfigName(const std::string& name)
183+
{
184+
model_config_name_ = name;
185+
}
186+
180187
// Get / set rate limiter mode.
181188
RateLimitMode RateLimiterMode() const { return rate_limit_mode_; }
182189
void SetRateLimiterMode(RateLimitMode m) { rate_limit_mode_ = m; }
@@ -333,6 +340,7 @@ class InferenceServer {
333340
ModelControlMode model_control_mode_;
334341
bool strict_model_config_;
335342
bool strict_readiness_;
343+
std::string model_config_name_;
336344
uint32_t exit_timeout_secs_;
337345
uint32_t buffer_manager_thread_count_;
338346
uint32_t model_load_thread_count_;

0 commit comments

Comments
 (0)