Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,9 @@ async with any_llm_client.OpenAIClient(config, ...) as client:

- `any_llm_client.LLMError` or `any_llm_client.OutOfTokensOrSymbolsError` when the LLM API responds with a failed HTTP status,
- `any_llm_client.LLMRequestValidationError` when images are passed to YandexGPT client.
- `any_llm_client.LLMResponseValidationError` when invalid response come from LLM API (reraised from `pydantic.ValidationError`).

All these exceptions inherit from the base class `any_llm_client.AnyLLMClientError`.

#### Timeouts, proxy & other HTTP settings

Expand Down
4 changes: 4 additions & 0 deletions any_llm_client/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from any_llm_client.clients.yandexgpt import YandexGPTClient, YandexGPTConfig
from any_llm_client.core import (
AnyContentItem,
AnyLLMClientError,
AssistantMessage,
ContentItemList,
ImageContentItem,
Expand All @@ -11,6 +12,7 @@
LLMError,
LLMRequestValidationError,
LLMResponse,
LLMResponseValidationError,
Message,
MessageRole,
OutOfTokensOrSymbolsError,
Expand All @@ -24,6 +26,7 @@

__all__ = [
"AnyContentItem",
"AnyLLMClientError",
"AnyLLMConfig",
"AssistantMessage",
"ContentItemList",
Expand All @@ -33,6 +36,7 @@
"LLMError",
"LLMRequestValidationError",
"LLMResponse",
"LLMResponseValidationError",
"Message",
"MessageRole",
"MockLLMClient",
Expand Down
25 changes: 20 additions & 5 deletions any_llm_client/clients/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
LLMConfigValue,
LLMError,
LLMResponse,
LLMResponseValidationError,
Message,
MessageRole,
OutOfTokensOrSymbolsError,
Expand Down Expand Up @@ -236,27 +237,41 @@ async def request_llm_message(
)
except httpx.HTTPStatusError as exception:
_handle_status_error(status_code=exception.response.status_code, content=exception.response.content)

try:
validated_message_model: typing.Final = (
ChatCompletionsNotStreamingResponse.model_validate_json(response.content).choices[0].message
)
return LLMResponse(
content=validated_message_model.content,
reasoning_content=validated_message_model.reasoning_content,
)
except pydantic.ValidationError as validation_error:
raise LLMResponseValidationError(
response_content=response.content, original_error=validation_error
) from validation_error
finally:
await response.aclose()

return LLMResponse(
content=validated_message_model.content,
reasoning_content=validated_message_model.reasoning_content,
)

async def _iter_response_chunks(self, response: httpx.Response) -> typing.AsyncIterable[LLMResponse]:
async for event in httpx_sse.EventSource(response).aiter_sse():
if event.data == "[DONE]":
break
validated_response = ChatCompletionsStreamingEvent.model_validate_json(event.data)

try:
validated_response = ChatCompletionsStreamingEvent.model_validate_json(event.data)
except pydantic.ValidationError as validation_error:
raise LLMResponseValidationError(
response_content=response.content, original_error=validation_error
) from validation_error

if not (
(validated_delta := validated_response.choices[0].delta)
and (validated_delta.content or validated_delta.reasoning_content)
):
continue

yield LLMResponse(content=validated_delta.content, reasoning_content=validated_delta.reasoning_content)

@contextlib.asynccontextmanager
Expand Down
20 changes: 16 additions & 4 deletions any_llm_client/clients/yandexgpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
LLMError,
LLMRequestValidationError,
LLMResponse,
LLMResponseValidationError,
Message,
MessageRole,
OutOfTokensOrSymbolsError,
Expand Down Expand Up @@ -173,14 +174,25 @@ async def request_llm_message(
except httpx.HTTPStatusError as exception:
_handle_status_error(status_code=exception.response.status_code, content=exception.response.content)

return LLMResponse(
content=YandexGPTResponse.model_validate_json(response.content).result.alternatives[0].message.text,
)
try:
validated_response: typing.Final = YandexGPTResponse.model_validate_json(response.content)
except pydantic.ValidationError as validation_error:
raise LLMResponseValidationError(
response_content=response.content, original_error=validation_error
) from validation_error

return LLMResponse(content=validated_response.result.alternatives[0].message.text)

async def _iter_response_chunks(self, response: httpx.Response) -> typing.AsyncIterable[LLMResponse]:
previous_cursor = 0
async for one_line in response.aiter_lines():
validated_response = YandexGPTResponse.model_validate_json(one_line)
try:
validated_response = YandexGPTResponse.model_validate_json(one_line)
except pydantic.ValidationError as validation_error:
raise LLMResponseValidationError(
response_content=response.content, original_error=validation_error
) from validation_error

response_text = validated_response.result.alternatives[0].message.text
yield LLMResponse(content=response_text[previous_cursor:])
previous_cursor = len(response_text)
Expand Down
6 changes: 6 additions & 0 deletions any_llm_client/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,3 +147,9 @@ class OutOfTokensOrSymbolsError(LLMError): ...
@dataclasses.dataclass
class LLMRequestValidationError(AnyLLMClientError):
message: str


@dataclasses.dataclass
class LLMResponseValidationError(AnyLLMClientError):
response_content: bytes
original_error: pydantic.ValidationError
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ dynamic = ["version"]
dev = [
"anyio",
"faker",
"polyfactory",
"polyfactory==2.20.0",
"pydantic-settings",
"pytest-cov",
"pytest",
Expand Down
6 changes: 3 additions & 3 deletions tests/test_openai_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import faker
import httpx
import pydantic
import pytest
from polyfactory.factories.pydantic_factory import ModelFactory

Expand All @@ -17,6 +16,7 @@
OneStreamingChoice,
OneStreamingChoiceDelta,
)
from any_llm_client.core import LLMResponseValidationError
from tests.conftest import LLMFuncRequest, LLMFuncRequestFactory, consume_llm_message_chunks


Expand Down Expand Up @@ -58,7 +58,7 @@ async def test_fails_without_alternatives(self) -> None:
transport=httpx.MockTransport(lambda _: response),
)

with pytest.raises(pydantic.ValidationError):
with pytest.raises(LLMResponseValidationError):
await client.request_llm_message(**LLMFuncRequestFactory.build())


Expand Down Expand Up @@ -118,7 +118,7 @@ async def test_fails_without_alternatives(self) -> None:
transport=httpx.MockTransport(lambda _: response),
)

with pytest.raises(pydantic.ValidationError):
with pytest.raises(LLMResponseValidationError):
await consume_llm_message_chunks(client.stream_llm_message_chunks(**LLMFuncRequestFactory.build()))


Expand Down
5 changes: 3 additions & 2 deletions tests/test_yandexgpt_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import any_llm_client
from any_llm_client.clients.yandexgpt import YandexGPTAlternative, YandexGPTMessage, YandexGPTResponse, YandexGPTResult
from any_llm_client.core import LLMResponseValidationError
from tests.conftest import LLMFuncRequest, LLMFuncRequestFactory, consume_llm_message_chunks


Expand Down Expand Up @@ -95,7 +96,7 @@ async def test_fails_without_alternatives(self) -> None:
transport=httpx.MockTransport(lambda _: response),
)

with pytest.raises(pydantic.ValidationError):
with pytest.raises(LLMResponseValidationError):
await client.request_llm_message(**LLMFuncRequestWithTextContentMessagesFactory.build())


Expand Down Expand Up @@ -153,7 +154,7 @@ async def test_fails_without_alternatives(self) -> None:
transport=httpx.MockTransport(lambda _: response),
)

with pytest.raises(pydantic.ValidationError):
with pytest.raises(LLMResponseValidationError):
await consume_llm_message_chunks(
client.stream_llm_message_chunks(**LLMFuncRequestWithTextContentMessagesFactory.build()),
)
Expand Down