diff --git a/src/raglite/_chatml_function_calling.py b/src/raglite/_chatml_function_calling.py new file mode 100644 index 0000000..e028744 --- /dev/null +++ b/src/raglite/_chatml_function_calling.py @@ -0,0 +1,521 @@ +"""Upgrade of llama-cpp-python's chatml-function-calling chat handler. + +Changes: +1. General: + a. ✨ If no system message is supplied, add an empty system message to hold the tool metadata. + b. ✨ Add function descriptions to the system message so that tool use is better informed (fixes https://github.com/abetlen/llama-cpp-python/issues/1869). + c. ✨ Replace `print` statements relating to JSON grammars with `RuntimeWarning` warnings. + d. ✅ Add tests with fairly broad coverage of the different scenarios. +4. Case "Tool choice by user": + a. ✨ Add support for more than one function call by making this a special case of "Automatic tool choice" with a single tool (subsumes https://github.com/abetlen/llama-cpp-python/pull/1503). +5. Case "Automatic tool choice -> respond with a message": + a. ✨ Use user-defined `stop` and `max_tokens`. + b. 🐛 Replace incorrect use of follow-up grammar with user-defined grammar. +6. Case "Automatic tool choice -> one or more function calls": + a. ✨ Add support for streaming the function calls (fixes https://github.com/abetlen/llama-cpp-python/issues/1883). + b. ✨ Make tool calling more robust by giving the LLM an explicit way to terminate the tool calls by wrapping them in a `` block. + c. 🐛 Add missing ":" stop token to determine whether to continue with another tool call, which prevented parallel function calling (fixes https://github.com/abetlen/llama-cpp-python/issues/1756). + d. ✨ Set temperature=0 to determine whether to continue with another tool call, similar to the initial decision on whether to call a tool. +""" +# This file uses old-style type hints and ignores certain ruff rules to minimise changes w.r.t. the original implementation: +# ruff: noqa: C901, PLR0913, PLR0912, PLR0915, UP006, UP007, FBT001, FBT002, B006, TRY003, EM102, BLE001, PT018, W505 + +import json +import warnings +from typing import ( # noqa: UP035 + Any, + Iterator, + List, + Optional, + Union, + cast, +) + +import jinja2 +from jinja2.sandbox import ImmutableSandboxedEnvironment +from llama_cpp import llama, llama_grammar, llama_types +from llama_cpp.llama_chat_format import ( + _convert_completion_to_chat, + _convert_completion_to_chat_function, + _grammar_for_response_format, +) + + +def _accumulate_chunks( + chunks_iterator: Iterator[llama_types.CreateCompletionStreamResponse], + chunks_list: List[llama_types.CreateCompletionStreamResponse], +) -> Iterator[llama_types.CreateCompletionStreamResponse]: + for chunk in chunks_iterator: + chunks_list.append(chunk) + yield chunk + + +def _convert_chunks_to_completion( + chunks: List[llama_types.CreateCompletionStreamResponse], +) -> llama_types.CreateCompletionResponse: + """Convert a list of completion chunks to a completion.""" + # Accumulate completion response values + text: str = "" + finish_reason: Optional[str] = None + logprobs: Optional[llama_types.CompletionLogprobs] = None + prompt_tokens = 0 + completion_tokens = 0 + total_tokens = 0 + completion_id: Optional[str] = None + completion_model: Optional[str] = None + completion_created: Optional[int] = None + for chunk in chunks: + # Extract the id, model, and created values from the first chunk + if completion_id is None: + completion_id = chunk["id"] + completion_model = chunk["model"] + completion_created = chunk["created"] + # Extract the usage if present in the chunk + usage = chunk.get("usage") + if usage: + prompt_tokens += usage.get("prompt_tokens", 0) + completion_tokens += usage.get("completion_tokens", 0) + total_tokens += usage.get("total_tokens", 0) + # Accumulate the chunk text + choice = chunk["choices"][0] + text += choice.get("text", "") + # Extract the finish_reason and logprobs if present in the chunk + if choice.get("finish_reason"): + finish_reason = choice["finish_reason"] + if choice.get("logprobs"): + logprobs = choice["logprobs"] + # Create the completion response + completion: llama_types.CreateCompletionResponse = { + "id": completion_id or "unknown_id", + "object": "text_completion", + "created": completion_created or 0, + "model": completion_model or "unknown_model", + "choices": [ + { + "text": text, + "index": 0, + "logprobs": logprobs, # TODO: Improve accumulation of logprobs + "finish_reason": finish_reason, # type: ignore[typeddict-item] + } + ], + } + # Add usage section if present in the chunks + if (prompt_tokens + completion_tokens + total_tokens) > 0: + completion["usage"] = { + "prompt_tokens": prompt_tokens, + "completion_tokens": completion_tokens, + "total_tokens": total_tokens, + } + return completion + + +def _stream_tool_calls( + llama: llama.Llama, + prompt: str, + tools: List[llama_types.ChatCompletionTool], + tool_name: str, + completion_kwargs: dict[str, Any], + follow_up_gbnf_tool_grammar: str, +) -> Iterator[llama_types.CreateChatCompletionStreamResponse]: + # Generate a tool call completions + tool = next((tool for tool in tools if tool["function"]["name"] == tool_name), None) + completions: List[llama_types.CreateCompletionResponse] = [] + completions_tool_name: List[str] = [] + finish_reason_chat_chunk = None + while tool is not None and len(completions) <= 16: # noqa: PLR2004 + # Generate the parameter values for the selected tool + prompt += f"functions.{tool_name}:\n" + try: + grammar = llama_grammar.LlamaGrammar.from_json_schema( + json.dumps(tool["function"]["parameters"]), verbose=llama.verbose + ) + except Exception as e: + warnings.warn( + f"Failed to parse function body as JSON schema, falling back to default grammar\n\n{e}", + category=RuntimeWarning, + stacklevel=2, + ) + grammar = llama_grammar.LlamaGrammar.from_string( + llama_grammar.JSON_GBNF, verbose=llama.verbose + ) + completion_or_chunks = llama.create_completion( + prompt=prompt, + **{ + **completion_kwargs, + "max_tokens": None, + "grammar": grammar, + }, + ) + chunks: List[llama_types.CreateCompletionResponse] = [] + chat_chunks = _convert_completion_to_chat_function( + tool_name, + _accumulate_chunks(completion_or_chunks, chunks), # type: ignore[arg-type] + stream=True, + ) + for chat_chunk in chat_chunks: + # Don't return the finish_reason chunk + if chat_chunk["choices"] and chat_chunk["choices"][0].get("finish_reason"): + finish_reason_chat_chunk = chat_chunk + break + # Update this tool call's index + if chat_chunk["choices"] and chat_chunk["choices"][0]["delta"].get("tool_calls"): + chat_chunk["choices"][0]["delta"]["tool_calls"][0]["index"] = len(completions) + yield chat_chunk + completion = _convert_chunks_to_completion(chunks) + completions.append(completion) + completions_tool_name.append(tool_name) + prompt += completion["choices"][0]["text"] + prompt += "\n" + # Determine whether to call another tool or stop + response = cast( + llama_types.CreateCompletionResponse, + llama.create_completion( + prompt=prompt, + **{ + **completion_kwargs, + "temperature": 0, + "stream": False, + "stop": [*completion_kwargs["stop"], ":", ""], + "max_tokens": None, + "grammar": llama_grammar.LlamaGrammar.from_string( + follow_up_gbnf_tool_grammar, verbose=llama.verbose + ), + }, + ), + ) + tool_name = response["choices"][0]["text"][len("functions.") :] + tool = next((tool for tool in tools if tool["function"]["name"] == tool_name), None) + # Yield the finish_reason chunk + if finish_reason_chat_chunk is not None: + yield finish_reason_chat_chunk + + +def chatml_function_calling_with_streaming( + llama: llama.Llama, + messages: List[llama_types.ChatCompletionRequestMessage], + functions: Optional[List[llama_types.ChatCompletionFunction]] = None, + function_call: Optional[llama_types.ChatCompletionRequestFunctionCall] = None, + tools: Optional[List[llama_types.ChatCompletionTool]] = None, + tool_choice: Optional[llama_types.ChatCompletionToolChoiceOption] = None, + temperature: float = 0.2, + top_p: float = 0.95, + top_k: int = 40, + min_p: float = 0.05, + typical_p: float = 1.0, + stream: bool = False, + stop: Optional[Union[str, List[str]]] = [], + response_format: Optional[llama_types.ChatCompletionRequestResponseFormat] = None, + max_tokens: Optional[int] = None, + presence_penalty: float = 0.0, + frequency_penalty: float = 0.0, + repeat_penalty: float = 1.1, + tfs_z: float = 1.0, + mirostat_mode: int = 0, + mirostat_tau: float = 5.0, + mirostat_eta: float = 0.1, + model: Optional[str] = None, + logits_processor: Optional[llama.LogitsProcessorList] = None, + grammar: Optional[llama.LlamaGrammar] = None, # type: ignore[name-defined] + logprobs: Optional[bool] = None, + top_logprobs: Optional[int] = None, + **kwargs: Any, +) -> Union[ + llama_types.CreateChatCompletionResponse, + Iterator[llama_types.CreateChatCompletionStreamResponse], +]: + function_calling_template = ( + "{% for message in messages %}" + "<|im_start|>{{ message.role }}\n" + # System message + "{% if message.role == 'system' %}" + "{{ message.content }}" + "{% if tool_calls %}" + "\n\nYou have access to the following functions:\n" + "{% for tool in tools %}" + '\n{% if tool.function.get("description") %}/* {{ tool.function.description | trim }} */{% endif %}' + "\nfunctions.{{ tool.function.name }}:\n" + "{{ tool.function.parameters | tojson }}" + "\n{% endfor %}" + "\nYou must respond to user messages with either a single message or with one or more function calls." + "\n\nTo respond with a message use the following format:" + "\n\nmessage:" + "\n" + "\n\nTo respond with one or more function calls use the following format:" + "\n\n" + "\nfunctions.:" + '\n{ "arg1": "value1", "arg2": "value2" }' + "\nfunctions.:" + '\n{ "arg1": "value1", "arg2": "value2" }' + "\n" + "{% endif %}" + "<|im_end|>\n" + "{% endif %}" + # User message + "{% if message.role == 'user' %}" + "{{ message.content }}" + "<|im_end|>\n" + "{% endif %}" + # Assistant message + "{% if message.role == 'assistant' %}" + ## Regular message + "{% if message.content and message.content | length > 0 %}" + "{% if tool_calls %}" + "message:\n" + "{% endif %}" + "{{ message.content }}" + "<|im_end|>\n" + "{% endif %}" + ## Function calls + "{% if 'tool_calls' in message %}" + "{% for tool_call in message.tool_calls %}" + "functions.{{ tool_call.function.name }}:\n" + "{{ tool_call.function.arguments }}" + "{% endfor %}" + "<|im_end|>\n" + "{% endif %}" + "{% endif %}" + "{% endfor %}" + "{% if add_generation_prompt %}<|im_start|>assistant\n{% endif %}" + ) + template_renderer = ImmutableSandboxedEnvironment( + autoescape=jinja2.select_autoescape(["html", "xml"]), + undefined=jinja2.StrictUndefined, + ).from_string(function_calling_template) + + # Convert legacy functions to tools + if functions is not None: + tools = [{"type": "function", "function": function} for function in functions] + + # Convert legacy function_call to tool_choice + if function_call is not None: + if isinstance(function_call, str) and (function_call in ("none", "auto")): + tool_choice = function_call + if isinstance(function_call, dict) and "name" in function_call: + tool_choice = {"type": "function", "function": {"name": function_call["name"]}} + + # Collect the llama.create_completion keyword arguments so we don't have to repeat these with + # each completion call + stop = ( + [stop, "<|im_end|>", "|im_end|>"] + if isinstance(stop, str) + else [*stop, "<|im_end|>", "|im_end|>"] + if stop + else ["<|im_end|>", "|im_end|>"] + ) + grammar = ( # It is assumed the grammar applies to messages only, not tool calls + grammar + if grammar is not None + else ( + _grammar_for_response_format(response_format) + if response_format is not None and response_format["type"] == "json_object" + else None + ) + ) + completion_kwargs = { + "temperature": temperature, + "top_p": top_p, + "top_k": top_k, + "min_p": min_p, + "typical_p": typical_p, + "stream": stream, + "stop": stop, + "max_tokens": max_tokens, + "presence_penalty": presence_penalty, + "frequency_penalty": frequency_penalty, + "repeat_penalty": repeat_penalty, + "tfs_z": tfs_z, + "mirostat_mode": mirostat_mode, + "mirostat_tau": mirostat_tau, + "mirostat_eta": mirostat_eta, + "model": model, + "logits_processor": logits_processor, + "grammar": grammar, + } + + # Case 1: No tool use + if ( + tool_choice is None + or (isinstance(tool_choice, str) and tool_choice == "none") + or tools is None + or len(tools) == 0 + ): + prompt = template_renderer.render( + messages=messages, tools=[], tool_calls=None, add_generation_prompt=True + ) + return _convert_completion_to_chat( + llama.create_completion( + prompt=prompt, + **completion_kwargs, # type: ignore[arg-type] + logprobs=top_logprobs if logprobs else None, + ), + stream=stream, + ) + + # Ensure there is a system prompt to attach the tool metadata to + if not any(message["role"] == "system" for message in messages): + messages = [*messages, {"role": "system", "content": ""}] + + # Case 2: Automatic or fixed tool choice + # Case 2 step 1: Determine whether to respond with a message or a tool call + assert (isinstance(tool_choice, str) and tool_choice == "auto") or isinstance(tool_choice, dict) + if isinstance(tool_choice, dict): + tools = [t for t in tools if t["function"]["name"] == tool_choice["function"]["name"]] + assert tools + function_names = " | ".join([f'''"functions.{t['function']['name']}:"''' for t in tools]) + prompt = template_renderer.render( + messages=messages, tools=tools, tool_calls=True, add_generation_prompt=True + ) + initial_gbnf_tool_grammar = ( + ( + 'root ::= "" "\\n" functions | "message:"\n' + f"functions ::= {function_names}\n" + ) + if tool_choice == "auto" + else f'root ::= "" "\\n" functions\nfunctions ::= {function_names}\n' + ) + completion = cast( + llama_types.CreateCompletionResponse, + llama.create_completion( + prompt=prompt, + **{ # type: ignore[arg-type] + **completion_kwargs, + "temperature": 0, + "stream": False, + "stop": [":"], + "max_tokens": None, + "grammar": llama_grammar.LlamaGrammar.from_string( + initial_gbnf_tool_grammar, verbose=llama.verbose + ), + }, + ), + ) + text = completion["choices"][0]["text"] + tool_name = None if text.startswith("message") else text.split("\n")[-1][len("functions.") :] + + # Case 2 step 2A: Respond with a message + if tool_name is None: + prompt = template_renderer.render( + messages=messages, tools=[], tool_calls=None, add_generation_prompt=True + ) + return _convert_completion_to_chat( + llama.create_completion( + prompt=prompt, + **completion_kwargs, # type: ignore[arg-type] + logprobs=top_logprobs if logprobs else None, + ), + stream=stream, + ) + + # Case 2 step 2B: One or more function calls + follow_up_gbnf_tool_grammar = ( + 'root ::= functions | "" | "<|im_end|>"\n' + f"functions ::= {function_names}\n" + ) + prompt += "\n" + if stream: + return _stream_tool_calls( + llama, prompt, tools, tool_name, completion_kwargs, follow_up_gbnf_tool_grammar + ) + tool = next((tool for tool in tools if tool["function"]["name"] == tool_name), None) + completions: List[llama_types.CreateCompletionResponse] = [] + completions_tool_name: List[str] = [] + while tool is not None and len(completions) <= 16: # noqa: PLR2004 + # Generate the parameter values for the selected tool + prompt += f"functions.{tool_name}:\n" + try: + grammar = llama_grammar.LlamaGrammar.from_json_schema( + json.dumps(tool["function"]["parameters"]), verbose=llama.verbose + ) + except Exception as e: + warnings.warn( + f"Failed to parse function body as JSON schema, falling back to default grammar\n\n{e}", + category=RuntimeWarning, + stacklevel=2, + ) + grammar = llama_grammar.LlamaGrammar.from_string( + llama_grammar.JSON_GBNF, verbose=llama.verbose + ) + completion_or_chunks = llama.create_completion( + prompt=prompt, + **{ # type: ignore[arg-type] + **completion_kwargs, + "max_tokens": None, + "grammar": grammar, + }, + ) + completion = cast(llama_types.CreateCompletionResponse, completion_or_chunks) + completions.append(completion) + completions_tool_name.append(tool_name) + prompt += completion["choices"][0]["text"] + prompt += "\n" + # Determine whether to call another tool or stop + response = cast( + llama_types.CreateCompletionResponse, + llama.create_completion( + prompt=prompt, + **{ # type: ignore[arg-type] + **completion_kwargs, + "temperature": 0, + "stream": False, + "stop": [*completion_kwargs["stop"], ":", ""], # type: ignore[misc] + "max_tokens": None, + "grammar": llama_grammar.LlamaGrammar.from_string( + follow_up_gbnf_tool_grammar, verbose=llama.verbose + ), + }, + ), + ) + tool_name = response["choices"][0]["text"][len("functions.") :] + tool = next((tool for tool in tools if tool["function"]["name"] == tool_name), None) + # Merge the completions into a single chat completion + chat_completion: llama_types.CreateChatCompletionResponse = { + "id": "chat" + completion["id"], + "object": "chat.completion", + "created": completion["created"], + "model": completion["model"], + "choices": [ + { + "finish_reason": "tool_calls", + "index": 0, + "logprobs": completion["choices"][0]["logprobs"], + "message": { + "role": "assistant", + "content": None, + "tool_calls": [ + { + "id": "call_" + f"_{i}_" + tool_name + "_" + completion["id"], + "type": "function", + "function": { + "name": tool_name, + "arguments": completion["choices"][0]["text"], + }, + } + for i, (tool_name, completion) in enumerate( + zip(completions_tool_name, completions, strict=True) + ) + ], + }, + } + ], + "usage": { + "completion_tokens": sum( + (completion["usage"]["completion_tokens"] if "usage" in completion else 0) + for completion in completions + ), + "prompt_tokens": sum( + completion["usage"]["prompt_tokens"] if "usage" in completion else 0 + for completion in completions + ), + "total_tokens": sum( + completion["usage"]["total_tokens"] if "usage" in completion else 0 + for completion in completions + ), + }, + } + if len(completions) == 1: + single_function_call: llama_types.ChatCompletionResponseFunctionCall = { + "name": tool_name, + "arguments": completions[0]["choices"][0]["text"], + } + chat_completion["choices"][0]["message"]["function_call"] = single_function_call + return chat_completion diff --git a/src/raglite/_litellm.py b/src/raglite/_litellm.py index e0bab3b..973cdb4 100644 --- a/src/raglite/_litellm.py +++ b/src/raglite/_litellm.py @@ -31,6 +31,7 @@ LlamaRAMCache, ) +from raglite._chatml_function_calling import chatml_function_calling_with_streaming from raglite._config import RAGLiteConfig # Reduce the logging level for LiteLLM, flashrank, and httpx. @@ -116,8 +117,8 @@ def llm(model: str, **kwargs: Any) -> Llama: n_ctx=n_ctx, n_gpu_layers=-1, verbose=False, - # Enable function calling. - chat_format="chatml-function-calling", + # Enable function calling with streaming. + chat_handler=chatml_function_calling_with_streaming, # Workaround to enable long context embedding models [1]. # [1] https://github.com/abetlen/llama-cpp-python/issues/1762 n_batch=n_ctx if n_ctx > 0 else 1024, diff --git a/src/raglite/_mcp.py b/src/raglite/_mcp.py index 094d69a..8d72a70 100644 --- a/src/raglite/_mcp.py +++ b/src/raglite/_mcp.py @@ -13,9 +13,8 @@ Field( description=( "The `query` string to search the knowledge base with.\n" - "The `query` string MUST satisfy ALL of the following criteria:\n" - "- The `query` string MUST be a precise question in the user's language.\n" - "- The `query` string MUST resolve all pronouns to explicit nouns from the conversation history." + "The `query` string MUST be a precise single-faceted question in the user's language.\n" + "The `query` string MUST resolve all pronouns to explicit nouns from the conversation history." ) ), ] diff --git a/src/raglite/_rag.py b/src/raglite/_rag.py index 689fc1a..acb7423 100644 --- a/src/raglite/_rag.py +++ b/src/raglite/_rag.py @@ -98,41 +98,27 @@ def _get_tools( if not messages_contain_rag_context and not llm_supports_function_calling: error_message = "You must either explicitly provide RAG context in the last message, or use an LLM that supports function calling." raise ValueError(error_message) - # Add a tool to search the knowledge base if no RAG context is provided in the messages. Because - # llama-cpp-python cannot stream tool_use='auto' yet, we use a workaround that forces the LLM - # to use a tool, but allows it to skip the search. - auto_tool_use_workaround = ( - { - "expert": { - "type": "boolean", - "description": "The `expert` boolean MUST be true if the question requires domain-specific or expert-level knowledge to answer, and false otherwise.", - } - } - if config.llm.startswith("llama-cpp-python") - else {} - ) + # Return a single tool to search the knowledge base if no RAG context is provided. tools: list[dict[str, Any]] | None = ( [ { "type": "function", "function": { "name": "search_knowledge_base", - "description": "Search the knowledge base. IMPORTANT: Only use this tool if a well-rounded non-expert would need to look up information to answer the question.", + "description": "Search the knowledge base. IMPORTANT: You MAY NOT use this function if the query can be answered with common knowledge or straightforward reasoning.", "parameters": { "type": "object", "properties": { - **auto_tool_use_workaround, "query": { "type": "string", "description": ( "The `query` string to search the knowledge base with.\n" - "The `query` string MUST satisfy ALL of the following criteria:\n" - "- The `query` string MUST be a precise question in the user's language.\n" - "- The `query` string MUST resolve all pronouns to explicit nouns from the conversation history." + "The `query` string MUST be a precise single-faceted question in the user's language.\n" + "The `query` string MUST resolve all pronouns to explicit nouns from the conversation history." ), }, }, - "required": [*list(auto_tool_use_workaround), "query"], + "required": ["query"], "additionalProperties": False, }, }, @@ -141,15 +127,7 @@ def _get_tools( if not messages_contain_rag_context else None ) - tool_choice: dict[str, Any] | str | None = ( - ( - {"type": "function", "function": {"name": "search_knowledge_base"}} - if auto_tool_use_workaround - else "auto" - ) - if tools - else None - ) + tool_choice: dict[str, Any] | str | None = "auto" if tools else None return tools, tool_choice @@ -164,19 +142,16 @@ def _run_tools( if tool_call.function.name == "search_knowledge_base": kwargs = json.loads(tool_call.function.arguments) kwargs["config"] = config - skip = not kwargs.pop("expert", True) - chunk_spans = retrieve_rag_context(**kwargs) if not skip and kwargs["query"] else None + chunk_spans = retrieve_rag_context(**kwargs) tool_messages.append( { "role": "tool", "content": '{{"documents": [{elements}]}}'.format( elements=", ".join( chunk_span.to_json(index=i + 1) - for i, chunk_span in enumerate(chunk_spans) # type: ignore[arg-type] + for i, chunk_span in enumerate(chunk_spans) ) - ) - if not skip and kwargs["query"] - else "{}", + ), "tool_call_id": tool_call.id, } ) @@ -198,23 +173,14 @@ def rag( max_tokens = get_context_size(config) tools, tool_choice = _get_tools(messages, config) # Stream the LLM response, which is either a tool call request or an assistant response. - chunks = [] - clipped_messages = _clip(messages, max_tokens) - if tools and config.llm.startswith("llama-cpp-python"): - # Help llama.cpp LLMs plan their response by providing a JSON schema for the tool call. - clipped_messages[-1]["content"] += ( - "\n\n\n" - f"Available tools:\n```\n{json.dumps(tools)}\n```\n" - "IMPORTANT: The `expert` boolean MUST be true if the question requires domain-specific or expert-level knowledge to answer, and false otherwise.\n" - "" - ) stream = completion( model=config.llm, - messages=clipped_messages, + messages=_clip(messages, max_tokens), tools=tools, tool_choice=tool_choice, stream=True, ) + chunks = [] for chunk in stream: chunks.append(chunk) if isinstance(token := chunk.choices[0].delta.content, str): @@ -249,23 +215,14 @@ async def async_rag( max_tokens = get_context_size(config) tools, tool_choice = _get_tools(messages, config) # Asynchronously stream the LLM response, which is either a tool call or an assistant response. - chunks = [] - clipped_messages = _clip(messages, max_tokens) - if tools and config.llm.startswith("llama-cpp-python"): - # Help llama.cpp LLMs plan their response by providing a JSON schema for the tool call. - clipped_messages[-1]["content"] += ( - "\n\n\n" - f"Available tools:\n```\n{json.dumps(tools)}\n```\n" - "IMPORTANT: The `expert` boolean MUST be true if the question requires domain-specific or expert-level knowledge to answer, and false otherwise.\n" - "" - ) async_stream = await acompletion( model=config.llm, - messages=clipped_messages, + messages=_clip(messages, max_tokens), tools=tools, tool_choice=tool_choice, stream=True, ) + chunks = [] async for chunk in async_stream: chunks.append(chunk) if isinstance(token := chunk.choices[0].delta.content, str): diff --git a/tests/conftest.py b/tests/conftest.py index 943250c..e067355 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -7,7 +7,6 @@ from pathlib import Path import pytest -from llama_cpp import llama_supports_gpu_offload from sqlalchemy import create_engine, text from raglite import RAGLiteConfig, insert_document @@ -24,11 +23,6 @@ def is_postgres_running() -> bool: return False -def is_accelerator_available() -> bool: - """Check if an accelerator is available.""" - return llama_supports_gpu_offload() or (os.cpu_count() or 1) >= 8 # noqa: PLR2004 - - def is_openai_available() -> bool: """Check if an OpenAI API key is set.""" return bool(os.environ.get("OPENAI_API_KEY")) @@ -79,7 +73,7 @@ def database(request: pytest.FixtureRequest) -> str: "llama-cpp-python/bartowski/Llama-3.2-3B-Instruct-GGUF/*Q4_K_M.gguf@4096", "llama-cpp-python/lm-kit/bge-m3-gguf/*Q4_K_M.gguf@1024", # More context degrades performance. ), - id="llama32_3B-bge_m3", + id="llama_3.2_3B-bge_m3", ), pytest.param( ("gpt-4o-mini", "text-embedding-3-small"), diff --git a/tests/test_chatml_function_calling.py b/tests/test_chatml_function_calling.py new file mode 100644 index 0000000..2ef4edb --- /dev/null +++ b/tests/test_chatml_function_calling.py @@ -0,0 +1,129 @@ +"""Test RAGLite's upgraded chatml-function-calling llama-cpp-python chat handler.""" + +import os +from collections.abc import Iterator +from typing import cast + +import pytest +from llama_cpp import Llama, llama_supports_gpu_offload +from llama_cpp.llama_types import ( + ChatCompletionRequestMessage, + ChatCompletionTool, + ChatCompletionToolChoiceOption, + CreateChatCompletionResponse, + CreateChatCompletionStreamResponse, +) +from typeguard import ForwardRefPolicy, check_type + +from raglite._chatml_function_calling import chatml_function_calling_with_streaming + + +def is_accelerator_available() -> bool: + """Check if an accelerator is available.""" + return llama_supports_gpu_offload() or (os.cpu_count() or 1) >= 8 # noqa: PLR2004 + + +@pytest.mark.parametrize( + "stream", + [ + pytest.param(True, id="stream=True"), + pytest.param(False, id="stream=False"), + ], +) +@pytest.mark.parametrize( + "tool_choice", + [ + pytest.param("none", id="tool_choice=none"), + pytest.param("auto", id="tool_choice=auto"), + pytest.param( + {"type": "function", "function": {"name": "get_weather"}}, id="tool_choice=fixed" + ), + ], +) +@pytest.mark.parametrize( + "user_prompt_expected_tool_calls", + [ + pytest.param( + ("Is 7 a prime number?", 0), + id="expected_tool_calls=0", + ), + pytest.param( + ("What's the weather like in Paris today?", 1), + id="expected_tool_calls=1", + ), + pytest.param( + ("What's the weather like in Paris today? What about New York?", 2), + id="expected_tool_calls=2", + ), + ], +) +@pytest.mark.parametrize( + "llm_repo_id", + [ + pytest.param("bartowski/Llama-3.2-3B-Instruct-GGUF", id="llama_3.2_3B"), + pytest.param( + "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF", + id="llama_3.1_8B", + marks=pytest.mark.skipif( + not is_accelerator_available(), reason="Accelerator not available" + ), + ), + ], +) +def test_llama_cpp_python_tool_use( + llm_repo_id: str, + user_prompt_expected_tool_calls: tuple[str, int], + tool_choice: ChatCompletionToolChoiceOption, + stream: bool, # noqa: FBT001 +) -> None: + """Test the upgraded chatml-function-calling llama-cpp-python chat handler.""" + user_prompt, expected_tool_calls = user_prompt_expected_tool_calls + if isinstance(tool_choice, dict) and expected_tool_calls == 0: + pytest.skip("Nonsensical") + llm = Llama.from_pretrained( + repo_id=llm_repo_id, + filename="*Q4_K_M.gguf", + n_ctx=4096, + n_gpu_layers=-1, + verbose=False, + chat_handler=chatml_function_calling_with_streaming, + ) + messages: list[ChatCompletionRequestMessage] = [{"role": "user", "content": user_prompt}] + tools: list[ChatCompletionTool] = [ + { + "type": "function", + "function": { + "name": "get_weather", + "description": "Get the weather for a location.", + "parameters": { + "type": "object", + "properties": {"location": {"type": "string", "description": "A city name."}}, + }, + }, + } + ] + response = llm.create_chat_completion( + messages=messages, tools=tools, tool_choice=tool_choice, stream=stream + ) + if stream: + response = cast(Iterator[CreateChatCompletionStreamResponse], response) + num_tool_calls = 0 + for chunk in response: + check_type(chunk, CreateChatCompletionStreamResponse) + tool_calls = chunk["choices"][0]["delta"].get("tool_calls") + if isinstance(tool_calls, list): + num_tool_calls = max(tool_call["index"] for tool_call in tool_calls) + 1 + assert num_tool_calls == (expected_tool_calls if tool_choice != "none" else 0) + else: + response = cast(CreateChatCompletionResponse, response) + check_type( + response, CreateChatCompletionResponse, forward_ref_policy=ForwardRefPolicy.IGNORE + ) + if expected_tool_calls == 0 or tool_choice == "none": + assert response["choices"][0]["message"].get("tool_calls") is None + else: + assert len(response["choices"][0]["message"]["tool_calls"]) == expected_tool_calls + assert all( + tool_call["function"]["name"] == tools[0]["function"]["name"] + for tool_call in response["choices"][0]["message"]["tool_calls"] + ) diff --git a/tests/test_rag.py b/tests/test_rag.py index 2fb8a1c..b7d6280 100644 --- a/tests/test_rag.py +++ b/tests/test_rag.py @@ -1,10 +1,13 @@ """Test RAGLite's RAG functionality.""" +import json + from raglite import ( RAGLiteConfig, create_rag_instruction, retrieve_rag_context, ) +from raglite._database import ChunkSpan from raglite._rag import rag @@ -22,3 +25,38 @@ def test_rag_manual(raglite_test_config: RAGLiteConfig) -> None: assert "event" in answer.lower() # Verify that no RAG context was retrieved through tool use. assert [message["role"] for message in messages] == ["user", "assistant"] + + +def test_rag_auto_with_retrieval(raglite_test_config: RAGLiteConfig) -> None: + """Test Retrieval-Augmented Generation with automatic retrieval.""" + # Answer a question that requires RAG. + user_prompt = "How does Einstein define 'simultaneous events' in his special relativity paper?" + messages = [{"role": "user", "content": user_prompt}] + chunk_spans = [] + stream = rag(messages, on_retrieval=lambda x: chunk_spans.extend(x), config=raglite_test_config) + answer = "" + for update in stream: + assert isinstance(update, str) + answer += update + assert "event" in answer.lower() + # Verify that RAG context was retrieved automatically. + assert [message["role"] for message in messages] == ["user", "assistant", "tool", "assistant"] + assert json.loads(messages[-2]["content"]) + assert chunk_spans + assert all(isinstance(chunk_span, ChunkSpan) for chunk_span in chunk_spans) + + +def test_rag_auto_without_retrieval(raglite_test_config: RAGLiteConfig) -> None: + """Test Retrieval-Augmented Generation with automatic retrieval.""" + # Answer a question that does not require RAG. + user_prompt = "Is 7 a prime number?" + messages = [{"role": "user", "content": user_prompt}] + chunk_spans = [] + stream = rag(messages, on_retrieval=lambda x: chunk_spans.extend(x), config=raglite_test_config) + answer = "" + for update in stream: + assert isinstance(update, str) + answer += update + # Verify that no RAG context was retrieved. + assert [message["role"] for message in messages] == ["user", "assistant"] + assert not chunk_spans