Skip to content

Commit 1c3d63b

Browse files
authored
feat: gunicorn (#591)
1 parent 61e8640 commit 1c3d63b

File tree

11 files changed

+155
-70
lines changed

11 files changed

+155
-70
lines changed

docker/Dockerfile.api

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,4 +22,4 @@ FROM base as final
2222
ENV PATH="/venv/bin:${PATH}"
2323
ENV VIRTUAL_ENV="/venv"
2424
COPY --from=builder /venv /venv
25-
ENTRYPOINT ["keep", "--json", "api"]
25+
ENTRYPOINT ["gunicorn", "keep.api.api:get_app", "--bind" , "0.0.0.0:8080" , "--workers", "4" , "-k" , "uvicorn.workers.UvicornWorker" ]

keep-ui/app/settings/smtp-settings.tsx

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ import { Card, Button, Title, Subtitle, TextInput } from "@tremor/react";
33
import useSWR from "swr";
44
import { getApiURL } from "utils/apiUrl";
55
import { fetcher } from "utils/fetcher";
6+
import Loading from "app/loading";
67

78
interface SMTPSettings {
89
host: string;
@@ -91,8 +92,9 @@ export default function SMTPSettingsForm({ accessToken, selectedTab }: Props) {
9192
);
9293

9394
// Show loading state or error messages if needed
94-
if (isLoading) return <div>Loading...</div>; // Loading state
95-
if (error) return <div>Error: {error.message}</div>;
95+
if (smtpSettings === undefined || isLoading) {
96+
return <Loading />;
97+
}
9698

9799
// if no errors and we have data, set the settings
98100
if (smtpSettings) {

keep/api/api.py

Lines changed: 41 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from dotenv import find_dotenv, load_dotenv
1010
from fastapi import FastAPI, HTTPException, Request, Response
1111
from fastapi.responses import JSONResponse
12+
from opentelemetry import trace
1213
from starlette.middleware.base import BaseHTTPMiddleware
1314
from starlette.middleware.cors import CORSMiddleware
1415
from starlette_context import plugins
@@ -61,6 +62,7 @@ class EventCaptureMiddleware(BaseHTTPMiddleware):
6162
def __init__(self, app: FastAPI):
6263
super().__init__(app)
6364
self.posthog_client = get_posthog_client()
65+
self.tracer = trace.get_tracer(__name__)
6466

6567
def _extract_identity(self, request: Request) -> str:
6668
try:
@@ -70,42 +72,45 @@ def _extract_identity(self, request: Request) -> str:
7072
except Exception:
7173
return "anonymous"
7274

73-
def capture_request(self, request: Request) -> None:
75+
async def capture_request(self, request: Request) -> None:
7476
identity = self._extract_identity(request)
75-
self.posthog_client.capture(
76-
identity,
77-
"request-started",
78-
{"path": request.url.path, "method": request.method},
79-
)
77+
with self.tracer.start_as_current_span("capture_request"):
78+
self.posthog_client.capture(
79+
identity,
80+
"request-started",
81+
{"path": request.url.path, "method": request.method},
82+
)
8083

81-
def capture_response(self, request: Request, response: Response) -> None:
84+
async def capture_response(self, request: Request, response: Response) -> None:
8285
identity = self._extract_identity(request)
83-
self.posthog_client.capture(
84-
identity,
85-
"request-finished",
86-
{
87-
"path": request.url.path,
88-
"method": request.method,
89-
"status_code": response.status_code,
90-
},
91-
)
86+
with self.tracer.start_as_current_span("capture_response"):
87+
self.posthog_client.capture(
88+
identity,
89+
"request-finished",
90+
{
91+
"path": request.url.path,
92+
"method": request.method,
93+
"status_code": response.status_code,
94+
},
95+
)
9296

93-
def flush(self):
94-
logger.info("Flushing Posthog events")
95-
self.posthog_client.flush()
96-
logger.info("Posthog events flushed")
97+
async def flush(self):
98+
with self.tracer.start_as_current_span("flush_posthog_events"):
99+
logger.info("Flushing Posthog events")
100+
self.posthog_client.flush()
101+
logger.info("Posthog events flushed")
97102

98103
async def dispatch(self, request: Request, call_next):
99104
# Skip OPTIONS requests
100105
if request.method == "OPTIONS":
101106
return await call_next(request)
102107
# Capture event before request
103-
self.capture_request(request)
108+
await self.capture_request(request)
104109

105110
response = await call_next(request)
106111

107112
# Capture event after request
108-
self.capture_response(request, response)
113+
await self.capture_response(request, response)
109114

110115
# Perform async tasks or flush events after the request is handled
111116
self.flush()
@@ -219,6 +224,13 @@ async def on_startup():
219224
] = verify_token_or_key_single_tenant
220225
try_create_single_tenant(SINGLE_TENANT_UUID)
221226

227+
# load all providers into cache
228+
from keep.providers.providers_factory import ProvidersFactory
229+
230+
logger.info("Loading providers into cache")
231+
ProvidersFactory.get_all_providers()
232+
logger.info("Providers loaded successfully")
233+
222234
@app.exception_handler(Exception)
223235
async def catch_exception(request: Request, exc: Exception):
224236
logging.error(
@@ -233,6 +245,13 @@ async def catch_exception(request: Request, exc: Exception):
233245
},
234246
)
235247

248+
@app.middleware("http")
249+
async def log_middeware(request: Request, call_next):
250+
logger.info(f"Request started: {request.method} {request.url.path}")
251+
response = await call_next(request)
252+
logger.info(f"Request finished: {request.method} {request.url.path}")
253+
return response
254+
236255
keep.api.observability.setup(app)
237256

238257
if os.environ.get("USE_NGROK", "false") == "true":

keep/api/core/db.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -119,8 +119,12 @@ def get_session() -> Session:
119119
Yields:
120120
Session: A database session
121121
"""
122-
with Session(engine) as session:
123-
yield session
122+
from opentelemetry import trace
123+
124+
tracer = trace.get_tracer(__name__)
125+
with tracer.start_as_current_span("get_session"):
126+
with Session(engine) as session:
127+
yield session
124128

125129

126130
def try_create_single_tenant(tenant_id: str) -> None:

keep/api/core/dependencies.py

Lines changed: 49 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -33,18 +33,22 @@
3333

3434

3535
def get_user_email(request: Request) -> str | None:
36-
token = request.headers.get("Authorization")
37-
if token:
38-
token = token.split(" ")[1]
39-
decoded_token = jwt.decode(token, options={"verify_signature": False})
40-
return decoded_token.get("email")
41-
elif "x-api-key" in request.headers:
42-
username = get_user_by_api_key(request.headers["x-api-key"])
43-
return username
44-
else:
45-
raise HTTPException(
46-
status_code=401, detail="Invalid authentication credentials"
47-
)
36+
from opentelemetry import trace
37+
38+
tracer = trace.get_tracer(__name__)
39+
with tracer.start_as_current_span("get_user_email"):
40+
token = request.headers.get("Authorization")
41+
if token:
42+
token = token.split(" ")[1]
43+
decoded_token = jwt.decode(token, options={"verify_signature": False})
44+
return decoded_token.get("email")
45+
elif "x-api-key" in request.headers:
46+
username = get_user_by_api_key(request.headers["x-api-key"])
47+
return username
48+
else:
49+
raise HTTPException(
50+
status_code=401, detail="Invalid authentication credentials"
51+
)
4852

4953

5054
def __extract_api_key(
@@ -133,34 +137,42 @@ def verify_api_key(
133137
return tenant_api_key.tenant_id
134138

135139

140+
# init once so the cache will actually work
141+
auth_domain = os.environ.get("AUTH0_DOMAIN")
142+
if auth_domain:
143+
jwks_uri = f"https://{auth_domain}/.well-known/jwks.json"
144+
jwks_client = jwt.PyJWKClient(jwks_uri, cache_keys=True)
145+
146+
136147
def verify_bearer_token(token: str = Depends(oauth2_scheme)) -> str:
137148
# Took the implementation from here:
138149
# https://github.com/auth0-developer-hub/api_fastapi_python_hello-world/blob/main/application/json_web_token.py
139-
if not token:
140-
raise HTTPException(status_code=401, detail="No token provided 👈")
141-
try:
142-
auth_domain = os.environ.get("AUTH0_DOMAIN")
143-
auth_audience = os.environ.get("AUTH0_AUDIENCE")
144-
jwks_uri = f"https://{auth_domain}/.well-known/jwks.json"
145-
issuer = f"https://{auth_domain}/"
146-
jwks_client = jwt.PyJWKClient(jwks_uri)
147-
jwt_signing_key = jwks_client.get_signing_key_from_jwt(token).key
148-
payload = jwt.decode(
149-
token,
150-
jwt_signing_key,
151-
algorithms="RS256",
152-
audience=auth_audience,
153-
issuer=issuer,
154-
leeway=60,
155-
)
156-
tenant_id = payload.get("keep_tenant_id")
157-
return tenant_id
158-
except jwt.exceptions.DecodeError:
159-
logger.exception("Failed to decode token")
160-
raise HTTPException(status_code=401, detail="Token is not a valid JWT")
161-
except Exception as e:
162-
logger.exception("Failed to validate token")
163-
raise HTTPException(status_code=401, detail=str(e))
150+
from opentelemetry import trace
151+
152+
tracer = trace.get_tracer(__name__)
153+
with tracer.start_as_current_span("verify_bearer_token"):
154+
if not token:
155+
raise HTTPException(status_code=401, detail="No token provided 👈")
156+
try:
157+
auth_audience = os.environ.get("AUTH0_AUDIENCE")
158+
issuer = f"https://{auth_domain}/"
159+
jwt_signing_key = jwks_client.get_signing_key_from_jwt(token).key
160+
payload = jwt.decode(
161+
token,
162+
jwt_signing_key,
163+
algorithms="RS256",
164+
audience=auth_audience,
165+
issuer=issuer,
166+
leeway=60,
167+
)
168+
tenant_id = payload.get("keep_tenant_id")
169+
return tenant_id
170+
except jwt.exceptions.DecodeError:
171+
logger.exception("Failed to decode token")
172+
raise HTTPException(status_code=401, detail="Token is not a valid JWT")
173+
except Exception as e:
174+
logger.exception("Failed to validate token")
175+
raise HTTPException(status_code=401, detail=str(e))
164176

165177

166178
def get_user_email_single_tenant(request: Request) -> str:

keep/api/logging.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def dump(self):
6767
"disable_existing_loggers": False,
6868
"formatters": {
6969
"json": {
70-
"format": "%(asctime)s %(message)s %(levelname)s %(name)s %(filename)s %(otelTraceID)s %(otelSpanID)s %(otelServiceName)s",
70+
"format": "%(asctime)s %(message)s %(levelname)s %(name)s %(filename)s %(otelTraceID)s %(otelSpanID)s %(otelServiceName)s %(threadName)s %(process)s",
7171
"class": "pythonjsonlogger.jsonlogger.JsonFormatter",
7272
}
7373
},

keep/api/routes/settings.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import io
22
import json
3+
import logging
34
import os
45
import secrets
56
import smtplib
@@ -27,6 +28,8 @@
2728

2829
router = APIRouter()
2930

31+
logger = logging.getLogger(__name__)
32+
3033

3134
@router.get(
3235
"/webhook",
@@ -36,6 +39,7 @@ def webhook_settings(
3639
tenant_id: str = Depends(verify_bearer_token),
3740
session: Session = Depends(get_session),
3841
) -> WebhookSettings:
42+
logger.info("Getting webhook settings")
3943
api_url = config("KEEP_API_URL")
4044
keep_webhook_api_url = f"{api_url}/alerts/event"
4145
webhook_api_key = get_or_create_api_key(
@@ -45,6 +49,7 @@ def webhook_settings(
4549
unique_api_key_id="webhook",
4650
system_description="Webhooks API key",
4751
)
52+
logger.info("Webhook settings retrieved successfully")
4853
return WebhookSettings(
4954
webhookApi=keep_webhook_api_url,
5055
apiKey=webhook_api_key,
@@ -180,12 +185,14 @@ async def get_smtp_settings(
180185
tenant_id: str = Depends(verify_bearer_token),
181186
session: Session = Depends(get_session),
182187
):
188+
logger.info("Getting SMTP settings")
183189
context_manager = ContextManager(tenant_id=tenant_id)
184190
secret_manager = SecretManagerFactory.get_secret_manager(context_manager)
185191
# Read the SMTP settings from the secret manager
186192
try:
187193
smtp_settings = secret_manager.read_secret(secret_name="smtp")
188194
smtp_settings = json.loads(smtp_settings)
195+
logger.info("SMTP settings retrieved successfully")
189196
return JSONResponse(status_code=200, content=smtp_settings)
190197
except Exception:
191198
# everything ok but no smtp settings
@@ -278,6 +285,7 @@ def get_api_key(
278285
session: Session = Depends(get_session),
279286
user_name: str = Depends(get_user_email),
280287
):
288+
logger.info("Getting API key")
281289
# get the api key for the CLI
282290
api_key = get_or_create_api_key(
283291
session=session,
@@ -286,4 +294,5 @@ def get_api_key(
286294
unique_api_key_id="cli",
287295
system_description="API key",
288296
)
297+
logger.info("API key retrieved successfully")
289298
return {"apiKey": api_key}

keep/contextmanager/contextmanager.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ class ContextManager:
1515
STATE_FILE = "keepstate.json"
1616

1717
def __init__(
18-
self, tenant_id, workflow_id=None, workflow_execution_id=None, load_state=True
18+
self, tenant_id, workflow_id=None, workflow_execution_id=None, load_state=False
1919
):
2020
self.logger = logging.getLogger(__name__)
2121
self.logger_adapter = WorkflowLoggerAdapter(
@@ -37,7 +37,7 @@ def __init__(
3737
except RuntimeError:
3838
self.click_context = {}
3939
self.aliases = {}
40-
self.state = {}
40+
self._state = {}
4141
# dependencies are used so iohandler will be able to use the output class of the providers
4242
# e.g. let's say bigquery_provider results are google.cloud.bigquery.Row
4343
# and we want to use it in iohandler, we need to import it before the eval
@@ -190,7 +190,7 @@ def set_step_context(self, step_id, results, foreach=False):
190190

191191
def __load_state(self):
192192
try:
193-
self.state = json.loads(
193+
self._state = json.loads(
194194
self.storage_manager.get_file(
195195
self.tenant_id, self.state_file, create_if_not_exist=True
196196
)
@@ -201,7 +201,13 @@ def __load_state(self):
201201
f"State storage: {self.storage_manager.__class__.__name__}"
202202
)
203203
self.logger.warning(f"Reason: {exc}")
204-
self.state = {}
204+
self._state = {}
205+
206+
@property
207+
def state(self):
208+
if not self._state:
209+
self.__load_state()
210+
return self._state
205211

206212
def get_last_workflow_run(self, workflow_id):
207213
if workflow_id in self.state:

keep/providers/providers_factory.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ class ProviderConfigurationException(Exception):
2424

2525

2626
class ProvidersFactory:
27+
_loaded_providers_cache = None
28+
2729
@staticmethod
2830
def get_provider_class(provider_type: str) -> BaseProvider:
2931
provider_type_split = provider_type.split(
@@ -157,6 +159,13 @@ def get_all_providers() -> list[Provider]:
157159
Returns:
158160
list: All the providers.
159161
"""
162+
logger = logging.getLogger(__name__)
163+
# use the cache if exists
164+
if ProvidersFactory._loaded_providers_cache:
165+
logger.info("Using cached providers")
166+
return ProvidersFactory._loaded_providers_cache
167+
168+
logger.info("Loading providers")
160169
providers = []
161170
blacklisted_providers = [
162171
"base_provider",
@@ -278,6 +287,8 @@ def get_all_providers() -> list[Provider]:
278287
except ModuleNotFoundError:
279288
logger.exception(f"Cannot import provider {provider_directory}")
280289
continue
290+
291+
ProvidersFactory._loaded_providers_cache = providers
281292
return providers
282293

283294
@staticmethod

0 commit comments

Comments
 (0)