Skip to content

Commit 48f7c2e

Browse files
authored
Merge pull request #2 from atrifat/feat-basic-cache
feat: basic cache
2 parents 694231f + 105ccba commit 48f7c2e

File tree

3 files changed

+32
-4
lines changed

3 files changed

+32
-4
lines changed

.env.example

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,10 @@ APP_ENV=production
44
LISTEN_HOST=0.0.0.0
55
LISTEN_PORT=7860
66
# (Required, Default: unbiased-small) Set this with detoxify model. Check list of model in detoxify github repository.
7-
DETOXIFY_MODEL=unbiased-small
7+
DETOXIFY_MODEL=unbiased-small
8+
9+
# (Optional. Default: false. Options: true or false) Check whether the cache will be enabled or not
10+
ENABLE_CACHE=false
11+
12+
# (Optional. Default: 60. Options: integer value) Duration of cache in seconds. Parameter will be used when the cache is enabled
13+
CACHE_DURATION_SECONDS=60

app.py

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import datetime
77
from detoxify import Detoxify
88
import logging
9+
from flask_caching import Cache
910

1011
load_dotenv()
1112

@@ -15,6 +16,9 @@
1516
LISTEN_HOST = os.getenv("LISTEN_HOST", "0.0.0.0")
1617
LISTEN_PORT = os.getenv("LISTEN_PORT", "7860")
1718
DETOXIFY_MODEL = os.getenv("DETOXIFY_MODEL", "unbiased-small")
19+
CACHE_DURATION_SECONDS = int(os.getenv("CACHE_DURATION_SECONDS", 60))
20+
ENABLE_CACHE = os.getenv("ENABLE_CACHE", "false") == "true"
21+
1822
APP_VERSION = "0.1.0"
1923

2024
# Setup logging configuration
@@ -40,6 +44,14 @@
4044

4145
app = Flask(__name__)
4246

47+
cache_config = {
48+
"DEBUG": True if APP_ENV != "production" else False,
49+
"CACHE_TYPE": "SimpleCache" if ENABLE_CACHE else "NullCache",
50+
"CACHE_DEFAULT_TIMEOUT": CACHE_DURATION_SECONDS, # Cache duration in seconds
51+
}
52+
cache = Cache(config=cache_config)
53+
cache.init_app(app)
54+
4355

4456
def is_valid_api_key(api_key):
4557
if api_key == API_TOKEN:
@@ -67,6 +79,16 @@ def decorator(*args, **kwargs):
6779
return decorator
6880

6981

82+
def make_key_fn():
83+
"""A function which is called to derive the key for a computed value.
84+
The key in this case is the concat value of all the json request
85+
parameters. Other strategy could to use any hashing function.
86+
:returns: unique string for which the value should be cached.
87+
"""
88+
user_data = request.get_json()
89+
return ",".join([f"{key}={value}" for key, value in user_data.items()])
90+
91+
7092
def perform_hate_speech_analysis(query):
7193
result = {}
7294
df = pd.DataFrame(model.predict(query), index=[0])
@@ -86,6 +108,7 @@ def handle_exception(error):
86108

87109
@app.route("/predict", methods=["POST"])
88110
@api_required
111+
@cache.cached(make_cache_key=make_key_fn)
89112
def predict():
90113
data = request.json
91114
q = data["q"]
@@ -99,9 +122,7 @@ def predict():
99122

100123
@app.route("/", methods=["GET"])
101124
def index():
102-
response = {
103-
"message": "Use /predict route to get prediction result"
104-
}
125+
response = {"message": "Use /predict route to get prediction result"}
105126
return jsonify(response)
106127

107128

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
detoxify==0.5.1
22
Flask==3.0.0
3+
Flask-Caching==2.3.0
34
gunicorn==21.2.0
45
pandas==2.1.1
56
python-dotenv==1.0.0

0 commit comments

Comments
 (0)