diff --git a/src/raglite/_chatml_function_calling.py b/src/raglite/_chatml_function_calling.py
new file mode 100644
index 0000000..e028744
--- /dev/null
+++ b/src/raglite/_chatml_function_calling.py
@@ -0,0 +1,521 @@
+"""Upgrade of llama-cpp-python's chatml-function-calling chat handler.
+
+Changes:
+1. General:
+ a. ✨ If no system message is supplied, add an empty system message to hold the tool metadata.
+ b. ✨ Add function descriptions to the system message so that tool use is better informed (fixes https://github.com/abetlen/llama-cpp-python/issues/1869).
+ c. ✨ Replace `print` statements relating to JSON grammars with `RuntimeWarning` warnings.
+ d. ✅ Add tests with fairly broad coverage of the different scenarios.
+4. Case "Tool choice by user":
+ a. ✨ Add support for more than one function call by making this a special case of "Automatic tool choice" with a single tool (subsumes https://github.com/abetlen/llama-cpp-python/pull/1503).
+5. Case "Automatic tool choice -> respond with a message":
+ a. ✨ Use user-defined `stop` and `max_tokens`.
+ b. 🐛 Replace incorrect use of follow-up grammar with user-defined grammar.
+6. Case "Automatic tool choice -> one or more function calls":
+ a. ✨ Add support for streaming the function calls (fixes https://github.com/abetlen/llama-cpp-python/issues/1883).
+ b. ✨ Make tool calling more robust by giving the LLM an explicit way to terminate the tool calls by wrapping them in a `` block.
+ c. 🐛 Add missing ":" stop token to determine whether to continue with another tool call, which prevented parallel function calling (fixes https://github.com/abetlen/llama-cpp-python/issues/1756).
+ d. ✨ Set temperature=0 to determine whether to continue with another tool call, similar to the initial decision on whether to call a tool.
+"""
+# This file uses old-style type hints and ignores certain ruff rules to minimise changes w.r.t. the original implementation:
+# ruff: noqa: C901, PLR0913, PLR0912, PLR0915, UP006, UP007, FBT001, FBT002, B006, TRY003, EM102, BLE001, PT018, W505
+
+import json
+import warnings
+from typing import ( # noqa: UP035
+ Any,
+ Iterator,
+ List,
+ Optional,
+ Union,
+ cast,
+)
+
+import jinja2
+from jinja2.sandbox import ImmutableSandboxedEnvironment
+from llama_cpp import llama, llama_grammar, llama_types
+from llama_cpp.llama_chat_format import (
+ _convert_completion_to_chat,
+ _convert_completion_to_chat_function,
+ _grammar_for_response_format,
+)
+
+
+def _accumulate_chunks(
+ chunks_iterator: Iterator[llama_types.CreateCompletionStreamResponse],
+ chunks_list: List[llama_types.CreateCompletionStreamResponse],
+) -> Iterator[llama_types.CreateCompletionStreamResponse]:
+ for chunk in chunks_iterator:
+ chunks_list.append(chunk)
+ yield chunk
+
+
+def _convert_chunks_to_completion(
+ chunks: List[llama_types.CreateCompletionStreamResponse],
+) -> llama_types.CreateCompletionResponse:
+ """Convert a list of completion chunks to a completion."""
+ # Accumulate completion response values
+ text: str = ""
+ finish_reason: Optional[str] = None
+ logprobs: Optional[llama_types.CompletionLogprobs] = None
+ prompt_tokens = 0
+ completion_tokens = 0
+ total_tokens = 0
+ completion_id: Optional[str] = None
+ completion_model: Optional[str] = None
+ completion_created: Optional[int] = None
+ for chunk in chunks:
+ # Extract the id, model, and created values from the first chunk
+ if completion_id is None:
+ completion_id = chunk["id"]
+ completion_model = chunk["model"]
+ completion_created = chunk["created"]
+ # Extract the usage if present in the chunk
+ usage = chunk.get("usage")
+ if usage:
+ prompt_tokens += usage.get("prompt_tokens", 0)
+ completion_tokens += usage.get("completion_tokens", 0)
+ total_tokens += usage.get("total_tokens", 0)
+ # Accumulate the chunk text
+ choice = chunk["choices"][0]
+ text += choice.get("text", "")
+ # Extract the finish_reason and logprobs if present in the chunk
+ if choice.get("finish_reason"):
+ finish_reason = choice["finish_reason"]
+ if choice.get("logprobs"):
+ logprobs = choice["logprobs"]
+ # Create the completion response
+ completion: llama_types.CreateCompletionResponse = {
+ "id": completion_id or "unknown_id",
+ "object": "text_completion",
+ "created": completion_created or 0,
+ "model": completion_model or "unknown_model",
+ "choices": [
+ {
+ "text": text,
+ "index": 0,
+ "logprobs": logprobs, # TODO: Improve accumulation of logprobs
+ "finish_reason": finish_reason, # type: ignore[typeddict-item]
+ }
+ ],
+ }
+ # Add usage section if present in the chunks
+ if (prompt_tokens + completion_tokens + total_tokens) > 0:
+ completion["usage"] = {
+ "prompt_tokens": prompt_tokens,
+ "completion_tokens": completion_tokens,
+ "total_tokens": total_tokens,
+ }
+ return completion
+
+
+def _stream_tool_calls(
+ llama: llama.Llama,
+ prompt: str,
+ tools: List[llama_types.ChatCompletionTool],
+ tool_name: str,
+ completion_kwargs: dict[str, Any],
+ follow_up_gbnf_tool_grammar: str,
+) -> Iterator[llama_types.CreateChatCompletionStreamResponse]:
+ # Generate a tool call completions
+ tool = next((tool for tool in tools if tool["function"]["name"] == tool_name), None)
+ completions: List[llama_types.CreateCompletionResponse] = []
+ completions_tool_name: List[str] = []
+ finish_reason_chat_chunk = None
+ while tool is not None and len(completions) <= 16: # noqa: PLR2004
+ # Generate the parameter values for the selected tool
+ prompt += f"functions.{tool_name}:\n"
+ try:
+ grammar = llama_grammar.LlamaGrammar.from_json_schema(
+ json.dumps(tool["function"]["parameters"]), verbose=llama.verbose
+ )
+ except Exception as e:
+ warnings.warn(
+ f"Failed to parse function body as JSON schema, falling back to default grammar\n\n{e}",
+ category=RuntimeWarning,
+ stacklevel=2,
+ )
+ grammar = llama_grammar.LlamaGrammar.from_string(
+ llama_grammar.JSON_GBNF, verbose=llama.verbose
+ )
+ completion_or_chunks = llama.create_completion(
+ prompt=prompt,
+ **{
+ **completion_kwargs,
+ "max_tokens": None,
+ "grammar": grammar,
+ },
+ )
+ chunks: List[llama_types.CreateCompletionResponse] = []
+ chat_chunks = _convert_completion_to_chat_function(
+ tool_name,
+ _accumulate_chunks(completion_or_chunks, chunks), # type: ignore[arg-type]
+ stream=True,
+ )
+ for chat_chunk in chat_chunks:
+ # Don't return the finish_reason chunk
+ if chat_chunk["choices"] and chat_chunk["choices"][0].get("finish_reason"):
+ finish_reason_chat_chunk = chat_chunk
+ break
+ # Update this tool call's index
+ if chat_chunk["choices"] and chat_chunk["choices"][0]["delta"].get("tool_calls"):
+ chat_chunk["choices"][0]["delta"]["tool_calls"][0]["index"] = len(completions)
+ yield chat_chunk
+ completion = _convert_chunks_to_completion(chunks)
+ completions.append(completion)
+ completions_tool_name.append(tool_name)
+ prompt += completion["choices"][0]["text"]
+ prompt += "\n"
+ # Determine whether to call another tool or stop
+ response = cast(
+ llama_types.CreateCompletionResponse,
+ llama.create_completion(
+ prompt=prompt,
+ **{
+ **completion_kwargs,
+ "temperature": 0,
+ "stream": False,
+ "stop": [*completion_kwargs["stop"], ":", ""],
+ "max_tokens": None,
+ "grammar": llama_grammar.LlamaGrammar.from_string(
+ follow_up_gbnf_tool_grammar, verbose=llama.verbose
+ ),
+ },
+ ),
+ )
+ tool_name = response["choices"][0]["text"][len("functions.") :]
+ tool = next((tool for tool in tools if tool["function"]["name"] == tool_name), None)
+ # Yield the finish_reason chunk
+ if finish_reason_chat_chunk is not None:
+ yield finish_reason_chat_chunk
+
+
+def chatml_function_calling_with_streaming(
+ llama: llama.Llama,
+ messages: List[llama_types.ChatCompletionRequestMessage],
+ functions: Optional[List[llama_types.ChatCompletionFunction]] = None,
+ function_call: Optional[llama_types.ChatCompletionRequestFunctionCall] = None,
+ tools: Optional[List[llama_types.ChatCompletionTool]] = None,
+ tool_choice: Optional[llama_types.ChatCompletionToolChoiceOption] = None,
+ temperature: float = 0.2,
+ top_p: float = 0.95,
+ top_k: int = 40,
+ min_p: float = 0.05,
+ typical_p: float = 1.0,
+ stream: bool = False,
+ stop: Optional[Union[str, List[str]]] = [],
+ response_format: Optional[llama_types.ChatCompletionRequestResponseFormat] = None,
+ max_tokens: Optional[int] = None,
+ presence_penalty: float = 0.0,
+ frequency_penalty: float = 0.0,
+ repeat_penalty: float = 1.1,
+ tfs_z: float = 1.0,
+ mirostat_mode: int = 0,
+ mirostat_tau: float = 5.0,
+ mirostat_eta: float = 0.1,
+ model: Optional[str] = None,
+ logits_processor: Optional[llama.LogitsProcessorList] = None,
+ grammar: Optional[llama.LlamaGrammar] = None, # type: ignore[name-defined]
+ logprobs: Optional[bool] = None,
+ top_logprobs: Optional[int] = None,
+ **kwargs: Any,
+) -> Union[
+ llama_types.CreateChatCompletionResponse,
+ Iterator[llama_types.CreateChatCompletionStreamResponse],
+]:
+ function_calling_template = (
+ "{% for message in messages %}"
+ "<|im_start|>{{ message.role }}\n"
+ # System message
+ "{% if message.role == 'system' %}"
+ "{{ message.content }}"
+ "{% if tool_calls %}"
+ "\n\nYou have access to the following functions:\n"
+ "{% for tool in tools %}"
+ '\n{% if tool.function.get("description") %}/* {{ tool.function.description | trim }} */{% endif %}'
+ "\nfunctions.{{ tool.function.name }}:\n"
+ "{{ tool.function.parameters | tojson }}"
+ "\n{% endfor %}"
+ "\nYou must respond to user messages with either a single message or with one or more function calls."
+ "\n\nTo respond with a message use the following format:"
+ "\n\nmessage:"
+ "\n"
+ "\n\nTo respond with one or more function calls use the following format:"
+ "\n\n"
+ "\nfunctions.:"
+ '\n{ "arg1": "value1", "arg2": "value2" }'
+ "\nfunctions.:"
+ '\n{ "arg1": "value1", "arg2": "value2" }'
+ "\n"
+ "{% endif %}"
+ "<|im_end|>\n"
+ "{% endif %}"
+ # User message
+ "{% if message.role == 'user' %}"
+ "{{ message.content }}"
+ "<|im_end|>\n"
+ "{% endif %}"
+ # Assistant message
+ "{% if message.role == 'assistant' %}"
+ ## Regular message
+ "{% if message.content and message.content | length > 0 %}"
+ "{% if tool_calls %}"
+ "message:\n"
+ "{% endif %}"
+ "{{ message.content }}"
+ "<|im_end|>\n"
+ "{% endif %}"
+ ## Function calls
+ "{% if 'tool_calls' in message %}"
+ "{% for tool_call in message.tool_calls %}"
+ "functions.{{ tool_call.function.name }}:\n"
+ "{{ tool_call.function.arguments }}"
+ "{% endfor %}"
+ "<|im_end|>\n"
+ "{% endif %}"
+ "{% endif %}"
+ "{% endfor %}"
+ "{% if add_generation_prompt %}<|im_start|>assistant\n{% endif %}"
+ )
+ template_renderer = ImmutableSandboxedEnvironment(
+ autoescape=jinja2.select_autoescape(["html", "xml"]),
+ undefined=jinja2.StrictUndefined,
+ ).from_string(function_calling_template)
+
+ # Convert legacy functions to tools
+ if functions is not None:
+ tools = [{"type": "function", "function": function} for function in functions]
+
+ # Convert legacy function_call to tool_choice
+ if function_call is not None:
+ if isinstance(function_call, str) and (function_call in ("none", "auto")):
+ tool_choice = function_call
+ if isinstance(function_call, dict) and "name" in function_call:
+ tool_choice = {"type": "function", "function": {"name": function_call["name"]}}
+
+ # Collect the llama.create_completion keyword arguments so we don't have to repeat these with
+ # each completion call
+ stop = (
+ [stop, "<|im_end|>", "|im_end|>"]
+ if isinstance(stop, str)
+ else [*stop, "<|im_end|>", "|im_end|>"]
+ if stop
+ else ["<|im_end|>", "|im_end|>"]
+ )
+ grammar = ( # It is assumed the grammar applies to messages only, not tool calls
+ grammar
+ if grammar is not None
+ else (
+ _grammar_for_response_format(response_format)
+ if response_format is not None and response_format["type"] == "json_object"
+ else None
+ )
+ )
+ completion_kwargs = {
+ "temperature": temperature,
+ "top_p": top_p,
+ "top_k": top_k,
+ "min_p": min_p,
+ "typical_p": typical_p,
+ "stream": stream,
+ "stop": stop,
+ "max_tokens": max_tokens,
+ "presence_penalty": presence_penalty,
+ "frequency_penalty": frequency_penalty,
+ "repeat_penalty": repeat_penalty,
+ "tfs_z": tfs_z,
+ "mirostat_mode": mirostat_mode,
+ "mirostat_tau": mirostat_tau,
+ "mirostat_eta": mirostat_eta,
+ "model": model,
+ "logits_processor": logits_processor,
+ "grammar": grammar,
+ }
+
+ # Case 1: No tool use
+ if (
+ tool_choice is None
+ or (isinstance(tool_choice, str) and tool_choice == "none")
+ or tools is None
+ or len(tools) == 0
+ ):
+ prompt = template_renderer.render(
+ messages=messages, tools=[], tool_calls=None, add_generation_prompt=True
+ )
+ return _convert_completion_to_chat(
+ llama.create_completion(
+ prompt=prompt,
+ **completion_kwargs, # type: ignore[arg-type]
+ logprobs=top_logprobs if logprobs else None,
+ ),
+ stream=stream,
+ )
+
+ # Ensure there is a system prompt to attach the tool metadata to
+ if not any(message["role"] == "system" for message in messages):
+ messages = [*messages, {"role": "system", "content": ""}]
+
+ # Case 2: Automatic or fixed tool choice
+ # Case 2 step 1: Determine whether to respond with a message or a tool call
+ assert (isinstance(tool_choice, str) and tool_choice == "auto") or isinstance(tool_choice, dict)
+ if isinstance(tool_choice, dict):
+ tools = [t for t in tools if t["function"]["name"] == tool_choice["function"]["name"]]
+ assert tools
+ function_names = " | ".join([f'''"functions.{t['function']['name']}:"''' for t in tools])
+ prompt = template_renderer.render(
+ messages=messages, tools=tools, tool_calls=True, add_generation_prompt=True
+ )
+ initial_gbnf_tool_grammar = (
+ (
+ 'root ::= "" "\\n" functions | "message:"\n'
+ f"functions ::= {function_names}\n"
+ )
+ if tool_choice == "auto"
+ else f'root ::= "" "\\n" functions\nfunctions ::= {function_names}\n'
+ )
+ completion = cast(
+ llama_types.CreateCompletionResponse,
+ llama.create_completion(
+ prompt=prompt,
+ **{ # type: ignore[arg-type]
+ **completion_kwargs,
+ "temperature": 0,
+ "stream": False,
+ "stop": [":"],
+ "max_tokens": None,
+ "grammar": llama_grammar.LlamaGrammar.from_string(
+ initial_gbnf_tool_grammar, verbose=llama.verbose
+ ),
+ },
+ ),
+ )
+ text = completion["choices"][0]["text"]
+ tool_name = None if text.startswith("message") else text.split("\n")[-1][len("functions.") :]
+
+ # Case 2 step 2A: Respond with a message
+ if tool_name is None:
+ prompt = template_renderer.render(
+ messages=messages, tools=[], tool_calls=None, add_generation_prompt=True
+ )
+ return _convert_completion_to_chat(
+ llama.create_completion(
+ prompt=prompt,
+ **completion_kwargs, # type: ignore[arg-type]
+ logprobs=top_logprobs if logprobs else None,
+ ),
+ stream=stream,
+ )
+
+ # Case 2 step 2B: One or more function calls
+ follow_up_gbnf_tool_grammar = (
+ 'root ::= functions | "" | "<|im_end|>"\n'
+ f"functions ::= {function_names}\n"
+ )
+ prompt += "\n"
+ if stream:
+ return _stream_tool_calls(
+ llama, prompt, tools, tool_name, completion_kwargs, follow_up_gbnf_tool_grammar
+ )
+ tool = next((tool for tool in tools if tool["function"]["name"] == tool_name), None)
+ completions: List[llama_types.CreateCompletionResponse] = []
+ completions_tool_name: List[str] = []
+ while tool is not None and len(completions) <= 16: # noqa: PLR2004
+ # Generate the parameter values for the selected tool
+ prompt += f"functions.{tool_name}:\n"
+ try:
+ grammar = llama_grammar.LlamaGrammar.from_json_schema(
+ json.dumps(tool["function"]["parameters"]), verbose=llama.verbose
+ )
+ except Exception as e:
+ warnings.warn(
+ f"Failed to parse function body as JSON schema, falling back to default grammar\n\n{e}",
+ category=RuntimeWarning,
+ stacklevel=2,
+ )
+ grammar = llama_grammar.LlamaGrammar.from_string(
+ llama_grammar.JSON_GBNF, verbose=llama.verbose
+ )
+ completion_or_chunks = llama.create_completion(
+ prompt=prompt,
+ **{ # type: ignore[arg-type]
+ **completion_kwargs,
+ "max_tokens": None,
+ "grammar": grammar,
+ },
+ )
+ completion = cast(llama_types.CreateCompletionResponse, completion_or_chunks)
+ completions.append(completion)
+ completions_tool_name.append(tool_name)
+ prompt += completion["choices"][0]["text"]
+ prompt += "\n"
+ # Determine whether to call another tool or stop
+ response = cast(
+ llama_types.CreateCompletionResponse,
+ llama.create_completion(
+ prompt=prompt,
+ **{ # type: ignore[arg-type]
+ **completion_kwargs,
+ "temperature": 0,
+ "stream": False,
+ "stop": [*completion_kwargs["stop"], ":", ""], # type: ignore[misc]
+ "max_tokens": None,
+ "grammar": llama_grammar.LlamaGrammar.from_string(
+ follow_up_gbnf_tool_grammar, verbose=llama.verbose
+ ),
+ },
+ ),
+ )
+ tool_name = response["choices"][0]["text"][len("functions.") :]
+ tool = next((tool for tool in tools if tool["function"]["name"] == tool_name), None)
+ # Merge the completions into a single chat completion
+ chat_completion: llama_types.CreateChatCompletionResponse = {
+ "id": "chat" + completion["id"],
+ "object": "chat.completion",
+ "created": completion["created"],
+ "model": completion["model"],
+ "choices": [
+ {
+ "finish_reason": "tool_calls",
+ "index": 0,
+ "logprobs": completion["choices"][0]["logprobs"],
+ "message": {
+ "role": "assistant",
+ "content": None,
+ "tool_calls": [
+ {
+ "id": "call_" + f"_{i}_" + tool_name + "_" + completion["id"],
+ "type": "function",
+ "function": {
+ "name": tool_name,
+ "arguments": completion["choices"][0]["text"],
+ },
+ }
+ for i, (tool_name, completion) in enumerate(
+ zip(completions_tool_name, completions, strict=True)
+ )
+ ],
+ },
+ }
+ ],
+ "usage": {
+ "completion_tokens": sum(
+ (completion["usage"]["completion_tokens"] if "usage" in completion else 0)
+ for completion in completions
+ ),
+ "prompt_tokens": sum(
+ completion["usage"]["prompt_tokens"] if "usage" in completion else 0
+ for completion in completions
+ ),
+ "total_tokens": sum(
+ completion["usage"]["total_tokens"] if "usage" in completion else 0
+ for completion in completions
+ ),
+ },
+ }
+ if len(completions) == 1:
+ single_function_call: llama_types.ChatCompletionResponseFunctionCall = {
+ "name": tool_name,
+ "arguments": completions[0]["choices"][0]["text"],
+ }
+ chat_completion["choices"][0]["message"]["function_call"] = single_function_call
+ return chat_completion
diff --git a/src/raglite/_litellm.py b/src/raglite/_litellm.py
index e0bab3b..973cdb4 100644
--- a/src/raglite/_litellm.py
+++ b/src/raglite/_litellm.py
@@ -31,6 +31,7 @@
LlamaRAMCache,
)
+from raglite._chatml_function_calling import chatml_function_calling_with_streaming
from raglite._config import RAGLiteConfig
# Reduce the logging level for LiteLLM, flashrank, and httpx.
@@ -116,8 +117,8 @@ def llm(model: str, **kwargs: Any) -> Llama:
n_ctx=n_ctx,
n_gpu_layers=-1,
verbose=False,
- # Enable function calling.
- chat_format="chatml-function-calling",
+ # Enable function calling with streaming.
+ chat_handler=chatml_function_calling_with_streaming,
# Workaround to enable long context embedding models [1].
# [1] https://github.com/abetlen/llama-cpp-python/issues/1762
n_batch=n_ctx if n_ctx > 0 else 1024,
diff --git a/src/raglite/_mcp.py b/src/raglite/_mcp.py
index 094d69a..8d72a70 100644
--- a/src/raglite/_mcp.py
+++ b/src/raglite/_mcp.py
@@ -13,9 +13,8 @@
Field(
description=(
"The `query` string to search the knowledge base with.\n"
- "The `query` string MUST satisfy ALL of the following criteria:\n"
- "- The `query` string MUST be a precise question in the user's language.\n"
- "- The `query` string MUST resolve all pronouns to explicit nouns from the conversation history."
+ "The `query` string MUST be a precise single-faceted question in the user's language.\n"
+ "The `query` string MUST resolve all pronouns to explicit nouns from the conversation history."
)
),
]
diff --git a/src/raglite/_rag.py b/src/raglite/_rag.py
index 689fc1a..acb7423 100644
--- a/src/raglite/_rag.py
+++ b/src/raglite/_rag.py
@@ -98,41 +98,27 @@ def _get_tools(
if not messages_contain_rag_context and not llm_supports_function_calling:
error_message = "You must either explicitly provide RAG context in the last message, or use an LLM that supports function calling."
raise ValueError(error_message)
- # Add a tool to search the knowledge base if no RAG context is provided in the messages. Because
- # llama-cpp-python cannot stream tool_use='auto' yet, we use a workaround that forces the LLM
- # to use a tool, but allows it to skip the search.
- auto_tool_use_workaround = (
- {
- "expert": {
- "type": "boolean",
- "description": "The `expert` boolean MUST be true if the question requires domain-specific or expert-level knowledge to answer, and false otherwise.",
- }
- }
- if config.llm.startswith("llama-cpp-python")
- else {}
- )
+ # Return a single tool to search the knowledge base if no RAG context is provided.
tools: list[dict[str, Any]] | None = (
[
{
"type": "function",
"function": {
"name": "search_knowledge_base",
- "description": "Search the knowledge base. IMPORTANT: Only use this tool if a well-rounded non-expert would need to look up information to answer the question.",
+ "description": "Search the knowledge base. IMPORTANT: You MAY NOT use this function if the query can be answered with common knowledge or straightforward reasoning.",
"parameters": {
"type": "object",
"properties": {
- **auto_tool_use_workaround,
"query": {
"type": "string",
"description": (
"The `query` string to search the knowledge base with.\n"
- "The `query` string MUST satisfy ALL of the following criteria:\n"
- "- The `query` string MUST be a precise question in the user's language.\n"
- "- The `query` string MUST resolve all pronouns to explicit nouns from the conversation history."
+ "The `query` string MUST be a precise single-faceted question in the user's language.\n"
+ "The `query` string MUST resolve all pronouns to explicit nouns from the conversation history."
),
},
},
- "required": [*list(auto_tool_use_workaround), "query"],
+ "required": ["query"],
"additionalProperties": False,
},
},
@@ -141,15 +127,7 @@ def _get_tools(
if not messages_contain_rag_context
else None
)
- tool_choice: dict[str, Any] | str | None = (
- (
- {"type": "function", "function": {"name": "search_knowledge_base"}}
- if auto_tool_use_workaround
- else "auto"
- )
- if tools
- else None
- )
+ tool_choice: dict[str, Any] | str | None = "auto" if tools else None
return tools, tool_choice
@@ -164,19 +142,16 @@ def _run_tools(
if tool_call.function.name == "search_knowledge_base":
kwargs = json.loads(tool_call.function.arguments)
kwargs["config"] = config
- skip = not kwargs.pop("expert", True)
- chunk_spans = retrieve_rag_context(**kwargs) if not skip and kwargs["query"] else None
+ chunk_spans = retrieve_rag_context(**kwargs)
tool_messages.append(
{
"role": "tool",
"content": '{{"documents": [{elements}]}}'.format(
elements=", ".join(
chunk_span.to_json(index=i + 1)
- for i, chunk_span in enumerate(chunk_spans) # type: ignore[arg-type]
+ for i, chunk_span in enumerate(chunk_spans)
)
- )
- if not skip and kwargs["query"]
- else "{}",
+ ),
"tool_call_id": tool_call.id,
}
)
@@ -198,23 +173,14 @@ def rag(
max_tokens = get_context_size(config)
tools, tool_choice = _get_tools(messages, config)
# Stream the LLM response, which is either a tool call request or an assistant response.
- chunks = []
- clipped_messages = _clip(messages, max_tokens)
- if tools and config.llm.startswith("llama-cpp-python"):
- # Help llama.cpp LLMs plan their response by providing a JSON schema for the tool call.
- clipped_messages[-1]["content"] += (
- "\n\n\n"
- f"Available tools:\n```\n{json.dumps(tools)}\n```\n"
- "IMPORTANT: The `expert` boolean MUST be true if the question requires domain-specific or expert-level knowledge to answer, and false otherwise.\n"
- ""
- )
stream = completion(
model=config.llm,
- messages=clipped_messages,
+ messages=_clip(messages, max_tokens),
tools=tools,
tool_choice=tool_choice,
stream=True,
)
+ chunks = []
for chunk in stream:
chunks.append(chunk)
if isinstance(token := chunk.choices[0].delta.content, str):
@@ -249,23 +215,14 @@ async def async_rag(
max_tokens = get_context_size(config)
tools, tool_choice = _get_tools(messages, config)
# Asynchronously stream the LLM response, which is either a tool call or an assistant response.
- chunks = []
- clipped_messages = _clip(messages, max_tokens)
- if tools and config.llm.startswith("llama-cpp-python"):
- # Help llama.cpp LLMs plan their response by providing a JSON schema for the tool call.
- clipped_messages[-1]["content"] += (
- "\n\n\n"
- f"Available tools:\n```\n{json.dumps(tools)}\n```\n"
- "IMPORTANT: The `expert` boolean MUST be true if the question requires domain-specific or expert-level knowledge to answer, and false otherwise.\n"
- ""
- )
async_stream = await acompletion(
model=config.llm,
- messages=clipped_messages,
+ messages=_clip(messages, max_tokens),
tools=tools,
tool_choice=tool_choice,
stream=True,
)
+ chunks = []
async for chunk in async_stream:
chunks.append(chunk)
if isinstance(token := chunk.choices[0].delta.content, str):
diff --git a/tests/conftest.py b/tests/conftest.py
index 943250c..e067355 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -7,7 +7,6 @@
from pathlib import Path
import pytest
-from llama_cpp import llama_supports_gpu_offload
from sqlalchemy import create_engine, text
from raglite import RAGLiteConfig, insert_document
@@ -24,11 +23,6 @@ def is_postgres_running() -> bool:
return False
-def is_accelerator_available() -> bool:
- """Check if an accelerator is available."""
- return llama_supports_gpu_offload() or (os.cpu_count() or 1) >= 8 # noqa: PLR2004
-
-
def is_openai_available() -> bool:
"""Check if an OpenAI API key is set."""
return bool(os.environ.get("OPENAI_API_KEY"))
@@ -79,7 +73,7 @@ def database(request: pytest.FixtureRequest) -> str:
"llama-cpp-python/bartowski/Llama-3.2-3B-Instruct-GGUF/*Q4_K_M.gguf@4096",
"llama-cpp-python/lm-kit/bge-m3-gguf/*Q4_K_M.gguf@1024", # More context degrades performance.
),
- id="llama32_3B-bge_m3",
+ id="llama_3.2_3B-bge_m3",
),
pytest.param(
("gpt-4o-mini", "text-embedding-3-small"),
diff --git a/tests/test_chatml_function_calling.py b/tests/test_chatml_function_calling.py
new file mode 100644
index 0000000..2ef4edb
--- /dev/null
+++ b/tests/test_chatml_function_calling.py
@@ -0,0 +1,129 @@
+"""Test RAGLite's upgraded chatml-function-calling llama-cpp-python chat handler."""
+
+import os
+from collections.abc import Iterator
+from typing import cast
+
+import pytest
+from llama_cpp import Llama, llama_supports_gpu_offload
+from llama_cpp.llama_types import (
+ ChatCompletionRequestMessage,
+ ChatCompletionTool,
+ ChatCompletionToolChoiceOption,
+ CreateChatCompletionResponse,
+ CreateChatCompletionStreamResponse,
+)
+from typeguard import ForwardRefPolicy, check_type
+
+from raglite._chatml_function_calling import chatml_function_calling_with_streaming
+
+
+def is_accelerator_available() -> bool:
+ """Check if an accelerator is available."""
+ return llama_supports_gpu_offload() or (os.cpu_count() or 1) >= 8 # noqa: PLR2004
+
+
+@pytest.mark.parametrize(
+ "stream",
+ [
+ pytest.param(True, id="stream=True"),
+ pytest.param(False, id="stream=False"),
+ ],
+)
+@pytest.mark.parametrize(
+ "tool_choice",
+ [
+ pytest.param("none", id="tool_choice=none"),
+ pytest.param("auto", id="tool_choice=auto"),
+ pytest.param(
+ {"type": "function", "function": {"name": "get_weather"}}, id="tool_choice=fixed"
+ ),
+ ],
+)
+@pytest.mark.parametrize(
+ "user_prompt_expected_tool_calls",
+ [
+ pytest.param(
+ ("Is 7 a prime number?", 0),
+ id="expected_tool_calls=0",
+ ),
+ pytest.param(
+ ("What's the weather like in Paris today?", 1),
+ id="expected_tool_calls=1",
+ ),
+ pytest.param(
+ ("What's the weather like in Paris today? What about New York?", 2),
+ id="expected_tool_calls=2",
+ ),
+ ],
+)
+@pytest.mark.parametrize(
+ "llm_repo_id",
+ [
+ pytest.param("bartowski/Llama-3.2-3B-Instruct-GGUF", id="llama_3.2_3B"),
+ pytest.param(
+ "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF",
+ id="llama_3.1_8B",
+ marks=pytest.mark.skipif(
+ not is_accelerator_available(), reason="Accelerator not available"
+ ),
+ ),
+ ],
+)
+def test_llama_cpp_python_tool_use(
+ llm_repo_id: str,
+ user_prompt_expected_tool_calls: tuple[str, int],
+ tool_choice: ChatCompletionToolChoiceOption,
+ stream: bool, # noqa: FBT001
+) -> None:
+ """Test the upgraded chatml-function-calling llama-cpp-python chat handler."""
+ user_prompt, expected_tool_calls = user_prompt_expected_tool_calls
+ if isinstance(tool_choice, dict) and expected_tool_calls == 0:
+ pytest.skip("Nonsensical")
+ llm = Llama.from_pretrained(
+ repo_id=llm_repo_id,
+ filename="*Q4_K_M.gguf",
+ n_ctx=4096,
+ n_gpu_layers=-1,
+ verbose=False,
+ chat_handler=chatml_function_calling_with_streaming,
+ )
+ messages: list[ChatCompletionRequestMessage] = [{"role": "user", "content": user_prompt}]
+ tools: list[ChatCompletionTool] = [
+ {
+ "type": "function",
+ "function": {
+ "name": "get_weather",
+ "description": "Get the weather for a location.",
+ "parameters": {
+ "type": "object",
+ "properties": {"location": {"type": "string", "description": "A city name."}},
+ },
+ },
+ }
+ ]
+ response = llm.create_chat_completion(
+ messages=messages, tools=tools, tool_choice=tool_choice, stream=stream
+ )
+ if stream:
+ response = cast(Iterator[CreateChatCompletionStreamResponse], response)
+ num_tool_calls = 0
+ for chunk in response:
+ check_type(chunk, CreateChatCompletionStreamResponse)
+ tool_calls = chunk["choices"][0]["delta"].get("tool_calls")
+ if isinstance(tool_calls, list):
+ num_tool_calls = max(tool_call["index"] for tool_call in tool_calls) + 1
+ assert num_tool_calls == (expected_tool_calls if tool_choice != "none" else 0)
+ else:
+ response = cast(CreateChatCompletionResponse, response)
+ check_type(
+ response, CreateChatCompletionResponse, forward_ref_policy=ForwardRefPolicy.IGNORE
+ )
+ if expected_tool_calls == 0 or tool_choice == "none":
+ assert response["choices"][0]["message"].get("tool_calls") is None
+ else:
+ assert len(response["choices"][0]["message"]["tool_calls"]) == expected_tool_calls
+ assert all(
+ tool_call["function"]["name"] == tools[0]["function"]["name"]
+ for tool_call in response["choices"][0]["message"]["tool_calls"]
+ )
diff --git a/tests/test_rag.py b/tests/test_rag.py
index 2fb8a1c..b7d6280 100644
--- a/tests/test_rag.py
+++ b/tests/test_rag.py
@@ -1,10 +1,13 @@
"""Test RAGLite's RAG functionality."""
+import json
+
from raglite import (
RAGLiteConfig,
create_rag_instruction,
retrieve_rag_context,
)
+from raglite._database import ChunkSpan
from raglite._rag import rag
@@ -22,3 +25,38 @@ def test_rag_manual(raglite_test_config: RAGLiteConfig) -> None:
assert "event" in answer.lower()
# Verify that no RAG context was retrieved through tool use.
assert [message["role"] for message in messages] == ["user", "assistant"]
+
+
+def test_rag_auto_with_retrieval(raglite_test_config: RAGLiteConfig) -> None:
+ """Test Retrieval-Augmented Generation with automatic retrieval."""
+ # Answer a question that requires RAG.
+ user_prompt = "How does Einstein define 'simultaneous events' in his special relativity paper?"
+ messages = [{"role": "user", "content": user_prompt}]
+ chunk_spans = []
+ stream = rag(messages, on_retrieval=lambda x: chunk_spans.extend(x), config=raglite_test_config)
+ answer = ""
+ for update in stream:
+ assert isinstance(update, str)
+ answer += update
+ assert "event" in answer.lower()
+ # Verify that RAG context was retrieved automatically.
+ assert [message["role"] for message in messages] == ["user", "assistant", "tool", "assistant"]
+ assert json.loads(messages[-2]["content"])
+ assert chunk_spans
+ assert all(isinstance(chunk_span, ChunkSpan) for chunk_span in chunk_spans)
+
+
+def test_rag_auto_without_retrieval(raglite_test_config: RAGLiteConfig) -> None:
+ """Test Retrieval-Augmented Generation with automatic retrieval."""
+ # Answer a question that does not require RAG.
+ user_prompt = "Is 7 a prime number?"
+ messages = [{"role": "user", "content": user_prompt}]
+ chunk_spans = []
+ stream = rag(messages, on_retrieval=lambda x: chunk_spans.extend(x), config=raglite_test_config)
+ answer = ""
+ for update in stream:
+ assert isinstance(update, str)
+ answer += update
+ # Verify that no RAG context was retrieved.
+ assert [message["role"] for message in messages] == ["user", "assistant"]
+ assert not chunk_spans