Skip to content

Commit 574e407

Browse files
authored
feat: let LLM choose whether to retrieve context (#62)
1 parent b02c5a0 commit 574e407

File tree

12 files changed

+451
-155
lines changed

12 files changed

+451
-155
lines changed

README.md

Lines changed: 33 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -159,9 +159,35 @@ insert_document(Path("Special Relativity.pdf"), config=my_config)
159159

160160
### 3. Searching and Retrieval-Augmented Generation (RAG)
161161

162-
#### 3.1 Simple RAG pipeline
162+
#### 3.1 Dynamically routed RAG
163163

164-
Now you can run a simple but powerful RAG pipeline that consists of retrieving the most relevant chunk spans (each of which is a list of consecutive chunks) with hybrid search and reranking, converting the user prompt to a RAG instruction and appending it to the message history, and finally generating the RAG response:
164+
Now you can run a dynamically routed RAG pipeline that consists of adding the user prompt to the message history and streaming the LLM response. Depending on the user prompt, the LLM may choose to retrieve context using RAGLite by invoking a retrieval tool. If retrieval is necessary, the LLM determines the search query and RAGLite applies hybrid search with reranking to retrieve the most relevant chunk spans (each of which is a list of consecutive chunks). The retrieval results are sent to the `on_retrieval` callback and are also appended to the message history as a tool output. Finally, the LLM response given the RAG context is streamed and the message history is updated with the assistant response:
165+
166+
```python
167+
from raglite import rag
168+
169+
# Create a user message:
170+
messages = [] # Or start with an existing message history.
171+
messages.append({
172+
"role": "user",
173+
"content": "How is intelligence measured?"
174+
})
175+
176+
# Let the LLM decide whether to search the database by providing a retrieval tool to the LLM.
177+
# If requested, RAGLite then uses hybrid search and reranking to append RAG context to the message history.
178+
# Finally, assistant response is streamed and appended to the message history.
179+
chunk_spans = []
180+
stream = rag(messages, on_retrieval=lambda x: chunk_spans.extend(x), config=my_config)
181+
for update in stream:
182+
print(update, end="")
183+
184+
# Access the documents referenced in the RAG context:
185+
documents = [chunk_span.document for chunk_span in chunk_spans]
186+
```
187+
188+
#### 3.2 Programmable RAG
189+
190+
If you need manual control over the RAG pipeline, you can run a basic but powerful pipeline that consists of retrieving the most relevant chunk spans with hybrid search and reranking, converting the user prompt to a RAG instruction and appending it to the message history, and finally generating the RAG response:
165191

166192
```python
167193
from raglite import create_rag_instruction, rag, retrieve_rag_context
@@ -174,21 +200,19 @@ chunk_spans = retrieve_rag_context(query=user_prompt, num_chunks=5, config=my_co
174200
messages = [] # Or start with an existing message history.
175201
messages.append(create_rag_instruction(user_prompt=user_prompt, context=chunk_spans))
176202

177-
# Stream the RAG response:
203+
# Stream the RAG response and append it to the message history:
178204
stream = rag(messages, config=my_config)
179205
for update in stream:
180206
print(update, end="")
181207

182-
# Access the documents cited in the RAG response:
208+
# Access the documents referenced in the RAG context:
183209
documents = [chunk_span.document for chunk_span in chunk_spans]
184210
```
185211

186-
#### 3.2 Advanced RAG pipeline
187-
188212
> [!TIP]
189213
> 🥇 Reranking can significantly improve the output quality of a RAG application. To add reranking to your application: first search for a larger set of 20 relevant chunks, then rerank them with a [rerankers](https://github.com/AnswerDotAI/rerankers) reranker, and finally keep the top 5 chunks.
190214
191-
In addition to the simple RAG pipeline, RAGLite also offers more advanced control over the individual steps of the pipeline. A full pipeline consists of several steps:
215+
RAGLite also offers more advanced control over the individual steps of a full RAG pipeline:
192216

193217
1. Searching for relevant chunks with keyword, vector, or hybrid search
194218
2. Retrieving the chunks from the database
@@ -229,14 +253,14 @@ from raglite import create_rag_instruction
229253
messages = [] # Or start with an existing message history.
230254
messages.append(create_rag_instruction(user_prompt=user_prompt, context=chunk_spans))
231255

232-
# Stream the RAG response:
256+
# Stream the RAG response and append it to the message history:
233257
from raglite import rag
234258

235259
stream = rag(messages, config=my_config)
236260
for update in stream:
237261
print(update, end="")
238262

239-
# Access the documents cited in the RAG response:
263+
# Access the documents referenced in the RAG context:
240264
documents = [chunk_span.document for chunk_span in chunk_spans]
241265
```
242266

poetry.lock

Lines changed: 4 additions & 4 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ scipy = ">=1.5.0"
3232
spacy = ">=3.7.0,<3.8.0"
3333
# Large Language Models:
3434
huggingface-hub = ">=0.22.0"
35-
litellm = ">=1.47.1"
35+
litellm = ">=1.48.4"
3636
llama-cpp-python = ">=0.3.2"
3737
pydantic = ">=2.7.0"
3838
# Approximate Nearest Neighbors:

src/raglite/_chainlit.py

Lines changed: 20 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -6,22 +6,11 @@
66
import chainlit as cl
77
from chainlit.input_widget import Switch, TextInput
88

9-
from raglite import (
10-
RAGLiteConfig,
11-
async_rag,
12-
create_rag_instruction,
13-
hybrid_search,
14-
insert_document,
15-
rerank_chunks,
16-
retrieve_chunk_spans,
17-
retrieve_chunks,
18-
)
9+
from raglite import RAGLiteConfig, async_rag, hybrid_search, insert_document, rerank_chunks
1910
from raglite._markdown import document_to_markdown
2011

2112
async_insert_document = cl.make_async(insert_document)
2213
async_hybrid_search = cl.make_async(hybrid_search)
23-
async_retrieve_chunks = cl.make_async(retrieve_chunks)
24-
async_retrieve_chunk_spans = cl.make_async(retrieve_chunk_spans)
2514
async_rerank_chunks = cl.make_async(rerank_chunks)
2615

2716

@@ -93,31 +82,27 @@ async def handle_message(user_message: cl.Message) -> None:
9382
for i, attachment in enumerate(inline_attachments)
9483
)
9584
+ f"\n\n{user_message.content}"
96-
)
97-
# Search for relevant contexts for RAG.
98-
async with cl.Step(name="search", type="retrieval") as step:
99-
step.input = user_message.content
100-
chunk_ids, _ = await async_hybrid_search(query=user_prompt, num_results=10, config=config)
101-
chunks = await async_retrieve_chunks(chunk_ids=chunk_ids, config=config)
102-
step.output = chunks
103-
step.elements = [ # Show the top chunks inline.
104-
cl.Text(content=str(chunk), display="inline") for chunk in chunks[:5]
105-
]
106-
await step.update() # TODO: Workaround for https://github.com/Chainlit/chainlit/issues/602.
107-
# Rerank the chunks and group them into chunk spans.
108-
async with cl.Step(name="rerank", type="rerank") as step:
109-
step.input = chunks
110-
chunks = await async_rerank_chunks(query=user_prompt, chunk_ids=chunks, config=config)
111-
chunk_spans = await async_retrieve_chunk_spans(chunks[:5], config=config)
112-
step.output = chunk_spans
113-
step.elements = [ # Show the top chunk spans inline.
114-
cl.Text(content=str(chunk_span), display="inline") for chunk_span in chunk_spans
115-
]
116-
await step.update() # TODO: Workaround for https://github.com/Chainlit/chainlit/issues/602.
85+
).strip()
11786
# Stream the LLM response.
11887
assistant_message = cl.Message(content="")
88+
chunk_spans = []
11989
messages: list[dict[str, str]] = cl.chat_context.to_openai()[:-1] # type: ignore[no-untyped-call]
120-
messages.append(create_rag_instruction(user_prompt=user_prompt, context=chunk_spans))
121-
async for token in async_rag(messages, config=config):
90+
messages.append({"role": "user", "content": user_prompt})
91+
async for token in async_rag(
92+
messages, on_retrieval=lambda x: chunk_spans.extend(x), config=config
93+
):
12294
await assistant_message.stream_token(token)
95+
# Append RAG sources, if any.
96+
if chunk_spans:
97+
rag_sources: dict[str, list[str]] = {}
98+
for chunk_span in chunk_spans:
99+
rag_sources.setdefault(chunk_span.document.id, [])
100+
rag_sources[chunk_span.document.id].append(str(chunk_span))
101+
assistant_message.content += "\n\nSources: " + ", ".join( # Rendered as hyperlinks.
102+
f"[{i + 1}]" for i in range(len(rag_sources))
103+
)
104+
assistant_message.elements = [ # Markdown content is rendered in sidebar.
105+
cl.Text(name=f"[{i + 1}]", content="\n\n---\n\n".join(content), display="side") # type: ignore[misc]
106+
for i, (_, content) in enumerate(rag_sources.items())
107+
]
123108
await assistant_message.update() # type: ignore[no-untyped-call]

src/raglite/_database.py

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -169,18 +169,41 @@ def to_xml(self, index: int | None = None) -> str:
169169
if not self.chunks:
170170
return ""
171171
index_attribute = f' index="{index}"' if index is not None else ""
172-
xml = "\n".join(
172+
xml_document = "\n".join(
173173
[
174174
f'<document{index_attribute} id="{self.document.id}">',
175175
f"<source>{self.document.url if self.document.url else self.document.filename}</source>",
176-
f'<span from_chunk_id="{self.chunks[0].id}" to_chunk_id="{self.chunks[0].id}">',
177-
f"<heading>\n{escape(self.chunks[0].headings.strip())}\n</heading>",
176+
f'<span from_chunk_id="{self.chunks[0].id}" to_chunk_id="{self.chunks[-1].id}">',
177+
f"<headings>\n{escape(self.chunks[0].headings.strip())}\n</headings>",
178178
f"<content>\n{escape(''.join(chunk.body for chunk in self.chunks).strip())}\n</content>",
179179
"</span>",
180180
"</document>",
181181
]
182182
)
183-
return xml
183+
return xml_document
184+
185+
def to_json(self, index: int | None = None) -> str:
186+
"""Convert this chunk span to a JSON representation.
187+
188+
The JSON representation follows Anthropic's best practices [1].
189+
190+
[1] https://docs.anthropic.com/en/docs/build-with-claude/prompt-engineering/long-context-tips
191+
"""
192+
if not self.chunks:
193+
return "{}"
194+
index_attribute = {"index": index} if index is not None else {}
195+
json_document = {
196+
**index_attribute,
197+
"id": self.document.id,
198+
"source": self.document.url if self.document.url else self.document.filename,
199+
"span": {
200+
"from_chunk_id": self.chunks[0].id,
201+
"to_chunk_id": self.chunks[-1].id,
202+
"headings": self.chunks[0].headings.strip(),
203+
"content": "".join(chunk.body for chunk in self.chunks).strip(),
204+
},
205+
}
206+
return json.dumps(json_document)
184207

185208
@property
186209
def content(self) -> str:

src/raglite/_extract.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,13 +34,12 @@ class MyNameResponse(BaseModel):
3434
# Load the default config if not provided.
3535
config = config or RAGLiteConfig()
3636
# Check if the LLM supports the response format.
37-
llm_provider = "llama-cpp-python" if config.embedder.startswith("llama-cpp") else None
3837
llm_supports_response_format = "response_format" in (
39-
get_supported_openai_params(model=config.llm, custom_llm_provider=llm_provider) or []
38+
get_supported_openai_params(model=config.llm) or []
4039
)
4140
# Update the system prompt with the JSON schema of the return type to help the LLM.
4241
system_prompt = getattr(return_type, "system_prompt", "").strip()
43-
if not llm_supports_response_format or llm_provider == "llama-cpp-python":
42+
if not llm_supports_response_format or config.llm.startswith("llama-cpp-python"):
4443
system_prompt += f"\n\nFormat your response according to this JSON schema:\n{return_type.model_json_schema()!s}"
4544
# Constrain the reponse format to the JSON schema if it's supported by the LLM [1]. Strict mode
4645
# is disabled by default because it only supports a subset of JSON schema features [2].

0 commit comments

Comments
 (0)