Skip to content

initial tensorrt ep commit #921

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Jun 6, 2024
2 changes: 2 additions & 0 deletions sherpa-onnx/csrc/provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ Provider StringToProvider(std::string s) {
return Provider::kXnnpack;
} else if (s == "nnapi") {
return Provider::kNNAPI;
} else if (s == "trt") {
return Provider::kTRT;
} else {
SHERPA_ONNX_LOGE("Unsupported string: %s. Fallback to cpu", s.c_str());
return Provider::kCPU;
Expand Down
1 change: 1 addition & 0 deletions sherpa-onnx/csrc/provider.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ enum class Provider {
kCoreML = 2, // CoreMLExecutionProvider
kXnnpack = 3, // XnnpackExecutionProvider
kNNAPI = 4, // NnapiExecutionProvider
kTRT = 5, // TensorRTExecutionProvider
};

/**
Expand Down
62 changes: 61 additions & 1 deletion sherpa-onnx/csrc/session.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,16 @@

namespace sherpa_onnx {


static void OrtStatusFailure(OrtStatus *status, const char *s) {
const auto &api = Ort::GetApi();
const char *msg = api.GetErrorMessage(status);
SHERPA_ONNX_LOGE(
"Failed to enable TensorRT : %s."
"Available providers: %s. Fallback to cuda", msg, s);
api.ReleaseStatus(status);
}

static Ort::SessionOptions GetSessionOptionsImpl(int32_t num_threads,
std::string provider_str) {
Provider p = StringToProvider(std::move(provider_str));
Expand Down Expand Up @@ -53,6 +63,57 @@ static Ort::SessionOptions GetSessionOptionsImpl(int32_t num_threads,
}
break;
}
case Provider::kTRT: {
struct TrtPairs {
const char* op_keys;
const char* op_values;
};

std::vector<TrtPairs> trt_options = {
{"device_id", "0"},
{"trt_max_workspace_size", "2147483648"},
{"trt_max_partition_iterations", "10"},
{"trt_min_subgraph_size", "5"},
{"trt_fp16_enable", "0"},
{"trt_detailed_build_log", "0"},
{"trt_engine_cache_enable", "1"},
{"trt_engine_cache_path", "."},
{"trt_timing_cache_enable", "1"},
{"trt_timing_cache_path", "."}
};
// ToDo : Trt configs
// "trt_int8_enable"
// "trt_int8_use_native_calibration_table"
// "trt_dump_subgraphs"

std::vector<const char*> option_keys, option_values;
for (const TrtPairs& pair : trt_options) {
option_keys.emplace_back(pair.op_keys);
option_values.emplace_back(pair.op_values);
}

std::vector<std::string> available_providers =
Ort::GetAvailableProviders();
if (std::find(available_providers.begin(), available_providers.end(),
"TensorrtExecutionProvider") != available_providers.end()) {
const auto& api = Ort::GetApi();

OrtTensorRTProviderOptionsV2* tensorrt_options;
OrtStatus *statusC = api.CreateTensorRTProviderOptions(
&tensorrt_options);
OrtStatus *statusU = api.UpdateTensorRTProviderOptions(
tensorrt_options, option_keys.data(), option_values.data(),
option_keys.size());
sess_opts.AppendExecutionProvider_TensorRT_V2(*tensorrt_options);

if (statusC) { OrtStatusFailure(statusC, os.str().c_str()); }
if (statusU) { OrtStatusFailure(statusU, os.str().c_str()); }

api.ReleaseTensorRTProviderOptions(tensorrt_options);
}
// break; is omitted here intentionally so that
// if TRT not available, CUDA will be used
}
case Provider::kCUDA: {
if (std::find(available_providers.begin(), available_providers.end(),
"CUDAExecutionProvider") != available_providers.end()) {
Expand Down Expand Up @@ -116,7 +177,6 @@ static Ort::SessionOptions GetSessionOptionsImpl(int32_t num_threads,
break;
}
}

return sess_opts;
}

Expand Down
Loading