Skip to content

Commit f67a7d7

Browse files
committed
feat: create initial code
1 parent 6b2ff1e commit f67a7d7

File tree

1 file changed

+115
-0
lines changed

1 file changed

+115
-0
lines changed

app.py

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
import os
2+
from dotenv import load_dotenv
3+
import pandas as pd
4+
from flask import Flask, request, jsonify
5+
import functools
6+
import datetime
7+
from detoxify import Detoxify
8+
import logging
9+
10+
load_dotenv()
11+
12+
ENABLE_API_TOKEN = os.getenv("ENABLE_API_TOKEN", "false") == "true"
13+
API_TOKEN = os.getenv("API_TOKEN", "")
14+
APP_ENV = os.getenv("APP_ENV", "production")
15+
LISTEN_HOST = os.getenv("LISTEN_HOST", "0.0.0.0")
16+
LISTEN_PORT = os.getenv("LISTEN_PORT", "7860")
17+
DETOXIFY_MODEL = os.getenv("DETOXIFY_MODEL", "unbiased-small")
18+
APP_VERSION = "0.0.0"
19+
20+
# Setup logging configuration
21+
LOGGING_DATE_FORMAT = "%Y-%m-%d %H:%M:%S"
22+
LOGGING_FORMAT = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
23+
if APP_ENV == "production":
24+
logging.basicConfig(
25+
level=logging.INFO,
26+
datefmt=LOGGING_DATE_FORMAT,
27+
format=LOGGING_FORMAT,
28+
)
29+
else:
30+
logging.basicConfig(
31+
level=logging.DEBUG,
32+
datefmt=LOGGING_DATE_FORMAT,
33+
format=LOGGING_FORMAT,
34+
)
35+
36+
if ENABLE_API_TOKEN and API_TOKEN == "":
37+
raise Exception("API_TOKEN is required if ENABLE_API_TOKEN is enabled")
38+
39+
model = Detoxify(DETOXIFY_MODEL)
40+
41+
app = Flask(__name__)
42+
43+
44+
def is_valid_api_key(api_key):
45+
if api_key == API_TOKEN:
46+
return True
47+
else:
48+
return False
49+
50+
51+
def api_required(func):
52+
@functools.wraps(func)
53+
def decorator(*args, **kwargs):
54+
if ENABLE_API_TOKEN:
55+
if request.json:
56+
api_key = request.json.get("api_key")
57+
else:
58+
return {"message": "Please provide an API key"}, 400
59+
# Check if API key is correct and valid
60+
if request.method == "POST" and is_valid_api_key(api_key):
61+
return func(*args, **kwargs)
62+
else:
63+
return {"message": "The provided API key is not valid"}, 403
64+
else:
65+
return func(*args, **kwargs)
66+
67+
return decorator
68+
69+
70+
def perform_hate_speech_analysis(query):
71+
result = {}
72+
df = pd.DataFrame(model.predict(query), index=[0])
73+
columns = df.columns
74+
75+
for i, label in enumerate(columns):
76+
result[label] = df[label][0].round(3).astype("float")
77+
78+
return result
79+
80+
81+
@app.errorhandler(Exception)
82+
def handle_exception(error):
83+
res = {"error": str(error)}
84+
return jsonify(res)
85+
86+
87+
@app.route("/predict", methods=["POST"])
88+
@api_required
89+
def predict():
90+
data = request.json
91+
q = data["q"]
92+
start_time = datetime.datetime.now()
93+
result = perform_hate_speech_analysis(q)
94+
end_time = datetime.datetime.now()
95+
elapsed_time = end_time - start_time
96+
logging.debug("elapsed predict time: %s", str(elapsed_time))
97+
return jsonify(result)
98+
99+
100+
@app.route("/", methods=["GET"])
101+
def index():
102+
response = {
103+
"message": "Use /predict and /predict_sentiment route to get prediction result"
104+
}
105+
return jsonify(response)
106+
107+
108+
@app.route("/app_version", methods=["GET"])
109+
def app_version():
110+
response = {"message": "This app version is ".APP_VERSION}
111+
return jsonify(response)
112+
113+
114+
if __name__ == "__main__":
115+
app.run(host=LISTEN_HOST, port=LISTEN_PORT)

0 commit comments

Comments
 (0)