27
27
└──────────────────────────────────────────────────────────────────────────────┘
28
28
"""
29
29
30
+ import uuid
30
31
from fastapi import (
31
32
APIRouter ,
32
33
Depends ,
33
34
HTTPException ,
34
35
status ,
35
36
WebSocket ,
36
37
WebSocketDisconnect ,
38
+ Header ,
37
39
)
38
40
from sqlalchemy .orm import Session
39
41
from src .config .database import get_db
57
59
from datetime import datetime
58
60
import logging
59
61
import json
62
+ from typing import Optional , Dict
60
63
61
64
logger = logging .getLogger (__name__ )
62
65
67
70
)
68
71
69
72
73
+ async def get_agent_by_api_key (
74
+ agent_id : str ,
75
+ api_key : Optional [str ] = Header (None , alias = "x-api-key" ),
76
+ authorization : Optional [str ] = Header (None ),
77
+ db : Session = Depends (get_db ),
78
+ ):
79
+ """Flexible authentication for chat routes, allowing JWT or API key"""
80
+ if authorization :
81
+ # Try to authenticate with JWT token first
82
+ try :
83
+ # Extract token from Authorization header if needed
84
+ token = (
85
+ authorization .replace ("Bearer " , "" )
86
+ if authorization .startswith ("Bearer " )
87
+ else authorization
88
+ )
89
+ payload = await get_jwt_token (token )
90
+ agent = agent_service .get_agent (db , agent_id )
91
+ if not agent :
92
+ raise HTTPException (
93
+ status_code = status .HTTP_404_NOT_FOUND ,
94
+ detail = "Agent not found" ,
95
+ )
96
+
97
+ # Verify if the user has access to the agent's client
98
+ await verify_user_client (payload , db , agent .client_id )
99
+ return agent
100
+ except Exception as e :
101
+ logger .warning (f"JWT authentication failed: { str (e )} " )
102
+ # If JWT fails, continue to try with API key
103
+
104
+ # Try to authenticate with API key
105
+ if not api_key :
106
+ raise HTTPException (
107
+ status_code = status .HTTP_401_UNAUTHORIZED ,
108
+ detail = "Authentication required (JWT or API key)" ,
109
+ )
110
+
111
+ agent = agent_service .get_agent (db , agent_id )
112
+ if not agent or not agent .config :
113
+ raise HTTPException (
114
+ status_code = status .HTTP_404_NOT_FOUND , detail = "Agent not found"
115
+ )
116
+
117
+ # Verify if the API key matches
118
+ if not agent .config .get ("api_key" ) or agent .config .get ("api_key" ) != api_key :
119
+ raise HTTPException (
120
+ status_code = status .HTTP_401_UNAUTHORIZED , detail = "Invalid API key"
121
+ )
122
+
123
+ return agent
124
+
125
+
70
126
@router .websocket ("/ws/{agent_id}/{external_id}" )
71
127
async def websocket_chat (
72
128
websocket : WebSocket ,
@@ -82,32 +138,49 @@ async def websocket_chat(
82
138
# Wait for authentication message
83
139
try :
84
140
auth_data = await websocket .receive_json ()
85
- logger .info (f"Received authentication data: { auth_data } " )
141
+ logger .info (f"Authentication data received : { auth_data } " )
86
142
87
143
if not (
88
- auth_data .get ("type" ) == "authorization" and auth_data .get ("token" )
144
+ auth_data .get ("type" ) == "authorization"
145
+ and (auth_data .get ("token" ) or auth_data .get ("api_key" ))
89
146
):
90
147
logger .warning ("Invalid authentication message" )
91
148
await websocket .close (code = status .WS_1008_POLICY_VIOLATION )
92
149
return
93
150
94
- token = auth_data ["token" ]
95
- # Verify the token
96
- payload = await get_jwt_token_ws (token )
97
- if not payload :
98
- logger .warning ("Invalid token" )
99
- await websocket .close (code = status .WS_1008_POLICY_VIOLATION )
100
- return
101
-
102
- # Verify if the agent belongs to the user's client
151
+ # Verify if the agent exists
103
152
agent = agent_service .get_agent (db , agent_id )
104
153
if not agent :
105
154
logger .warning (f"Agent { agent_id } not found" )
106
155
await websocket .close (code = status .WS_1008_POLICY_VIOLATION )
107
156
return
108
157
109
- # Verify if the user has access to the agent (via client)
110
- await verify_user_client (payload , db , agent .client_id )
158
+ # Verify authentication
159
+ is_authenticated = False
160
+
161
+ # Try with JWT token
162
+ if auth_data .get ("token" ):
163
+ try :
164
+ payload = await get_jwt_token_ws (auth_data ["token" ])
165
+ if payload :
166
+ # Verify if the user has access to the agent
167
+ await verify_user_client (payload , db , agent .client_id )
168
+ is_authenticated = True
169
+ except Exception as e :
170
+ logger .warning (f"JWT authentication failed: { str (e )} " )
171
+
172
+ # If JWT fails, try with API key
173
+ if not is_authenticated and auth_data .get ("api_key" ):
174
+ if agent .config and agent .config .get ("api_key" ) == auth_data .get (
175
+ "api_key"
176
+ ):
177
+ is_authenticated = True
178
+ else :
179
+ logger .warning ("Invalid API key" )
180
+
181
+ if not is_authenticated :
182
+ await websocket .close (code = status .WS_1008_POLICY_VIOLATION )
183
+ return
111
184
112
185
logger .info (
113
186
f"WebSocket connection established for agent { agent_id } and external_id { external_id } "
@@ -174,19 +247,9 @@ async def websocket_chat(
174
247
)
175
248
async def chat (
176
249
request : ChatRequest ,
250
+ _ = Depends (get_agent_by_api_key ),
177
251
db : Session = Depends (get_db ),
178
- payload : dict = Depends (get_jwt_token ),
179
252
):
180
- # Verify if the agent belongs to the user's client
181
- agent = agent_service .get_agent (db , request .agent_id )
182
- if not agent :
183
- raise HTTPException (
184
- status_code = status .HTTP_404_NOT_FOUND , detail = "Agent not found"
185
- )
186
-
187
- # Verify if the user has access to the agent (via client)
188
- await verify_user_client (payload , db , agent .client_id )
189
-
190
253
try :
191
254
final_response = await run_agent (
192
255
request .agent_id ,
0 commit comments