|
1 | 1 | #!/usr/bin/env python3
|
2 | 2 |
|
3 |
| -# Copyright 2019-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. |
| 3 | +# Copyright 2019-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. |
4 | 4 | #
|
5 | 5 | # Redistribution and use in source and binary forms, with or without
|
6 | 6 | # modification, are permitted provided that the following conditions
|
|
27 | 27 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
28 | 28 |
|
29 | 29 | import argparse
|
| 30 | +import ctypes |
30 | 31 | import os
|
31 | 32 |
|
32 | 33 | import numpy as np
|
|
38 | 39 | TRT_LOGGER = trt.Logger()
|
39 | 40 |
|
40 | 41 | trt.init_libnvinfer_plugins(TRT_LOGGER, "")
|
41 |
| -PLUGIN_CREATORS = trt.get_plugin_registry().plugin_creator_list |
42 | 42 |
|
43 | 43 |
|
44 | 44 | def get_trt_plugin(plugin_name):
|
45 | 45 | plugin = None
|
46 | 46 | field_collection = None
|
47 |
| - for plugin_creator in PLUGIN_CREATORS: |
| 47 | + plugin_creators = trt.get_plugin_registry().plugin_creator_list |
| 48 | + for plugin_creator in plugin_creators: |
48 | 49 | if (plugin_creator.name == "CustomHardmax") and (
|
49 | 50 | plugin_name == "CustomHardmax"
|
50 | 51 | ):
|
@@ -272,13 +273,37 @@ def create_plugin_models(models_dir):
|
272 | 273 | )
|
273 | 274 |
|
274 | 275 |
|
| 276 | +def windows_load_plugin_lib(win_plugin_dll): |
| 277 | + if os.path.isfile(win_plugin_dll): |
| 278 | + try: |
| 279 | + ctypes.CDLL(win_plugin_dll, winmode=0) |
| 280 | + except TypeError: |
| 281 | + # winmode only introduced in python 3.8 |
| 282 | + ctypes.CDLL(win_plugin_dll) |
| 283 | + return |
| 284 | + |
| 285 | + raise IOError('Failed to load library: "{}".'.format(win_plugin_dll)) |
| 286 | + |
| 287 | + |
275 | 288 | if __name__ == "__main__":
|
276 | 289 | parser = argparse.ArgumentParser()
|
277 | 290 | parser.add_argument(
|
278 | 291 | "--models_dir", type=str, required=True, help="Top-level model directory"
|
279 | 292 | )
|
| 293 | + parser.add_argument( |
| 294 | + "--win_plugin_dll", |
| 295 | + type=str, |
| 296 | + required=False, |
| 297 | + default="", |
| 298 | + help="Path to Windows plugin .dll", |
| 299 | + ) |
280 | 300 | FLAGS, unparsed = parser.parse_known_args()
|
281 | 301 |
|
282 | 302 | import test_util as tu
|
283 | 303 |
|
| 304 | + # Linux can leverage LD_PRELOAD. We must load the Windows plugin manually |
| 305 | + # in order for it to be discovered in the registry. |
| 306 | + if os.name == "nt": |
| 307 | + windows_load_plugin_lib(FLAGS.win_plugin_dll) |
| 308 | + |
284 | 309 | create_plugin_models(FLAGS.models_dir)
|
0 commit comments