From fa823b0b4a31b3dfc32d6d6d2123e35a6259d4b6 Mon Sep 17 00:00:00 2001 From: Jeffrey Martin Date: Tue, 18 Feb 2025 12:28:08 -0600 Subject: [PATCH 1/8] estimate token use before sending openai completions Signed-off-by: Jeffrey Martin --- garak/generators/openai.py | 27 +++++++++++++ tests/generators/test_openai_compatible.py | 45 +++++++++++++++++++++- 2 files changed, 71 insertions(+), 1 deletion(-) diff --git a/garak/generators/openai.py b/garak/generators/openai.py index c120387c7..ab5a0d7cf 100644 --- a/garak/generators/openai.py +++ b/garak/generators/openai.py @@ -14,6 +14,7 @@ import json import logging import re +import tiktoken from typing import List, Union import openai @@ -223,6 +224,32 @@ def _call_model( if hasattr(self, arg) and arg not in self.suppressed_params: create_args[arg] = getattr(self, arg) + # basic token boundary validation to ensure requests are not rejected for exceeding target context length + generation_max_tokens = create_args.get("max_tokens", None) + if generation_max_tokens is not None: + # count tokens in prompt and ensure max_tokens requested is <= context_len allowed + if ( + hasattr(self, "context_len") + and self.context_len is not None + and generation_max_tokens > self.context_len + ): + logging.warning( + f"Requested max_tokens {generation_max_tokens} exceeds context length {self.context_len}, reducing requested maximum" + ) + generation_max_tokens = self.context_len + prompt_tokens = 0 + try: + encoding = tiktoken.encoding_for_model(self.name) + prompt_tokens = len(encoding.encode(prompt)) + except KeyError as e: + prompt_tokens = len(prompt.split()) # extra naive fallback + generation_max_tokens -= prompt_tokens + create_args["max_tokens"] = generation_max_tokens + if generation_max_tokens < 1: # allow at least a binary result token + raise garak.exception.GarakException( + "A response cannot be created within the available context length" + ) + if self.generator == self.client.completions: if not isinstance(prompt, str): msg = ( diff --git a/tests/generators/test_openai_compatible.py b/tests/generators/test_openai_compatible.py index db676da5c..4a54ce5b0 100644 --- a/tests/generators/test_openai_compatible.py +++ b/tests/generators/test_openai_compatible.py @@ -16,7 +16,11 @@ # GENERATORS = [ # classname for (classname, active) in _plugins.enumerate_plugins("generators") # ] -GENERATORS = ["generators.openai.OpenAIGenerator", "generators.nim.NVOpenAIChat", "generators.groq.GroqChat"] +GENERATORS = [ + "generators.openai.OpenAIGenerator", + "generators.nim.NVOpenAIChat", + "generators.groq.GroqChat", +] MODEL_NAME = "gpt-3.5-turbo-instruct" ENV_VAR = os.path.abspath( @@ -98,3 +102,42 @@ def test_openai_multiprocessing(openai_compat_mocks, classname): with Pool(parallel_attempts) as attempt_pool: for result in attempt_pool.imap_unordered(generate_in_subprocess, prompts): assert result is not None + + +def test_validate_call_model_token_restrictions(openai_compat_mocks): + import lorem + import json + from garak.exception import GarakException + + generator = build_test_instance(OpenAICompatible) + mock_url = getattr(generator, "uri", "https://api.openai.com/v1") + with respx.mock(base_url=mock_url, assert_all_called=False) as respx_mock: + mock_response = openai_compat_mocks["chat"] + respx_mock.post("chat/completions").mock( + return_value=httpx.Response( + mock_response["code"], json=mock_response["json"] + ) + ) + generator._call_model("test values") + resp_body = json.loads(respx_mock.routes[0].calls[0].request.content) + assert ( + resp_body["max_tokens"] < generator.max_tokens + ), "request max_tokens must account for prompt tokens" + + test_large_context = "" + while len(test_large_context.split()) < generator.max_tokens: + test_large_context += "\n".join(lorem.paragraph()) + large_context_len = len(test_large_context.split()) + with pytest.raises(GarakException) as exc_info: + generator._call_model(test_large_context) + assert "cannot be created" in str( + exc_info.value + ), "a prompt large then max_tokens must raise exception" + + generator.context_len = large_context_len * 2 + generator.max_tokens = generator.context_len - (large_context_len / 2) + generator._call_model("test values") + resp_body = json.loads(respx_mock.routes[0].calls[1].request.content) + assert ( + resp_body["max_tokens"] < generator.context_len + ), "request max_tokens must me less than model context length" From bcca18b6e2ab48246b7246ef794bd96e421dcb78 Mon Sep 17 00:00:00 2001 From: Jeffrey Martin Date: Wed, 26 Feb 2025 09:26:29 -0600 Subject: [PATCH 2/8] update test failure reasons Signed-off-by: Jeffrey Martin --- tests/generators/test_openai_compatible.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/generators/test_openai_compatible.py b/tests/generators/test_openai_compatible.py index 4a54ce5b0..0431ab425 100644 --- a/tests/generators/test_openai_compatible.py +++ b/tests/generators/test_openai_compatible.py @@ -132,7 +132,7 @@ def test_validate_call_model_token_restrictions(openai_compat_mocks): generator._call_model(test_large_context) assert "cannot be created" in str( exc_info.value - ), "a prompt large then max_tokens must raise exception" + ), "a prompt larger than max_tokens must raise exception" generator.context_len = large_context_len * 2 generator.max_tokens = generator.context_len - (large_context_len / 2) @@ -140,4 +140,4 @@ def test_validate_call_model_token_restrictions(openai_compat_mocks): resp_body = json.loads(respx_mock.routes[0].calls[1].request.content) assert ( resp_body["max_tokens"] < generator.context_len - ), "request max_tokens must me less than model context length" + ), "request max_tokens must be less than model context length" From f7fb4812fc93ccbbc45f3548dbe4d511ec128423 Mon Sep 17 00:00:00 2001 From: Jeffrey Martin Date: Fri, 28 Feb 2025 10:46:20 -0600 Subject: [PATCH 3/8] a little better extra naive fallback Signed-off-by: Jeffrey Martin --- garak/generators/openai.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/garak/generators/openai.py b/garak/generators/openai.py index ab5a0d7cf..3eb656f4b 100644 --- a/garak/generators/openai.py +++ b/garak/generators/openai.py @@ -242,7 +242,9 @@ def _call_model( encoding = tiktoken.encoding_for_model(self.name) prompt_tokens = len(encoding.encode(prompt)) except KeyError as e: - prompt_tokens = len(prompt.split()) # extra naive fallback + prompt_tokens = int( + len(prompt.split()) * 4 / 3 + ) # extra naive fallback 1 token ~= 3/4 of a word generation_max_tokens -= prompt_tokens create_args["max_tokens"] = generation_max_tokens if generation_max_tokens < 1: # allow at least a binary result token From 7ac7349dccef55cb48209f57ef8cfb38dc0ad889 Mon Sep 17 00:00:00 2001 From: Leon Derczynski Date: Wed, 5 Mar 2025 09:46:15 +0100 Subject: [PATCH 4/8] update param to reflect deprecated max_tokens; include chat overhead in calculation; add numbers to exception message; adjust algebra to avoid false firing --- garak/generators/openai.py | 40 +++++++++++++++++++++++++++----------- 1 file changed, 29 insertions(+), 11 deletions(-) diff --git a/garak/generators/openai.py b/garak/generators/openai.py index 3eb656f4b..bb5a41547 100644 --- a/garak/generators/openai.py +++ b/garak/generators/openai.py @@ -141,7 +141,7 @@ class OpenAICompatible(Generator): "presence_penalty": 0.0, "seed": None, "stop": ["#", ";"], - "suppressed_params": set(), + "suppressed_params": {"max_tokens"}, # deprecated "retry_json": True, } @@ -224,32 +224,50 @@ def _call_model( if hasattr(self, arg) and arg not in self.suppressed_params: create_args[arg] = getattr(self, arg) + if self.max_tokens is not None and not hasattr(self, "max_completion_tokens"): + create_args["max_completion_tokens"] = self.max_tokens + # basic token boundary validation to ensure requests are not rejected for exceeding target context length - generation_max_tokens = create_args.get("max_tokens", None) - if generation_max_tokens is not None: + max_completion_tokens = create_args.get("max_completion_tokens", None) + if max_completion_tokens is not None: # count tokens in prompt and ensure max_tokens requested is <= context_len allowed if ( hasattr(self, "context_len") and self.context_len is not None - and generation_max_tokens > self.context_len + and max_completion_tokens > self.context_len ): logging.warning( - f"Requested max_tokens {generation_max_tokens} exceeds context length {self.context_len}, reducing requested maximum" + f"Requested garak max_tokens {max_completion_tokens} exceeds context length {self.context_len}, reducing requested maximum" ) - generation_max_tokens = self.context_len - prompt_tokens = 0 + max_completion_tokens = self.context_len + + prompt_tokens = 0 # this should apply to messages object try: encoding = tiktoken.encoding_for_model(self.name) prompt_tokens = len(encoding.encode(prompt)) + print("prompt tokens:", prompt_tokens) except KeyError as e: prompt_tokens = int( len(prompt.split()) * 4 / 3 ) # extra naive fallback 1 token ~= 3/4 of a word - generation_max_tokens -= prompt_tokens - create_args["max_tokens"] = generation_max_tokens - if generation_max_tokens < 1: # allow at least a binary result token + + fixed_cost = 0 + if self.generator == self.client.chat.completions: + # every reply is primed with <|start|>assistant<|message|> (3 toks) plus 1 for name change + # see https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb + # section 6 "Counting tokens for chat completions API calls" + fixed_cost += 7 + + if self.context_len is not None and ( + prompt_tokens + fixed_cost + max_completion_tokens > self.context_len + ): raise garak.exception.GarakException( - "A response cannot be created within the available context length" + "A response of %s toks plus prompt %s toks cannot be generated; API capped at context length %s toks" + % ( + self.max_tokens, + prompt_tokens + fixed_cost, + self.context_len, + ) # max_completion_tokens will already be adjusted down ) if self.generator == self.client.completions: From 577fcf6181c32f29d26e16de54d3ca55351e72ed Mon Sep 17 00:00:00 2001 From: Leon Derczynski Date: Wed, 5 Mar 2025 09:54:20 +0100 Subject: [PATCH 5/8] include some max output values, correct some ctx len values, reduce max_completion_tokens to max output length if known --- garak/generators/openai.py | 28 +++++++++++++++++++++++++--- 1 file changed, 25 insertions(+), 3 deletions(-) diff --git a/garak/generators/openai.py b/garak/generators/openai.py index bb5a41547..b7740e471 100644 --- a/garak/generators/openai.py +++ b/garak/generators/openai.py @@ -114,12 +114,24 @@ "gpt-4o": 128000, "gpt-4o-2024-05-13": 128000, "gpt-4o-2024-08-06": 128000, - "gpt-4o-mini": 16384, + "gpt-4o-mini": 128000, "gpt-4o-mini-2024-07-18": 16384, - "o1-mini": 65536, + "o1": 200000, + "o1-mini": 128000, "o1-mini-2024-09-12": 65536, "o1-preview": 32768, "o1-preview-2024-09-12": 32768, + "o3-mini": 200000, +} + +output_max = { + "gpt-3.5-turbo": 4096, + "gpt-4": 8192, + "gpt-4o": 16384, + "o3-mini": 100000, + "o1": 100000, + "o1-mini": 65536, + "gpt-4o-mini": 16384, } @@ -230,6 +242,7 @@ def _call_model( # basic token boundary validation to ensure requests are not rejected for exceeding target context length max_completion_tokens = create_args.get("max_completion_tokens", None) if max_completion_tokens is not None: + # count tokens in prompt and ensure max_tokens requested is <= context_len allowed if ( hasattr(self, "context_len") @@ -240,12 +253,21 @@ def _call_model( f"Requested garak max_tokens {max_completion_tokens} exceeds context length {self.context_len}, reducing requested maximum" ) max_completion_tokens = self.context_len + create_args["max_completion_tokens"] = max_completion_tokens + + if self.name in output_max: + if max_completion_tokens > output_max[self.name]: + + logging.warning( + f"Requested max_completion_tokens {max_completion_tokens} exceeds max output {output_max[self.name]}, reducing requested maximum" + ) + max_completion_tokens = output_max[self.name] + create_args["max_completion_tokens"] = max_completion_tokens prompt_tokens = 0 # this should apply to messages object try: encoding = tiktoken.encoding_for_model(self.name) prompt_tokens = len(encoding.encode(prompt)) - print("prompt tokens:", prompt_tokens) except KeyError as e: prompt_tokens = int( len(prompt.split()) * 4 / 3 From f7a65366694a8efb741d08da26d6254fd9e894c2 Mon Sep 17 00:00:00 2001 From: Leon Derczynski Date: Wed, 5 Mar 2025 10:17:37 +0100 Subject: [PATCH 6/8] formatting --- garak/generators/openai.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/garak/generators/openai.py b/garak/generators/openai.py index b7740e471..764ab2ecf 100644 --- a/garak/generators/openai.py +++ b/garak/generators/openai.py @@ -255,14 +255,15 @@ def _call_model( max_completion_tokens = self.context_len create_args["max_completion_tokens"] = max_completion_tokens - if self.name in output_max: - if max_completion_tokens > output_max[self.name]: - - logging.warning( - f"Requested max_completion_tokens {max_completion_tokens} exceeds max output {output_max[self.name]}, reducing requested maximum" - ) - max_completion_tokens = output_max[self.name] - create_args["max_completion_tokens"] = max_completion_tokens + if ( + self.name in output_max + and max_completion_tokens > output_max[self.name] + ): + logging.warning( + f"Requested max_completion_tokens {max_completion_tokens} exceeds max output {output_max[self.name]}, reducing requested maximum" + ) + max_completion_tokens = output_max[self.name] + create_args["max_completion_tokens"] = max_completion_tokens prompt_tokens = 0 # this should apply to messages object try: From f1e5b94f4c54c99cb4180516881ec46dddf3e011 Mon Sep 17 00:00:00 2001 From: Leon Derczynski Date: Wed, 5 Mar 2025 14:25:21 +0100 Subject: [PATCH 7/8] update away from deprecated response limit key name --- tests/generators/test_openai_compatible.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/generators/test_openai_compatible.py b/tests/generators/test_openai_compatible.py index 0431ab425..bd5e82731 100644 --- a/tests/generators/test_openai_compatible.py +++ b/tests/generators/test_openai_compatible.py @@ -121,7 +121,7 @@ def test_validate_call_model_token_restrictions(openai_compat_mocks): generator._call_model("test values") resp_body = json.loads(respx_mock.routes[0].calls[0].request.content) assert ( - resp_body["max_tokens"] < generator.max_tokens + resp_body["max_completion_tokens"] <= generator.max_tokens ), "request max_tokens must account for prompt tokens" test_large_context = "" From 95452e0f3e862ac410305913b0ffbaede35564ba Mon Sep 17 00:00:00 2001 From: Jeffrey Martin Date: Thu, 6 Mar 2025 07:30:52 -0600 Subject: [PATCH 8/8] more refactor to support max_token and max_completion_tokens Signed-off-by: Jeffrey Martin --- garak/generators/openai.py | 131 ++++++++++++--------- tests/generators/test_openai_compatible.py | 95 +++++++++++++-- 2 files changed, 156 insertions(+), 70 deletions(-) diff --git a/garak/generators/openai.py b/garak/generators/openai.py index 764ab2ecf..88dda858d 100644 --- a/garak/generators/openai.py +++ b/garak/generators/openai.py @@ -153,7 +153,7 @@ class OpenAICompatible(Generator): "presence_penalty": 0.0, "seed": None, "stop": ["#", ";"], - "suppressed_params": {"max_tokens"}, # deprecated + "suppressed_params": set(), "retry_json": True, } @@ -184,6 +184,75 @@ def _clear_client(self): def _validate_config(self): pass + def _validate_token_args(self, create_args: dict, prompt: str) -> dict: + """Ensure maximum token limit compatibility with OpenAI create request""" + token_limit_key = "max_tokens" + fixed_cost = 0 + if ( + self.generator == self.client.chat.completions + and self.max_tokens is not None + ): + token_limit_key = "max_completion_tokens" + if not hasattr(self, "max_completion_tokens"): + create_args["max_completion_tokens"] = self.max_tokens + + create_args.pop( + "max_tokens", None + ) # remove deprecated value, utilize `max_completion_tokens` + # every reply is primed with <|start|>assistant<|message|> (3 toks) plus 1 for name change + # see https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb + # section 6 "Counting tokens for chat completions API calls" + fixed_cost = 7 + + # basic token boundary validation to ensure requests are not rejected for exceeding target context length + token_limit = create_args.pop(token_limit_key, None) + if token_limit is not None: + # Suppress max_tokens if greater than context_len + if ( + hasattr(self, "context_len") + and self.context_len is not None + and token_limit > self.context_len + ): + logging.warning( + f"Requested garak maximum tokens {token_limit} exceeds context length {self.context_len}, no limit will be applied to the request" + ) + token_limit = None + + if self.name in output_max and token_limit > output_max[self.name]: + logging.warning( + f"Requested maximum tokens {token_limit} exceeds max output {output_max[self.name]}, no limit will be applied to the request" + ) + token_limit = None + + if self.context_len is not None and token_limit is not None: + # count tokens in prompt and ensure token_limit requested is <= context_len or output_max allowed + prompt_tokens = 0 # this should apply to messages object + try: + encoding = tiktoken.encoding_for_model(self.name) + prompt_tokens = len(encoding.encode(prompt)) + except KeyError as e: + prompt_tokens = int( + len(prompt.split()) * 4 / 3 + ) # extra naive fallback 1 token ~= 3/4 of a word + + if (prompt_tokens + fixed_cost + token_limit > self.context_len) and ( + prompt_tokens + fixed_cost < self.context_len + ): + token_limit = self.context_len - prompt_tokens - fixed_cost + elif token_limit > prompt_tokens + fixed_cost: + token_limit = token_limit - prompt_tokens - fixed_cost + else: + raise garak.exception.GarakException( + "A response of %s toks plus prompt %s toks cannot be generated; API capped at context length %s toks" + % ( + self.max_tokens, + prompt_tokens + fixed_cost, + self.context_len, + ) + ) + create_args[token_limit_key] = token_limit + return create_args + def __init__(self, name="", config_root=_config): self.name = name self._load_config(config_root) @@ -229,69 +298,15 @@ def _call_model( create_args = {} if "n" not in self.suppressed_params: create_args["n"] = generations_this_call - for arg in inspect.signature(self.generator.create).parameters: + create_params = inspect.signature(self.generator.create).parameters + for arg in create_params: if arg == "model": create_args[arg] = self.name continue if hasattr(self, arg) and arg not in self.suppressed_params: create_args[arg] = getattr(self, arg) - if self.max_tokens is not None and not hasattr(self, "max_completion_tokens"): - create_args["max_completion_tokens"] = self.max_tokens - - # basic token boundary validation to ensure requests are not rejected for exceeding target context length - max_completion_tokens = create_args.get("max_completion_tokens", None) - if max_completion_tokens is not None: - - # count tokens in prompt and ensure max_tokens requested is <= context_len allowed - if ( - hasattr(self, "context_len") - and self.context_len is not None - and max_completion_tokens > self.context_len - ): - logging.warning( - f"Requested garak max_tokens {max_completion_tokens} exceeds context length {self.context_len}, reducing requested maximum" - ) - max_completion_tokens = self.context_len - create_args["max_completion_tokens"] = max_completion_tokens - - if ( - self.name in output_max - and max_completion_tokens > output_max[self.name] - ): - logging.warning( - f"Requested max_completion_tokens {max_completion_tokens} exceeds max output {output_max[self.name]}, reducing requested maximum" - ) - max_completion_tokens = output_max[self.name] - create_args["max_completion_tokens"] = max_completion_tokens - - prompt_tokens = 0 # this should apply to messages object - try: - encoding = tiktoken.encoding_for_model(self.name) - prompt_tokens = len(encoding.encode(prompt)) - except KeyError as e: - prompt_tokens = int( - len(prompt.split()) * 4 / 3 - ) # extra naive fallback 1 token ~= 3/4 of a word - - fixed_cost = 0 - if self.generator == self.client.chat.completions: - # every reply is primed with <|start|>assistant<|message|> (3 toks) plus 1 for name change - # see https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb - # section 6 "Counting tokens for chat completions API calls" - fixed_cost += 7 - - if self.context_len is not None and ( - prompt_tokens + fixed_cost + max_completion_tokens > self.context_len - ): - raise garak.exception.GarakException( - "A response of %s toks plus prompt %s toks cannot be generated; API capped at context length %s toks" - % ( - self.max_tokens, - prompt_tokens + fixed_cost, - self.context_len, - ) # max_completion_tokens will already be adjusted down - ) + create_args = self._validate_token_args(create_args, prompt) if self.generator == self.client.completions: if not isinstance(prompt, str): diff --git a/tests/generators/test_openai_compatible.py b/tests/generators/test_openai_compatible.py index bd5e82731..342029194 100644 --- a/tests/generators/test_openai_compatible.py +++ b/tests/generators/test_openai_compatible.py @@ -9,7 +9,7 @@ import inspect from collections.abc import Iterable -from garak.generators.openai import OpenAICompatible +from garak.generators.openai import OpenAICompatible, output_max, context_lengths # TODO: expand this when we have faster loading, currently to process all generator costs 30s for 3 tests @@ -104,9 +104,10 @@ def test_openai_multiprocessing(openai_compat_mocks, classname): assert result is not None -def test_validate_call_model_token_restrictions(openai_compat_mocks): +def test_validate_call_model_chat_token_restrictions(openai_compat_mocks): import lorem import json + import tiktoken from garak.exception import GarakException generator = build_test_instance(OpenAICompatible) @@ -119,25 +120,95 @@ def test_validate_call_model_token_restrictions(openai_compat_mocks): ) ) generator._call_model("test values") - resp_body = json.loads(respx_mock.routes[0].calls[0].request.content) + req_body = json.loads(respx_mock.routes[0].calls[0].request.content) assert ( - resp_body["max_completion_tokens"] <= generator.max_tokens - ), "request max_tokens must account for prompt tokens" + req_body["max_completion_tokens"] <= generator.max_tokens + ), "request max_completion_tokens must account for prompt tokens" test_large_context = "" - while len(test_large_context.split()) < generator.max_tokens: + encoding = tiktoken.encoding_for_model(MODEL_NAME) + while len(encoding.encode(test_large_context)) < generator.max_tokens: test_large_context += "\n".join(lorem.paragraph()) - large_context_len = len(test_large_context.split()) + large_context_len = len(encoding.encode(test_large_context)) + + generator.context_len = large_context_len * 2 + generator.max_tokens = generator.context_len * 2 + generator._call_model("test values") + req_body = json.loads(respx_mock.routes[0].calls[1].request.content) + assert ( + req_body.get("max_completion_tokens", None) is None + and req_body.get("max_tokens", None) is None + ), "request max_completion_tokens is suppressed when larger than context length" + + generator.max_tokens = large_context_len - int(large_context_len / 2) + generator.context_len = large_context_len with pytest.raises(GarakException) as exc_info: generator._call_model(test_large_context) - assert "cannot be created" in str( + assert "API capped" in str( exc_info.value ), "a prompt larger than max_tokens must raise exception" + max_output_model = "gpt-3.5-turbo" + generator.name = max_output_model + generator.max_tokens = output_max[max_output_model] * 2 + generator.context_len = generator.max_tokens * 2 + generator._call_model("test values") + req_body = json.loads(respx_mock.routes[0].calls[2].request.content) + assert ( + req_body.get("max_completion_tokens", None) is None + and req_body.get("max_tokens", None) is None + ), "request max_completion_tokens is suppressed when larger than output_max limited known model" + + generator.max_completion_tokens = int(output_max[max_output_model] / 2) + generator._call_model("test values") + req_body = json.loads(respx_mock.routes[0].calls[3].request.content) + assert ( + req_body["max_completion_tokens"] < generator.max_completion_tokens + and req_body.get("max_tokens", None) is None + ), "request max_completion_tokens is suppressed when larger than output_max limited known model" + + +def test_validate_call_model_completion_token_restrictions(openai_compat_mocks): + import lorem + import json + import tiktoken + from garak.exception import GarakException + + generator = build_test_instance(OpenAICompatible) + generator._load_client() + generator.generator = generator.client.completions + mock_url = getattr(generator, "uri", "https://api.openai.com/v1") + with respx.mock(base_url=mock_url, assert_all_called=False) as respx_mock: + mock_response = openai_compat_mocks["completion"] + respx_mock.post("/completions").mock( + return_value=httpx.Response( + mock_response["code"], json=mock_response["json"] + ) + ) + generator._call_model("test values") + req_body = json.loads(respx_mock.routes[0].calls[0].request.content) + assert ( + req_body["max_tokens"] <= generator.max_tokens + ), "request max_tokens must account for prompt tokens" + + test_large_context = "" + encoding = tiktoken.encoding_for_model(MODEL_NAME) + while len(encoding.encode(test_large_context)) < generator.max_tokens: + test_large_context += "\n".join(lorem.paragraph()) + large_context_len = len(encoding.encode(test_large_context)) + generator.context_len = large_context_len * 2 - generator.max_tokens = generator.context_len - (large_context_len / 2) + generator.max_tokens = generator.context_len * 2 generator._call_model("test values") - resp_body = json.loads(respx_mock.routes[0].calls[1].request.content) + req_body = json.loads(respx_mock.routes[0].calls[1].request.content) assert ( - resp_body["max_tokens"] < generator.context_len - ), "request max_tokens must be less than model context length" + req_body.get("max_tokens", None) is None + ), "request max_tokens is suppressed when larger than context length" + + generator.max_tokens = large_context_len - int(large_context_len / 2) + generator.context_len = large_context_len + with pytest.raises(GarakException) as exc_info: + generator._call_model(test_large_context) + assert "API capped" in str( + exc_info.value + ), "a prompt larger than max_tokens must raise exception"