Skip to content

Commit f062ac5

Browse files
committed
fix: disabled tts gpu acceleration on rocm
1 parent 24a3ac3 commit f062ac5

File tree

4 files changed

+11
-6
lines changed

4 files changed

+11
-6
lines changed

src/coqui_engine.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -186,8 +186,10 @@ void coqui_engine::create_model() {
186186
model_file.find("fairseq") != std::string::npos ||
187187
model_file.find("xtts") != std::string::npos;
188188

189-
auto use_cuda = m_config.use_gpu &&
190-
py_executor::instance()->libs_availability->torch_cuda;
189+
auto use_cuda =
190+
m_config.use_gpu &&
191+
(py_executor::instance()->libs_availability->torch_cuda ||
192+
py_executor::instance()->libs_availability->torch_hip);
191193

192194
LOGD("using device: " << (use_cuda ? "cuda" : "cpu") << " "
193195
<< m_config.gpu_device.id);

src/punctuator.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@ using namespace pybind11::literals;
1818

1919
punctuator::punctuator(const std::string& model_path, int device) {
2020
auto task = py_executor::instance()->execute(
21-
[&, dev = py_executor::instance()->libs_availability->torch_cuda
21+
[&, dev = (py_executor::instance()->libs_availability->torch_cuda ||
22+
py_executor::instance()->libs_availability->torch_hip)
2223
? device
2324
: -1]() {
2425
try {

src/py_tools.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ std::ostream& operator<<(std::ostream& os,
4343
<< ", gruut_sw=" << availability.gruut_sw
4444
<< ", mecab=" << availability.mecab
4545
<< ", torch-cuda=" << availability.torch_cuda
46-
<< ", torch-hip=" << availability.torch_cuda;
46+
<< ", torch-hip=" << availability.torch_hip;
4747

4848
return os;
4949
}

src/whisperspeech_engine.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -117,8 +117,10 @@ void whisperspeech_engine::create_model() {
117117

118118
LOGD("model files: " << s2a_file << " " << t2s_file);
119119

120-
auto use_cuda = m_config.use_gpu &&
121-
py_executor::instance()->libs_availability->torch_cuda;
120+
auto use_cuda =
121+
m_config.use_gpu &&
122+
(py_executor::instance()->libs_availability->torch_cuda ||
123+
py_executor::instance()->libs_availability->torch_hip);
122124

123125
LOGD("using device: " << (use_cuda ? "cuda" : "cpu") << " "
124126
<< m_config.gpu_device.id);

0 commit comments

Comments
 (0)