Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions any_llm_client/clients/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,12 @@ def _handle_status_error(*, status_code: int, content: bytes) -> typing.NoReturn
raise LLMError(response_content=content)


def _handle_validation_error(*, content: bytes, original_error: pydantic.ValidationError) -> typing.NoReturn:
if b"is too long to fit into the model" in content: # vLLM
raise OutOfTokensOrSymbolsError(response_content=content)
raise LLMResponseValidationError(response_content=content, original_error=original_error)


@dataclasses.dataclass(slots=True, init=False)
class OpenAIClient(LLMClient):
config: OpenAIConfig
Expand Down Expand Up @@ -243,9 +249,7 @@ async def request_llm_message(
ChatCompletionsNotStreamingResponse.model_validate_json(response.content).choices[0].message
)
except pydantic.ValidationError as validation_error:
raise LLMResponseValidationError(
response_content=response.content, original_error=validation_error
) from validation_error
_handle_validation_error(content=response.content, original_error=validation_error)
finally:
await response.aclose()

Expand All @@ -262,9 +266,7 @@ async def _iter_response_chunks(self, response: httpx.Response) -> typing.AsyncI
try:
validated_response = ChatCompletionsStreamingEvent.model_validate_json(event.data)
except pydantic.ValidationError as validation_error:
raise LLMResponseValidationError(
response_content=event.data.encode(), original_error=validation_error
) from validation_error
_handle_validation_error(content=event.data.encode(), original_error=validation_error)

if not (
(validated_delta := validated_response.choices[0].delta)
Expand Down
30 changes: 29 additions & 1 deletion tests/test_openai_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ async def test_fails_with_unknown_error(self, stream: bool, status_code: int) ->
b'{"object":"error","message":"This model\'s maximum context length is 16384 tokens. However, you requested 100000 tokens in the messages, Please reduce the length of the messages.","type":"BadRequestError","param":null,"code":400}', # noqa: E501
],
)
async def test_fails_with_out_of_tokens_error(self, stream: bool, content: bytes | None) -> None:
async def test_fails_with_out_of_tokens_error_on_status(self, stream: bool, content: bytes) -> None:
response: typing.Final = httpx.Response(400, content=content)
client: typing.Final = any_llm_client.get_client(
OpenAIConfigFactory.build(),
Expand All @@ -165,6 +165,34 @@ async def test_fails_with_out_of_tokens_error(self, stream: bool, content: bytes
with pytest.raises(any_llm_client.OutOfTokensOrSymbolsError):
await coroutine

@pytest.mark.parametrize("stream", [True, False])
@pytest.mark.parametrize(
"content",
[
b'{"error": {"object": "error", "message": "The prompt (total length 6287) is too long to fit into the model (context length 4096). Make sure that `max_model_len` is no smaller than the number of text tokens plus multimodal tokens. For image inputs, the number of image tokens depends on the number of images, and possibly their aspect ratios as well.", "type": "BadRequestError", "param": null, "code": 400}}\n', # noqa: E501
b'{"object": "error", "message": "The prompt (total length 43431) is too long to fit into the model (context length 8192). Make sure that `max_model_len` is no smaller than the number of text tokens plus multimodal tokens. For image inputs, the number of image tokens depends on the number of images, and possibly their aspect ratios as well.", "type": "BadRequestError", "param": null, "code": 400}\n', # noqa: E501
],
)
async def test_fails_with_out_of_tokens_error_on_validation(self, stream: bool, content: bytes) -> None:
response: typing.Final = httpx.Response(
200,
content=f"data: {content.decode()}\n\n" if stream else content,
headers={"Content-Type": "text/event-stream"} if stream else None,
)
client: typing.Final = any_llm_client.get_client(
OpenAIConfigFactory.build(),
transport=httpx.MockTransport(lambda _: response),
)

coroutine: typing.Final = (
consume_llm_message_chunks(client.stream_llm_message_chunks(**LLMFuncRequestFactory.build()))
if stream
else client.request_llm_message(**LLMFuncRequestFactory.build())
)

with pytest.raises(any_llm_client.OutOfTokensOrSymbolsError):
await coroutine


class TestOpenAIMessageAlternation:
@pytest.mark.parametrize(
Expand Down