diff --git a/garak/generators/openai.py b/garak/generators/openai.py index c120387c7..88dda858d 100644 --- a/garak/generators/openai.py +++ b/garak/generators/openai.py @@ -14,6 +14,7 @@ import json import logging import re +import tiktoken from typing import List, Union import openai @@ -113,12 +114,24 @@ "gpt-4o": 128000, "gpt-4o-2024-05-13": 128000, "gpt-4o-2024-08-06": 128000, - "gpt-4o-mini": 16384, + "gpt-4o-mini": 128000, "gpt-4o-mini-2024-07-18": 16384, - "o1-mini": 65536, + "o1": 200000, + "o1-mini": 128000, "o1-mini-2024-09-12": 65536, "o1-preview": 32768, "o1-preview-2024-09-12": 32768, + "o3-mini": 200000, +} + +output_max = { + "gpt-3.5-turbo": 4096, + "gpt-4": 8192, + "gpt-4o": 16384, + "o3-mini": 100000, + "o1": 100000, + "o1-mini": 65536, + "gpt-4o-mini": 16384, } @@ -171,6 +184,75 @@ def _clear_client(self): def _validate_config(self): pass + def _validate_token_args(self, create_args: dict, prompt: str) -> dict: + """Ensure maximum token limit compatibility with OpenAI create request""" + token_limit_key = "max_tokens" + fixed_cost = 0 + if ( + self.generator == self.client.chat.completions + and self.max_tokens is not None + ): + token_limit_key = "max_completion_tokens" + if not hasattr(self, "max_completion_tokens"): + create_args["max_completion_tokens"] = self.max_tokens + + create_args.pop( + "max_tokens", None + ) # remove deprecated value, utilize `max_completion_tokens` + # every reply is primed with <|start|>assistant<|message|> (3 toks) plus 1 for name change + # see https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb + # section 6 "Counting tokens for chat completions API calls" + fixed_cost = 7 + + # basic token boundary validation to ensure requests are not rejected for exceeding target context length + token_limit = create_args.pop(token_limit_key, None) + if token_limit is not None: + # Suppress max_tokens if greater than context_len + if ( + hasattr(self, "context_len") + and self.context_len is not None + and token_limit > self.context_len + ): + logging.warning( + f"Requested garak maximum tokens {token_limit} exceeds context length {self.context_len}, no limit will be applied to the request" + ) + token_limit = None + + if self.name in output_max and token_limit > output_max[self.name]: + logging.warning( + f"Requested maximum tokens {token_limit} exceeds max output {output_max[self.name]}, no limit will be applied to the request" + ) + token_limit = None + + if self.context_len is not None and token_limit is not None: + # count tokens in prompt and ensure token_limit requested is <= context_len or output_max allowed + prompt_tokens = 0 # this should apply to messages object + try: + encoding = tiktoken.encoding_for_model(self.name) + prompt_tokens = len(encoding.encode(prompt)) + except KeyError as e: + prompt_tokens = int( + len(prompt.split()) * 4 / 3 + ) # extra naive fallback 1 token ~= 3/4 of a word + + if (prompt_tokens + fixed_cost + token_limit > self.context_len) and ( + prompt_tokens + fixed_cost < self.context_len + ): + token_limit = self.context_len - prompt_tokens - fixed_cost + elif token_limit > prompt_tokens + fixed_cost: + token_limit = token_limit - prompt_tokens - fixed_cost + else: + raise garak.exception.GarakException( + "A response of %s toks plus prompt %s toks cannot be generated; API capped at context length %s toks" + % ( + self.max_tokens, + prompt_tokens + fixed_cost, + self.context_len, + ) + ) + create_args[token_limit_key] = token_limit + return create_args + def __init__(self, name="", config_root=_config): self.name = name self._load_config(config_root) @@ -216,13 +298,16 @@ def _call_model( create_args = {} if "n" not in self.suppressed_params: create_args["n"] = generations_this_call - for arg in inspect.signature(self.generator.create).parameters: + create_params = inspect.signature(self.generator.create).parameters + for arg in create_params: if arg == "model": create_args[arg] = self.name continue if hasattr(self, arg) and arg not in self.suppressed_params: create_args[arg] = getattr(self, arg) + create_args = self._validate_token_args(create_args, prompt) + if self.generator == self.client.completions: if not isinstance(prompt, str): msg = ( diff --git a/tests/generators/test_openai_compatible.py b/tests/generators/test_openai_compatible.py index db676da5c..342029194 100644 --- a/tests/generators/test_openai_compatible.py +++ b/tests/generators/test_openai_compatible.py @@ -9,14 +9,18 @@ import inspect from collections.abc import Iterable -from garak.generators.openai import OpenAICompatible +from garak.generators.openai import OpenAICompatible, output_max, context_lengths # TODO: expand this when we have faster loading, currently to process all generator costs 30s for 3 tests # GENERATORS = [ # classname for (classname, active) in _plugins.enumerate_plugins("generators") # ] -GENERATORS = ["generators.openai.OpenAIGenerator", "generators.nim.NVOpenAIChat", "generators.groq.GroqChat"] +GENERATORS = [ + "generators.openai.OpenAIGenerator", + "generators.nim.NVOpenAIChat", + "generators.groq.GroqChat", +] MODEL_NAME = "gpt-3.5-turbo-instruct" ENV_VAR = os.path.abspath( @@ -98,3 +102,113 @@ def test_openai_multiprocessing(openai_compat_mocks, classname): with Pool(parallel_attempts) as attempt_pool: for result in attempt_pool.imap_unordered(generate_in_subprocess, prompts): assert result is not None + + +def test_validate_call_model_chat_token_restrictions(openai_compat_mocks): + import lorem + import json + import tiktoken + from garak.exception import GarakException + + generator = build_test_instance(OpenAICompatible) + mock_url = getattr(generator, "uri", "https://api.openai.com/v1") + with respx.mock(base_url=mock_url, assert_all_called=False) as respx_mock: + mock_response = openai_compat_mocks["chat"] + respx_mock.post("chat/completions").mock( + return_value=httpx.Response( + mock_response["code"], json=mock_response["json"] + ) + ) + generator._call_model("test values") + req_body = json.loads(respx_mock.routes[0].calls[0].request.content) + assert ( + req_body["max_completion_tokens"] <= generator.max_tokens + ), "request max_completion_tokens must account for prompt tokens" + + test_large_context = "" + encoding = tiktoken.encoding_for_model(MODEL_NAME) + while len(encoding.encode(test_large_context)) < generator.max_tokens: + test_large_context += "\n".join(lorem.paragraph()) + large_context_len = len(encoding.encode(test_large_context)) + + generator.context_len = large_context_len * 2 + generator.max_tokens = generator.context_len * 2 + generator._call_model("test values") + req_body = json.loads(respx_mock.routes[0].calls[1].request.content) + assert ( + req_body.get("max_completion_tokens", None) is None + and req_body.get("max_tokens", None) is None + ), "request max_completion_tokens is suppressed when larger than context length" + + generator.max_tokens = large_context_len - int(large_context_len / 2) + generator.context_len = large_context_len + with pytest.raises(GarakException) as exc_info: + generator._call_model(test_large_context) + assert "API capped" in str( + exc_info.value + ), "a prompt larger than max_tokens must raise exception" + + max_output_model = "gpt-3.5-turbo" + generator.name = max_output_model + generator.max_tokens = output_max[max_output_model] * 2 + generator.context_len = generator.max_tokens * 2 + generator._call_model("test values") + req_body = json.loads(respx_mock.routes[0].calls[2].request.content) + assert ( + req_body.get("max_completion_tokens", None) is None + and req_body.get("max_tokens", None) is None + ), "request max_completion_tokens is suppressed when larger than output_max limited known model" + + generator.max_completion_tokens = int(output_max[max_output_model] / 2) + generator._call_model("test values") + req_body = json.loads(respx_mock.routes[0].calls[3].request.content) + assert ( + req_body["max_completion_tokens"] < generator.max_completion_tokens + and req_body.get("max_tokens", None) is None + ), "request max_completion_tokens is suppressed when larger than output_max limited known model" + + +def test_validate_call_model_completion_token_restrictions(openai_compat_mocks): + import lorem + import json + import tiktoken + from garak.exception import GarakException + + generator = build_test_instance(OpenAICompatible) + generator._load_client() + generator.generator = generator.client.completions + mock_url = getattr(generator, "uri", "https://api.openai.com/v1") + with respx.mock(base_url=mock_url, assert_all_called=False) as respx_mock: + mock_response = openai_compat_mocks["completion"] + respx_mock.post("/completions").mock( + return_value=httpx.Response( + mock_response["code"], json=mock_response["json"] + ) + ) + generator._call_model("test values") + req_body = json.loads(respx_mock.routes[0].calls[0].request.content) + assert ( + req_body["max_tokens"] <= generator.max_tokens + ), "request max_tokens must account for prompt tokens" + + test_large_context = "" + encoding = tiktoken.encoding_for_model(MODEL_NAME) + while len(encoding.encode(test_large_context)) < generator.max_tokens: + test_large_context += "\n".join(lorem.paragraph()) + large_context_len = len(encoding.encode(test_large_context)) + + generator.context_len = large_context_len * 2 + generator.max_tokens = generator.context_len * 2 + generator._call_model("test values") + req_body = json.loads(respx_mock.routes[0].calls[1].request.content) + assert ( + req_body.get("max_tokens", None) is None + ), "request max_tokens is suppressed when larger than context length" + + generator.max_tokens = large_context_len - int(large_context_len / 2) + generator.context_len = large_context_len + with pytest.raises(GarakException) as exc_info: + generator._call_model(test_large_context) + assert "API capped" in str( + exc_info.value + ), "a prompt larger than max_tokens must raise exception"