7
7
from typing import Tuple
8
8
9
9
from fastapi import APIRouter , Depends , status
10
+ from fastapi .requests import Request
10
11
from fastapi .responses import JSONResponse
11
12
from sqlalchemy .exc import IntegrityError
12
13
from sqlalchemy .ext .asyncio import AsyncSession
13
14
14
15
from ..auth .dependencies import authenticate_key , rate_limiter
15
- from ..config import CUSTOM_SPEECH_ENDPOINT , GCS_SPEECH_BUCKET
16
+ from ..config import CUSTOM_SPEECH_ENDPOINT , GCS_SPEECH_BUCKET , USE_CROSS_ENCODER
16
17
from ..contents .models import (
17
18
get_similar_content_async ,
18
19
increment_query_count ,
30
31
generate_llm_query_response ,
31
32
generate_tts__after ,
32
33
)
34
+ from ..schemas import QuerySearchResult
33
35
from ..users .models import UserDB
34
36
from ..utils import (
35
37
create_langfuse_metadata ,
39
41
setup_logger ,
40
42
upload_file_to_gcs ,
41
43
)
42
- from .config import N_TOP_CONTENT
44
+ from .config import N_TOP_CONTENT , N_TOP_CONTENT_TO_CROSSENCODER
43
45
from .models import (
44
46
QueryDB ,
45
47
check_secret_key_match ,
88
90
)
89
91
async def search (
90
92
user_query : QueryBase ,
93
+ request : Request ,
91
94
asession : AsyncSession = Depends (get_async_session ),
92
95
user_db : UserDB = Depends (authenticate_key ),
93
96
) -> QueryResponse | JSONResponse :
@@ -114,8 +117,10 @@ async def search(
114
117
response = response_template ,
115
118
user_id = user_db .user_id ,
116
119
n_similar = int (N_TOP_CONTENT ),
120
+ n_to_crossencoder = int (N_TOP_CONTENT_TO_CROSSENCODER ),
117
121
asession = asession ,
118
122
exclude_archived = True ,
123
+ request = request ,
119
124
)
120
125
121
126
if user_query .generate_llm_response :
@@ -138,17 +143,18 @@ async def search(
138
143
asession = asession ,
139
144
)
140
145
141
- if type (response ) is QueryResponse :
146
+ if isinstance (response , QueryResponse ) :
142
147
return response
143
- elif type (response ) is QueryResponseError :
148
+
149
+ if isinstance (response , QueryResponseError ):
144
150
return JSONResponse (
145
151
status_code = status .HTTP_400_BAD_REQUEST , content = response .model_dump ()
146
152
)
147
- else :
148
- return JSONResponse (
149
- status_code = status .HTTP_500_INTERNAL_SERVER_ERROR ,
150
- content = {"message" : "Internal server error" },
151
- )
153
+
154
+ return JSONResponse (
155
+ status_code = status .HTTP_500_INTERNAL_SERVER_ERROR ,
156
+ content = {"message" : "Internal server error" },
157
+ )
152
158
153
159
154
160
@router .post (
@@ -167,6 +173,7 @@ async def search(
167
173
)
168
174
async def voice_search (
169
175
file_url : str ,
176
+ request : Request ,
170
177
asession : AsyncSession = Depends (get_async_session ),
171
178
user_db : UserDB = Depends (authenticate_key ),
172
179
) -> QueryAudioResponse | JSONResponse :
@@ -222,8 +229,10 @@ async def voice_search(
222
229
response = response_template ,
223
230
user_id = user_db .user_id ,
224
231
n_similar = int (N_TOP_CONTENT ),
232
+ n_to_crossencoder = int (N_TOP_CONTENT_TO_CROSSENCODER ),
225
233
asession = asession ,
226
234
exclude_archived = True ,
235
+ request = request ,
227
236
)
228
237
229
238
if user_query .generate_llm_response :
@@ -250,17 +259,18 @@ async def voice_search(
250
259
os .remove (file_path )
251
260
file_stream .close ()
252
261
253
- if type (response ) is QueryAudioResponse :
262
+ if isinstance (response , QueryAudioResponse ) :
254
263
return response
255
- elif type (response ) is QueryResponseError :
264
+
265
+ if isinstance (response , QueryResponseError ):
256
266
return JSONResponse (
257
267
status_code = status .HTTP_400_BAD_REQUEST , content = response .model_dump ()
258
268
)
259
- else :
260
- return JSONResponse (
261
- status_code = status .HTTP_500_INTERNAL_SERVER_ERROR ,
262
- content = {"error" : "Internal server error" },
263
- )
269
+
270
+ return JSONResponse (
271
+ status_code = status .HTTP_500_INTERNAL_SERVER_ERROR ,
272
+ content = {"error" : "Internal server error" },
273
+ )
264
274
265
275
except ValueError as ve :
266
276
logger .error (f"ValueError: { str (ve )} " )
@@ -328,7 +338,9 @@ async def get_search_response(
328
338
response : QueryResponse ,
329
339
user_id : int ,
330
340
n_similar : int ,
341
+ n_to_crossencoder : int ,
331
342
asession : AsyncSession ,
343
+ request : Request ,
332
344
exclude_archived : bool = True ,
333
345
) -> QueryResponse | QueryResponseError :
334
346
"""Get similar content and construct the LLM answer for the user query.
@@ -347,8 +359,12 @@ async def get_search_response(
347
359
The ID of the user making the query.
348
360
n_similar
349
361
The number of similar contents to retrieve.
362
+ n_to_crossencoder
363
+ The number of similar contents to send to the cross-encoder.
350
364
asession
351
365
`AsyncSession` object for database transactions.
366
+ request
367
+ The FastAPI request object.
352
368
exclude_archived
353
369
Specifies whether to exclude archived content.
354
370
@@ -362,19 +378,56 @@ async def get_search_response(
362
378
# always do the embeddings search even if some guardrails have failed
363
379
metadata = create_langfuse_metadata (query_id = response .query_id , user_id = user_id )
364
380
381
+ if USE_CROSS_ENCODER == "True" and (n_to_crossencoder < n_similar ):
382
+ raise ValueError (
383
+ "`n_to_crossencoder` must be less than or equal to `n_similar`."
384
+ )
385
+
365
386
search_results = await get_similar_content_async (
366
387
user_id = user_id ,
367
388
question = query_refined .query_text , # use latest transformed version of the text
368
- n_similar = n_similar ,
389
+ n_similar = n_to_crossencoder if USE_CROSS_ENCODER == "True" else n_similar ,
369
390
asession = asession ,
370
391
metadata = metadata ,
371
392
exclude_archived = exclude_archived ,
372
393
)
394
+
395
+ if USE_CROSS_ENCODER and (len (search_results ) > 1 ):
396
+ search_results = rerank_search_results (
397
+ n_similar = n_similar ,
398
+ search_results = search_results ,
399
+ query_text = query_refined .query_text ,
400
+ request = request ,
401
+ )
402
+
373
403
response .search_results = search_results
374
404
375
405
return response
376
406
377
407
408
+ def rerank_search_results (
409
+ search_results : dict [int , QuerySearchResult ],
410
+ n_similar : int ,
411
+ query_text : str ,
412
+ request : Request ,
413
+ ) -> dict [int , QuerySearchResult ]:
414
+ """
415
+ Rerank search results based on the similarity of the content to the query text
416
+ """
417
+ encoder = request .app .state .crossencoder
418
+ contents = search_results .values ()
419
+ scores = encoder .predict (
420
+ [(query_text , content .title + "\n " + content .text ) for content in contents ]
421
+ )
422
+
423
+ sorted_by_score = [v for _ , v in sorted (zip (scores , contents ), reverse = True )][
424
+ :n_similar
425
+ ]
426
+ reranked_search_results = dict (enumerate (sorted_by_score ))
427
+
428
+ return reranked_search_results
429
+
430
+
378
431
@generate_tts__after
379
432
@check_align_score__after
380
433
async def get_generation_response (
@@ -418,6 +471,8 @@ async def get_user_query_and_response(
418
471
The user query database object.
419
472
asession
420
473
`AsyncSession` object for database transactions.
474
+ generate_tts
475
+ Specifies whether to generate a TTS audio response
421
476
422
477
Returns
423
478
-------
0 commit comments