Skip to content

Commit e189837

Browse files
authored
Add MultiprocessingConcurrencyLimiter to gateway (#399)
1 parent 7a956e3 commit e189837

File tree

6 files changed

+77
-83
lines changed

6 files changed

+77
-83
lines changed

charts/model-engine/templates/gateway_deployment.yaml

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -49,13 +49,6 @@ spec:
4949
port: 5000
5050
periodSeconds: 2
5151
failureThreshold: 30
52-
livenessProbe:
53-
httpGet:
54-
path: /healthz
55-
port: 5000
56-
initialDelaySeconds: 5
57-
periodSeconds: 2
58-
failureThreshold: 10
5952
command:
6053
- dumb-init
6154
- --

model-engine/model_engine_server/api/app.py

Lines changed: 35 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from pathlib import Path
66

77
import pytz
8-
from fastapi import FastAPI, Request, Response
8+
from fastapi import FastAPI, HTTPException, Request, Response
99
from fastapi.responses import JSONResponse
1010
from fastapi.staticfiles import StaticFiles
1111
from model_engine_server.api.batch_jobs_v1 import batch_job_router_v1
@@ -21,6 +21,7 @@
2121
from model_engine_server.api.model_endpoints_v1 import model_endpoint_router_v1
2222
from model_engine_server.api.tasks_v1 import inference_task_router_v1
2323
from model_engine_server.api.triggers_v1 import trigger_router_v1
24+
from model_engine_server.common.concurrency_limiter import MultiprocessingConcurrencyLimiter
2425
from model_engine_server.core.loggers import (
2526
LoggerTagKey,
2627
LoggerTagManager,
@@ -32,12 +33,34 @@
3233

3334
logger = make_logger(logger_name())
3435

36+
# Allows us to make the Uvicorn worker concurrency in model_engine_server/api/worker.py very high
37+
MAX_CONCURRENCY = 500
38+
39+
concurrency_limiter = MultiprocessingConcurrencyLimiter(
40+
concurrency=MAX_CONCURRENCY, fail_on_concurrency_limit=True
41+
)
42+
43+
healthcheck_routes = ["/healthcheck", "/healthz", "/readyz"]
44+
3545

3646
class CustomMiddleware(BaseHTTPMiddleware):
3747
async def dispatch(self, request: Request, call_next):
3848
try:
3949
LoggerTagManager.set(LoggerTagKey.REQUEST_ID, str(uuid.uuid4()))
40-
return await call_next(request)
50+
# we intentionally exclude healthcheck routes from the concurrency limiter
51+
if request.url.path in healthcheck_routes:
52+
return await call_next(request)
53+
with concurrency_limiter:
54+
return await call_next(request)
55+
except HTTPException as e:
56+
timestamp = datetime.now(pytz.timezone("US/Pacific")).strftime("%Y-%m-%d %H:%M:%S %Z")
57+
return JSONResponse(
58+
status_code=e.status_code,
59+
content={
60+
"error": e.detail,
61+
"timestamp": timestamp,
62+
},
63+
)
4164
except Exception as e:
4265
tb_str = traceback.format_exception(etype=type(e), value=e, tb=e.__traceback__)
4366
request_id = LoggerTagManager.get(LoggerTagKey.REQUEST_ID)
@@ -49,14 +72,12 @@ async def dispatch(self, request: Request, call_next):
4972
}
5073
logger.error("Unhandled exception: %s", structured_log)
5174
return JSONResponse(
52-
{
53-
"status_code": 500,
54-
"content": {
55-
"error": "Internal error occurred. Our team has been notified.",
56-
"timestamp": timestamp,
57-
"request_id": request_id,
58-
},
59-
}
75+
status_code=500,
76+
content={
77+
"error": "Internal error occurred. Our team has been notified.",
78+
"timestamp": timestamp,
79+
"request_id": request_id,
80+
},
6081
)
6182

6283

@@ -91,9 +112,10 @@ def load_redis():
91112
get_or_create_aioredis_pool()
92113

93114

94-
@app.get("/healthcheck")
95-
@app.get("/healthz")
96-
@app.get("/readyz")
97115
def healthcheck() -> Response:
98116
"""Returns 200 if the app is healthy."""
99117
return Response(status_code=200)
118+
119+
120+
for endpoint in healthcheck_routes:
121+
app.get(endpoint)(healthcheck)

model-engine/model_engine_server/api/worker.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
from uvicorn.workers import UvicornWorker
22

3-
# Gunicorn returns 503 instead of 429 when concurrency exceeds the limit, before adding rate limiting just increase the concurrency
3+
# Gunicorn returns 503 instead of 429 when concurrency exceeds the limit
44
# We'll autoscale at target concurrency of a much lower number (around 50), and this just makes sure we don't 503 with bursty traffic
5-
CONCURRENCY_LIMIT = 1000
5+
# We set this very high since model_engine_server/api/app.py sets a lower per-pod concurrency at which we start returning 429s
6+
CONCURRENCY_LIMIT = 10000
67

78

89
class LaunchWorker(UvicornWorker):
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
from multiprocessing import BoundedSemaphore
2+
from multiprocessing.synchronize import BoundedSemaphore as BoundedSemaphoreType
3+
from typing import Optional
4+
5+
from fastapi import HTTPException
6+
from model_engine_server.core.loggers import logger_name, make_logger
7+
8+
logger = make_logger(logger_name())
9+
10+
11+
class MultiprocessingConcurrencyLimiter:
12+
def __init__(self, concurrency: Optional[int], fail_on_concurrency_limit: bool):
13+
self.concurrency = concurrency
14+
if concurrency is not None:
15+
if concurrency < 1:
16+
raise ValueError("Concurrency should be at least 1")
17+
self.semaphore: Optional[BoundedSemaphoreType] = BoundedSemaphore(value=concurrency)
18+
self.blocking = (
19+
not fail_on_concurrency_limit
20+
) # we want to block if we want to queue up requests
21+
else:
22+
self.semaphore = None
23+
self.blocking = False # Unused
24+
25+
def __enter__(self):
26+
logger.debug("Entering concurrency limiter semaphore")
27+
if self.semaphore and not self.semaphore.acquire(block=self.blocking):
28+
logger.warning(f"Too many requests (max {self.concurrency}), returning 429")
29+
raise HTTPException(status_code=429, detail="Too many requests")
30+
# Just raises an HTTPException.
31+
# __exit__ should not run; otherwise the release() doesn't have an acquire()
32+
33+
def __exit__(self, type, value, traceback):
34+
logger.debug("Exiting concurrency limiter semaphore")
35+
if self.semaphore:
36+
self.semaphore.release()

model-engine/model_engine_server/inference/forwarding/http_forwarder.py

Lines changed: 2 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,9 @@
33
import os
44
import subprocess
55
from functools import lru_cache
6-
from multiprocessing import BoundedSemaphore
7-
from multiprocessing.synchronize import BoundedSemaphore as BoundedSemaphoreType
8-
from typing import Optional
96

10-
from fastapi import Depends, FastAPI, HTTPException
7+
from fastapi import Depends, FastAPI
8+
from model_engine_server.common.concurrency_limiter import MultiprocessingConcurrencyLimiter
119
from model_engine_server.common.dtos.tasks import EndpointPredictV1Request
1210
from model_engine_server.core.loggers import logger_name, make_logger
1311
from model_engine_server.inference.forwarding.forwarding import (
@@ -21,33 +19,6 @@
2119
app = FastAPI()
2220

2321

24-
class MultiprocessingConcurrencyLimiter:
25-
def __init__(self, concurrency: Optional[int], fail_on_concurrency_limit: bool):
26-
if concurrency is not None:
27-
if concurrency < 1:
28-
raise ValueError("Concurrency should be at least 1")
29-
self.semaphore: Optional[BoundedSemaphoreType] = BoundedSemaphore(value=concurrency)
30-
self.blocking = (
31-
not fail_on_concurrency_limit
32-
) # we want to block if we want to queue up requests
33-
else:
34-
self.semaphore = None
35-
self.blocking = False # Unused
36-
37-
def __enter__(self):
38-
logger.debug("Entering concurrency limiter semaphore")
39-
if self.semaphore and not self.semaphore.acquire(block=self.blocking):
40-
logger.warning("Too many requests, returning 429")
41-
raise HTTPException(status_code=429, detail="Too many requests")
42-
# Just raises an HTTPException.
43-
# __exit__ should not run; otherwise the release() doesn't have an acquire()
44-
45-
def __exit__(self, type, value, traceback):
46-
logger.debug("Exiting concurrency limiter semaphore")
47-
if self.semaphore:
48-
self.semaphore.release()
49-
50-
5122
@app.get("/healthz")
5223
@app.get("/readyz")
5324
def healthcheck():

model-engine/model_engine_server/inference/sync_inference/fastapi_server.py

Lines changed: 1 addition & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,8 @@
11
import traceback
22
from functools import wraps
3-
from multiprocessing import BoundedSemaphore
4-
from multiprocessing.synchronize import BoundedSemaphore as BoundedSemaphoreType
5-
from typing import Optional
63

74
from fastapi import BackgroundTasks, FastAPI, HTTPException, Response, status
5+
from model_engine_server.common.concurrency_limiter import MultiprocessingConcurrencyLimiter
86
from model_engine_server.common.dtos.tasks import EndpointPredictV1Request
97
from model_engine_server.core.loggers import logger_name, make_logger
108
from model_engine_server.inference.common import (
@@ -25,33 +23,6 @@
2523
logger = make_logger(logger_name())
2624

2725

28-
class MultiprocessingConcurrencyLimiter:
29-
def __init__(self, concurrency: Optional[int], fail_on_concurrency_limit: bool):
30-
if concurrency is not None:
31-
if concurrency < 1:
32-
raise ValueError("Concurrency should be at least 1")
33-
self.semaphore: Optional[BoundedSemaphoreType] = BoundedSemaphore(value=concurrency)
34-
self.blocking = (
35-
not fail_on_concurrency_limit
36-
) # we want to block if we want to queue up requests
37-
else:
38-
self.semaphore = None
39-
self.blocking = False # Unused
40-
41-
def __enter__(self):
42-
logger.debug("Entering concurrency limiter semaphore")
43-
if self.semaphore and not self.semaphore.acquire(block=self.blocking):
44-
logger.warning("Too many requests, returning 429")
45-
raise HTTPException(status_code=429, detail="Too many requests")
46-
# Just raises an HTTPException.
47-
# __exit__ should not run; otherwise the release() doesn't have an acquire()
48-
49-
def __exit__(self, type, value, traceback):
50-
logger.debug("Exiting concurrency limiter semaphore")
51-
if self.semaphore:
52-
self.semaphore.release()
53-
54-
5526
def with_concurrency_limit(concurrency_limiter: MultiprocessingConcurrencyLimiter):
5627
def _inner(flask_func):
5728
@wraps(flask_func)

0 commit comments

Comments
 (0)