Skip to content

Commit 0b9f81f

Browse files
committed
feat: enable cuda device support
1 parent be210ea commit 0b9f81f

File tree

1 file changed

+8
-1
lines changed

1 file changed

+8
-1
lines changed

app.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from flask import Flask, request, jsonify
55
import functools
66
import datetime
7+
import torch
78
from detoxify import Detoxify
89
import logging
910
from flask_caching import Cache
@@ -18,6 +19,7 @@
1819
DETOXIFY_MODEL = os.getenv("DETOXIFY_MODEL", "unbiased-small")
1920
CACHE_DURATION_SECONDS = int(os.getenv("CACHE_DURATION_SECONDS", 60))
2021
ENABLE_CACHE = os.getenv("ENABLE_CACHE", "false") == "true"
22+
TORCH_DEVICE = os.getenv("TORCH_DEVICE", "auto")
2123

2224
APP_VERSION = "0.1.0"
2325

@@ -40,7 +42,12 @@
4042
if ENABLE_API_TOKEN and API_TOKEN == "":
4143
raise Exception("API_TOKEN is required if ENABLE_API_TOKEN is enabled")
4244

43-
model = Detoxify(DETOXIFY_MODEL)
45+
if TORCH_DEVICE == "auto":
46+
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
47+
else:
48+
torch_device = TORCH_DEVICE
49+
50+
model = Detoxify(DETOXIFY_MODEL, device=torch_device)
4451

4552
app = Flask(__name__)
4653

0 commit comments

Comments
 (0)