1
- // Copyright 2018-2023 , NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1
+ // Copyright 2018-2024 , NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
2
//
3
3
// Redistribution and use in source and binary forms, with or without
4
4
// modification, are permitted provided that the following conditions
@@ -182,10 +182,39 @@ class LocalizeRepoAgent : public TritonRepoAgent {
182
182
}
183
183
};
184
184
185
+ // / Get the model config path to load for the model.
186
+ const std::string
187
+ GetModelConfigFullPath (
188
+ const std::string& model_dir_path, const std::string& custom_config_name)
189
+ {
190
+ // "--model-config-name" is set. Select custom config from
191
+ // "<model_dir_path>/configs" folder if config file exists.
192
+ if (!custom_config_name.empty ()) {
193
+ bool custom_config_exists = false ;
194
+ const std::string custom_config_path = JoinPath (
195
+ {model_dir_path, kModelConfigFolder ,
196
+ custom_config_name + kPbTxtExtension });
197
+
198
+ Status status = FileExists (custom_config_path, &custom_config_exists);
199
+ if (!status.IsOk ()) {
200
+ LOG_ERROR << " Failed to get model configuration full path for '"
201
+ << model_dir_path << " ': " << status.AsString ();
202
+ return " " ;
203
+ }
204
+
205
+ if (custom_config_exists) {
206
+ return custom_config_path;
207
+ }
208
+ }
209
+ // "--model-config-name" is not set or custom config file does not exist.
210
+ return JoinPath ({model_dir_path, kModelConfigPbTxt });
211
+ }
212
+
185
213
Status
186
214
CreateAgentModelListWithLoadAction (
187
215
const inference::ModelConfig& original_model_config,
188
216
const std::string& original_model_path,
217
+ const std::string& model_config_name,
189
218
std::shared_ptr<TritonRepoAgentModelList>* agent_model_list)
190
219
{
191
220
if (original_model_config.has_model_repository_agents ()) {
@@ -218,7 +247,8 @@ CreateAgentModelListWithLoadAction(
218
247
std::unique_ptr<TritonRepoAgentModel> agent_model;
219
248
if (lagent_model_list->Size () != 0 ) {
220
249
lagent_model_list->Back ()->Location (&artifact_type, &location);
221
- const auto config_path = JoinPath ({location, kModelConfigPbTxt });
250
+ const auto config_path =
251
+ GetModelConfigFullPath (location, model_config_name);
222
252
if (!ReadTextProto (config_path, &model_config).IsOk ()) {
223
253
model_config.Clear ();
224
254
}
@@ -283,10 +313,12 @@ GetModifiedTime(const std::string& path)
283
313
}
284
314
// Return the latest modification time in ns for '<config.pbtxt, model files>'
285
315
// in a model directory path. The time for "config.pbtxt" will be 0 if not
286
- // found at "[model_dir_path]/config.pbtxt". The time for "model files" includes
287
- // the time for 'model_dir_path'.
316
+ // found at "[model_dir_path]/config.pbtxt" or "[model_dir_path]/configs/
317
+ // <custom-config-name>.pbtxt" if "--model-config-name" is set. The time for
318
+ // "model files" includes the time for 'model_dir_path'.
288
319
std::pair<int64_t , int64_t >
289
- GetDetailedModifiedTime (const std::string& model_dir_path)
320
+ GetDetailedModifiedTime (
321
+ const std::string& model_dir_path, const std::string& model_config_path)
290
322
{
291
323
// Check if 'model_dir_path' is a directory.
292
324
bool is_dir;
@@ -322,12 +354,10 @@ GetDetailedModifiedTime(const std::string& model_dir_path)
322
354
}
323
355
// Get latest modification time for each files/folders, and place it at the
324
356
// correct category.
325
- const std::string model_config_full_path (
326
- JoinPath ({model_dir_path, kModelConfigPbTxt }));
327
357
for (const auto & child : contents) {
328
358
const auto full_path = JoinPath ({model_dir_path, child});
329
- if (full_path == model_config_full_path ) {
330
- // config.pbtxt
359
+ if (full_path == model_config_path ) {
360
+ // config.pbtxt or customized config file in configs folder
331
361
mtime.first = GetModifiedTime (full_path);
332
362
} else {
333
363
// model files
@@ -343,9 +373,10 @@ GetDetailedModifiedTime(const std::string& model_dir_path)
343
373
// modified time.
344
374
bool
345
375
IsModified (
346
- const std::string& model_dir_path, std::pair<int64_t , int64_t >* last_ns)
376
+ const std::string& model_dir_path, const std::string& model_config_path,
377
+ std::pair<int64_t , int64_t >* last_ns)
347
378
{
348
- auto new_ns = GetDetailedModifiedTime (model_dir_path);
379
+ auto new_ns = GetDetailedModifiedTime (model_dir_path, model_config_path );
349
380
bool modified = std::max (new_ns.first , new_ns.second ) >
350
381
std::max (last_ns->first , last_ns->second );
351
382
last_ns->swap (new_ns);
@@ -356,10 +387,12 @@ IsModified(
356
387
357
388
ModelRepositoryManager::ModelRepositoryManager (
358
389
const std::set<std::string>& repository_paths, const bool autofill,
359
- const bool polling_enabled, const bool model_control_enabled,
360
- const double min_compute_capability, const bool enable_model_namespacing,
390
+ const std::string& model_config_name, const bool polling_enabled,
391
+ const bool model_control_enabled, const double min_compute_capability,
392
+ const bool enable_model_namespacing,
361
393
std::unique_ptr<ModelLifeCycle> life_cycle)
362
- : autofill_(autofill), polling_enabled_(polling_enabled),
394
+ : autofill_(autofill), model_config_name_(model_config_name),
395
+ polling_enabled_ (polling_enabled),
363
396
model_control_enabled_(model_control_enabled),
364
397
min_compute_capability_(min_compute_capability),
365
398
dependency_graph_(&global_map_),
@@ -385,7 +418,8 @@ ModelRepositoryManager::Create(
385
418
InferenceServer* server, const std::string& server_version,
386
419
const std::set<std::string>& repository_paths,
387
420
const std::set<std::string>& startup_models, const bool strict_model_config,
388
- const bool polling_enabled, const bool model_control_enabled,
421
+ const std::string& model_config_name, const bool polling_enabled,
422
+ const bool model_control_enabled,
389
423
const ModelLifeCycleOptions& life_cycle_options,
390
424
const bool enable_model_namespacing,
391
425
std::unique_ptr<ModelRepositoryManager>* model_repository_manager)
@@ -414,9 +448,10 @@ ModelRepositoryManager::Create(
414
448
// Not setting the smart pointer directly to simplify clean up
415
449
std::unique_ptr<ModelRepositoryManager> local_manager (
416
450
new ModelRepositoryManager (
417
- repository_paths, !strict_model_config, polling_enabled,
418
- model_control_enabled, life_cycle_options.min_compute_capability ,
419
- enable_model_namespacing, std::move (life_cycle)));
451
+ repository_paths, !strict_model_config, model_config_name,
452
+ polling_enabled, model_control_enabled,
453
+ life_cycle_options.min_compute_capability , enable_model_namespacing,
454
+ std::move (life_cycle)));
420
455
*model_repository_manager = std::move (local_manager);
421
456
422
457
// Support loading all models on startup in explicit model control mode with
@@ -549,7 +584,7 @@ ModelRepositoryManager::LoadModelByDependency(
549
584
// encapsulate the interaction:
550
585
// Each iteration:
551
586
// - Check dependency graph for nodes that are ready for lifecycle changes:
552
- // - load if all dependencies are satisfied and the node is 'heathy '
587
+ // - load if all dependencies are satisfied and the node is 'healthy '
553
588
// - unload otherwise (should revisit this, logically will only happen in
554
589
// ensemble, the ensemble is requested to be re-loaded, at this point
555
590
// it is too late to revert model changes so the ensemble will not be
@@ -1298,10 +1333,11 @@ ModelRepositoryManager::Poll(
1298
1333
// its state will fallback to the state before the polling.
1299
1334
for (const auto & pair : model_to_path) {
1300
1335
std::unique_ptr<ModelInfo> model_info;
1336
+ const auto & model_name = pair.first .name_ ;
1301
1337
// Load with parameters will be appiled to all models with the same
1302
1338
// name (namespace can be different), unless namespace is specified
1303
1339
// in the future.
1304
- const auto & mit = models.find (pair. first . name_ );
1340
+ const auto & mit = models.find (model_name );
1305
1341
static std::vector<const InferenceParameter*> empty_params;
1306
1342
auto status = InitializeModelInfo (
1307
1343
pair.first , pair.second ,
@@ -1401,17 +1437,22 @@ ModelRepositoryManager::InitializeModelInfo(
1401
1437
// the override while the local files may still be unchanged.
1402
1438
linfo->mtime_nsec_ = std::make_pair (0 , 0 );
1403
1439
linfo->model_path_ = location;
1440
+ linfo->model_config_path_ = JoinPath ({location, kModelConfigPbTxt });
1404
1441
linfo->agent_model_list_ .reset (new TritonRepoAgentModelList ());
1405
1442
linfo->agent_model_list_ ->AddAgentModel (std::move (localize_agent_model));
1406
1443
} else {
1444
+ linfo->model_config_path_ =
1445
+ GetModelConfigFullPath (linfo->model_path_ , model_config_name_);
1446
+ // Model is not loaded.
1407
1447
if (iitr == infos_.end ()) {
1408
- linfo->mtime_nsec_ = GetDetailedModifiedTime (linfo->model_path_ );
1448
+ linfo->mtime_nsec_ = GetDetailedModifiedTime (
1449
+ linfo->model_path_ , linfo->model_config_path_ );
1409
1450
} else {
1410
1451
// Check the current timestamps to determine if model actually has been
1411
1452
// modified
1412
1453
linfo->mtime_nsec_ = linfo->prev_mtime_ns_ ;
1413
- unmodified =
1414
- ! IsModified ( std::string ( linfo->model_path_ ) , &linfo->mtime_nsec_ );
1454
+ unmodified = ! IsModified (
1455
+ linfo->model_path_ , linfo-> model_config_path_ , &linfo->mtime_nsec_ );
1415
1456
}
1416
1457
}
1417
1458
@@ -1461,7 +1502,7 @@ ModelRepositoryManager::InitializeModelInfo(
1461
1502
// this must be done before normalizing model config as agents might
1462
1503
// redirect to use the model config at a different location
1463
1504
if (!parsed_config) {
1464
- const auto config_path = JoinPath ({ linfo->model_path_ , kModelConfigPbTxt }) ;
1505
+ const auto config_path = linfo->model_config_path_ ;
1465
1506
bool model_config_exists = false ;
1466
1507
RETURN_IF_ERROR (FileExists (config_path, &model_config_exists));
1467
1508
// model config can be missing if auto fill is set
@@ -1474,7 +1515,8 @@ ModelRepositoryManager::InitializeModelInfo(
1474
1515
}
1475
1516
if (parsed_config) {
1476
1517
RETURN_IF_ERROR (CreateAgentModelListWithLoadAction (
1477
- linfo->model_config_ , linfo->model_path_ , &linfo->agent_model_list_ ));
1518
+ linfo->model_config_ , linfo->model_path_ , model_config_name_,
1519
+ &linfo->agent_model_list_ ));
1478
1520
if (linfo->agent_model_list_ != nullptr ) {
1479
1521
// Get the latest repository path
1480
1522
const char * location;
0 commit comments