Skip to content

Commit 01d2038

Browse files
committed
onnx backends
1 parent 28f15c4 commit 01d2038

File tree

2 files changed

+9
-24
lines changed

2 files changed

+9
-24
lines changed

ai_models/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,4 @@
55
# granted to it by virtue of its status as an intergovernmental organisation
66
# nor does it submit to any jurisdiction.
77

8-
__version__ = "0.4.1"
8+
__version__ = "0.4.2"

ai_models/model.py

Lines changed: 8 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -193,32 +193,13 @@ def torch_deterministic_mode(self):
193193

194194
@cached_property
195195
def providers(self):
196-
import platform
197-
198-
import GPUtil
199196
import onnxruntime as ort
200197

198+
available_providers = ort.get_available_providers()
201199
providers = []
202-
203-
try:
204-
if GPUtil.getAvailable():
205-
providers += [
206-
"CUDAExecutionProvider", # CUDA
207-
]
208-
except Exception:
209-
pass
210-
211-
if sys.platform == "darwin":
212-
if platform.machine() == "arm64":
213-
# This one is not working with error: CoreML does not support input dim > 16384
214-
# providers += ["CoreMLExecutionProvider"]
215-
pass
216-
217-
providers += [
218-
"CPUExecutionProvider", # CPU
219-
]
220-
221-
LOG.info("ONNXRuntime providers: %s", providers)
200+
for n in ["CUDAExecutionProvider", "CPUExecutionProvider"]:
201+
if n in available_providers:
202+
providers.append(n)
222203

223204
LOG.info(
224205
"Using device '%s'. The speed of inference depends greatly on the device.",
@@ -230,6 +211,10 @@ def providers(self):
230211
if ort.get_device() == "CPU":
231212
raise RuntimeError("GPU is not available")
232213

214+
providers = ["CUDAExecutionProvider"]
215+
216+
LOG.info("ONNXRuntime providers: %s", providers)
217+
233218
return providers
234219

235220
def timer(self, title):

0 commit comments

Comments
 (0)