Skip to content

Commit 6a37b5b

Browse files
committed
feat: basic cache
1 parent 694231f commit 6a37b5b

File tree

1 file changed

+24
-3
lines changed

1 file changed

+24
-3
lines changed

app.py

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import datetime
77
from detoxify import Detoxify
88
import logging
9+
from flask_caching import Cache
910

1011
load_dotenv()
1112

@@ -15,6 +16,9 @@
1516
LISTEN_HOST = os.getenv("LISTEN_HOST", "0.0.0.0")
1617
LISTEN_PORT = os.getenv("LISTEN_PORT", "7860")
1718
DETOXIFY_MODEL = os.getenv("DETOXIFY_MODEL", "unbiased-small")
19+
CACHE_DURATION_SECONDS = int(os.getenv("CACHE_DURATION_SECONDS", 60))
20+
ENABLE_CACHE = os.getenv("ENABLE_CACHE", "false") == "true"
21+
1822
APP_VERSION = "0.1.0"
1923

2024
# Setup logging configuration
@@ -40,6 +44,14 @@
4044

4145
app = Flask(__name__)
4246

47+
cache_config = {
48+
"DEBUG": True if APP_ENV != "production" else False,
49+
"CACHE_TYPE": "SimpleCache" if ENABLE_CACHE else "NullCache",
50+
"CACHE_DEFAULT_TIMEOUT": CACHE_DURATION_SECONDS, # Cache duration in seconds
51+
}
52+
cache = Cache(config=cache_config)
53+
cache.init_app(app)
54+
4355

4456
def is_valid_api_key(api_key):
4557
if api_key == API_TOKEN:
@@ -67,6 +79,16 @@ def decorator(*args, **kwargs):
6779
return decorator
6880

6981

82+
def make_key_fn():
83+
"""A function which is called to derive the key for a computed value.
84+
The key in this case is the concat value of all the json request
85+
parameters. Other strategy could to use any hashing function.
86+
:returns: unique string for which the value should be cached.
87+
"""
88+
user_data = request.get_json()
89+
return ",".join([f"{key}={value}" for key, value in user_data.items()])
90+
91+
7092
def perform_hate_speech_analysis(query):
7193
result = {}
7294
df = pd.DataFrame(model.predict(query), index=[0])
@@ -86,6 +108,7 @@ def handle_exception(error):
86108

87109
@app.route("/predict", methods=["POST"])
88110
@api_required
111+
@cache.cached(make_cache_key=make_key_fn)
89112
def predict():
90113
data = request.json
91114
q = data["q"]
@@ -99,9 +122,7 @@ def predict():
99122

100123
@app.route("/", methods=["GET"])
101124
def index():
102-
response = {
103-
"message": "Use /predict route to get prediction result"
104-
}
125+
response = {"message": "Use /predict route to get prediction result"}
105126
return jsonify(response)
106127

107128

0 commit comments

Comments
 (0)