Skip to content

Commit 4ff93d5

Browse files
authored
Merge pull request #4 from atrifat/feat-enable-cuda-support
feat: enable cuda support
2 parents be210ea + b8026ba commit 4ff93d5

File tree

2 files changed

+12
-2
lines changed

2 files changed

+12
-2
lines changed

.env.example

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,4 +10,7 @@ DETOXIFY_MODEL=unbiased-small
1010
ENABLE_CACHE=false
1111

1212
# (Optional. Default: 60. Options: integer value) Duration of cache in seconds. Parameter will be used when the cache is enabled
13-
CACHE_DURATION_SECONDS=60
13+
CACHE_DURATION_SECONDS=60
14+
15+
# (Optional. Default: auto. Options: auto,cpu,cuda) Set torch default device for detoxify library. Automatically detect if cuda/gpu device is available
16+
TORCH_DEVICE=auto

app.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from flask import Flask, request, jsonify
55
import functools
66
import datetime
7+
import torch
78
from detoxify import Detoxify
89
import logging
910
from flask_caching import Cache
@@ -18,6 +19,7 @@
1819
DETOXIFY_MODEL = os.getenv("DETOXIFY_MODEL", "unbiased-small")
1920
CACHE_DURATION_SECONDS = int(os.getenv("CACHE_DURATION_SECONDS", 60))
2021
ENABLE_CACHE = os.getenv("ENABLE_CACHE", "false") == "true"
22+
TORCH_DEVICE = os.getenv("TORCH_DEVICE", "auto")
2123

2224
APP_VERSION = "0.1.0"
2325

@@ -40,7 +42,12 @@
4042
if ENABLE_API_TOKEN and API_TOKEN == "":
4143
raise Exception("API_TOKEN is required if ENABLE_API_TOKEN is enabled")
4244

43-
model = Detoxify(DETOXIFY_MODEL)
45+
if TORCH_DEVICE == "auto":
46+
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
47+
else:
48+
torch_device = TORCH_DEVICE
49+
50+
model = Detoxify(DETOXIFY_MODEL, device=torch_device)
4451

4552
app = Flask(__name__)
4653

0 commit comments

Comments
 (0)