Skip to content

Commit d58417c

Browse files
authored
Paraphrase prompt fix & cross encoder (#414)
* added cross-encoder * Skip rerank when 1 or fewer results * updated diagram * add requirement * removed use of SERVICE_IDENTITY * added check for n_top:
1 parent e2ac83b commit d58417c

File tree

7 files changed

+97
-23
lines changed

7 files changed

+97
-23
lines changed

core_backend/app/__init__.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from fastapi.middleware.cors import CORSMiddleware
88
from prometheus_client import CollectorRegistry, make_asgi_app, multiprocess
99
from redis import asyncio as aioredis
10+
from sentence_transformers import CrossEncoder
1011

1112
from . import (
1213
admin,
@@ -20,7 +21,7 @@
2021
urgency_rules,
2122
user_tools,
2223
)
23-
from .config import DOMAIN, LANGFUSE, REDIS_HOST
24+
from .config import CROSS_ENCODER_MODEL, DOMAIN, LANGFUSE, REDIS_HOST, USE_CROSS_ENCODER
2425
from .prometheus_middleware import PrometheusMiddleware
2526
from .utils import setup_logger
2627

@@ -92,7 +93,13 @@ async def lifespan(app: FastAPI) -> AsyncIterator[None]:
9293

9394
logger.info("Application started")
9495
app.state.redis = await aioredis.from_url(REDIS_HOST)
96+
if USE_CROSS_ENCODER == "True":
97+
app.state.crossencoder = CrossEncoder(
98+
CROSS_ENCODER_MODEL,
99+
)
100+
95101
yield
102+
96103
await app.state.redis.close()
97104
logger.info("Application finished")
98105

core_backend/app/config.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,12 @@
6363
SERVICE_IDENTITY = os.environ.get(
6464
"SERVICE_IDENTITY", "air pollution and air quality chatbot"
6565
)
66+
# Cross-encoder
67+
USE_CROSS_ENCODER = os.environ.get("USE_CROSS_ENCODER", "True")
68+
CROSS_ENCODER_MODEL = os.environ.get(
69+
"CROSS_ENCODER_MODEL", "cross-encoder/ms-marco-MiniLM-L-6-v2"
70+
)
71+
6672
# Rate limit variables
6773
CHECK_CONTENT_LIMIT = os.environ.get("CHECK_CONTENT_LIMIT", True)
6874
DEFAULT_CONTENT_QUOTA = int(os.environ.get("DEFAULT_CONTENT_QUOTA", 50))

core_backend/app/llm_call/llm_prompts.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -151,11 +151,11 @@ def get_prompt(cls) -> str:
151151
},
152152
]
153153
PARAPHRASE_PROMPT = f"""You are a high-performing paraphrasing bot. \
154-
The user has sent a message.
154+
The user has sent a message for a question-answering service.
155155
156156
If the message is a question, do not answer it, \
157-
just paraphrase it to remove unecessary information and focus on the question. \
158-
Remove any irrelevant or offensive words.
157+
just paraphrase it to focus on the question and include any relevant information.\
158+
Remove any irrelevant or offensive words
159159
160160
If the input message is not a question, respond with the same message but \
161161
remove any irrelevant or offensive words.
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import os
22

33
# Functionality variables
4+
N_TOP_CONTENT_TO_CROSSENCODER = os.environ.get("N_TOP_CONTENT_TO_CROSSENCODER", "10")
45
N_TOP_CONTENT = os.environ.get("N_TOP_CONTENT", "4")

core_backend/app/question_answer/routers.py

Lines changed: 72 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,13 @@
77
from typing import Tuple
88

99
from fastapi import APIRouter, Depends, status
10+
from fastapi.requests import Request
1011
from fastapi.responses import JSONResponse
1112
from sqlalchemy.exc import IntegrityError
1213
from sqlalchemy.ext.asyncio import AsyncSession
1314

1415
from ..auth.dependencies import authenticate_key, rate_limiter
15-
from ..config import CUSTOM_SPEECH_ENDPOINT, GCS_SPEECH_BUCKET
16+
from ..config import CUSTOM_SPEECH_ENDPOINT, GCS_SPEECH_BUCKET, USE_CROSS_ENCODER
1617
from ..contents.models import (
1718
get_similar_content_async,
1819
increment_query_count,
@@ -30,6 +31,7 @@
3031
generate_llm_query_response,
3132
generate_tts__after,
3233
)
34+
from ..schemas import QuerySearchResult
3335
from ..users.models import UserDB
3436
from ..utils import (
3537
create_langfuse_metadata,
@@ -39,7 +41,7 @@
3941
setup_logger,
4042
upload_file_to_gcs,
4143
)
42-
from .config import N_TOP_CONTENT
44+
from .config import N_TOP_CONTENT, N_TOP_CONTENT_TO_CROSSENCODER
4345
from .models import (
4446
QueryDB,
4547
check_secret_key_match,
@@ -88,6 +90,7 @@
8890
)
8991
async def search(
9092
user_query: QueryBase,
93+
request: Request,
9194
asession: AsyncSession = Depends(get_async_session),
9295
user_db: UserDB = Depends(authenticate_key),
9396
) -> QueryResponse | JSONResponse:
@@ -114,8 +117,10 @@ async def search(
114117
response=response_template,
115118
user_id=user_db.user_id,
116119
n_similar=int(N_TOP_CONTENT),
120+
n_to_crossencoder=int(N_TOP_CONTENT_TO_CROSSENCODER),
117121
asession=asession,
118122
exclude_archived=True,
123+
request=request,
119124
)
120125

121126
if user_query.generate_llm_response:
@@ -138,17 +143,18 @@ async def search(
138143
asession=asession,
139144
)
140145

141-
if type(response) is QueryResponse:
146+
if isinstance(response, QueryResponse):
142147
return response
143-
elif type(response) is QueryResponseError:
148+
149+
if isinstance(response, QueryResponseError):
144150
return JSONResponse(
145151
status_code=status.HTTP_400_BAD_REQUEST, content=response.model_dump()
146152
)
147-
else:
148-
return JSONResponse(
149-
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
150-
content={"message": "Internal server error"},
151-
)
153+
154+
return JSONResponse(
155+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
156+
content={"message": "Internal server error"},
157+
)
152158

153159

154160
@router.post(
@@ -167,6 +173,7 @@ async def search(
167173
)
168174
async def voice_search(
169175
file_url: str,
176+
request: Request,
170177
asession: AsyncSession = Depends(get_async_session),
171178
user_db: UserDB = Depends(authenticate_key),
172179
) -> QueryAudioResponse | JSONResponse:
@@ -222,8 +229,10 @@ async def voice_search(
222229
response=response_template,
223230
user_id=user_db.user_id,
224231
n_similar=int(N_TOP_CONTENT),
232+
n_to_crossencoder=int(N_TOP_CONTENT_TO_CROSSENCODER),
225233
asession=asession,
226234
exclude_archived=True,
235+
request=request,
227236
)
228237

229238
if user_query.generate_llm_response:
@@ -250,17 +259,18 @@ async def voice_search(
250259
os.remove(file_path)
251260
file_stream.close()
252261

253-
if type(response) is QueryAudioResponse:
262+
if isinstance(response, QueryAudioResponse):
254263
return response
255-
elif type(response) is QueryResponseError:
264+
265+
if isinstance(response, QueryResponseError):
256266
return JSONResponse(
257267
status_code=status.HTTP_400_BAD_REQUEST, content=response.model_dump()
258268
)
259-
else:
260-
return JSONResponse(
261-
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
262-
content={"error": "Internal server error"},
263-
)
269+
270+
return JSONResponse(
271+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
272+
content={"error": "Internal server error"},
273+
)
264274

265275
except ValueError as ve:
266276
logger.error(f"ValueError: {str(ve)}")
@@ -328,7 +338,9 @@ async def get_search_response(
328338
response: QueryResponse,
329339
user_id: int,
330340
n_similar: int,
341+
n_to_crossencoder: int,
331342
asession: AsyncSession,
343+
request: Request,
332344
exclude_archived: bool = True,
333345
) -> QueryResponse | QueryResponseError:
334346
"""Get similar content and construct the LLM answer for the user query.
@@ -347,8 +359,12 @@ async def get_search_response(
347359
The ID of the user making the query.
348360
n_similar
349361
The number of similar contents to retrieve.
362+
n_to_crossencoder
363+
The number of similar contents to send to the cross-encoder.
350364
asession
351365
`AsyncSession` object for database transactions.
366+
request
367+
The FastAPI request object.
352368
exclude_archived
353369
Specifies whether to exclude archived content.
354370
@@ -362,19 +378,56 @@ async def get_search_response(
362378
# always do the embeddings search even if some guardrails have failed
363379
metadata = create_langfuse_metadata(query_id=response.query_id, user_id=user_id)
364380

381+
if USE_CROSS_ENCODER == "True" and (n_to_crossencoder < n_similar):
382+
raise ValueError(
383+
"`n_to_crossencoder` must be less than or equal to `n_similar`."
384+
)
385+
365386
search_results = await get_similar_content_async(
366387
user_id=user_id,
367388
question=query_refined.query_text, # use latest transformed version of the text
368-
n_similar=n_similar,
389+
n_similar=n_to_crossencoder if USE_CROSS_ENCODER == "True" else n_similar,
369390
asession=asession,
370391
metadata=metadata,
371392
exclude_archived=exclude_archived,
372393
)
394+
395+
if USE_CROSS_ENCODER and (len(search_results) > 1):
396+
search_results = rerank_search_results(
397+
n_similar=n_similar,
398+
search_results=search_results,
399+
query_text=query_refined.query_text,
400+
request=request,
401+
)
402+
373403
response.search_results = search_results
374404

375405
return response
376406

377407

408+
def rerank_search_results(
409+
search_results: dict[int, QuerySearchResult],
410+
n_similar: int,
411+
query_text: str,
412+
request: Request,
413+
) -> dict[int, QuerySearchResult]:
414+
"""
415+
Rerank search results based on the similarity of the content to the query text
416+
"""
417+
encoder = request.app.state.crossencoder
418+
contents = search_results.values()
419+
scores = encoder.predict(
420+
[(query_text, content.title + "\n" + content.text) for content in contents]
421+
)
422+
423+
sorted_by_score = [v for _, v in sorted(zip(scores, contents), reverse=True)][
424+
:n_similar
425+
]
426+
reranked_search_results = dict(enumerate(sorted_by_score))
427+
428+
return reranked_search_results
429+
430+
378431
@generate_tts__after
379432
@check_align_score__after
380433
async def get_generation_response(
@@ -418,6 +471,8 @@ async def get_user_query_and_response(
418471
The user query database object.
419472
asession
420473
`AsyncSession` object for database transactions.
474+
generate_tts
475+
Specifies whether to generate a TTS audio response
421476
422477
Returns
423478
-------

core_backend/requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,3 +24,4 @@ google-cloud-storage==2.18.2
2424
google-cloud-texttospeech==2.16.5
2525
google-cloud-speech==2.27.0
2626
pydub==0.25.1
27+
sentence-transformers==3.0.1

docs/components/qa-service/search.md

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,10 @@ sequenceDiagram
1919
LLM->>AAQ: <Translated text>
2020
AAQ->>LLM: Paraphrase question
2121
LLM->>AAQ: <Paraphrased question>
22-
AAQ->>Vector DB: Request N most similar contents in DB
23-
Vector DB->>AAQ: <N contents with similarity score>
22+
AAQ->>Vector DB: Request M most similar contents in DB
23+
Vector DB->>AAQ: <M contents with similarity score>
24+
AAQ->>Cross-encoder: Re-rank to get top N contents
25+
Cross-encoder->>AAQ: <N contents with similarity score>
2426
AAQ->>User: Return JSON of N contents
2527
2628
```
@@ -37,6 +39,8 @@ sequenceDiagram
3739
LLM->>AAQ: <Safety Classification>
3840
AAQ->>Vector DB: Request N most similar contents in DB
3941
Vector DB->>AAQ: <N contents with similarity score>
42+
AAQ->>Cross-encoder: Re-rank to get top N contents
43+
Cross-encoder->>AAQ: <N contents with similarity score>
4044
AAQ->>LLM: Given contents, construct response in user's language to question
4145
LLM->>AAQ: <LLM response>
4246
AAQ->>LLM: Check if LLM response is consistent with contents

0 commit comments

Comments
 (0)