Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
b2bd5c2
Start moving to niquests
vrslev Nov 22, 2024
afe94a8
Fix openai
vrslev Nov 22, 2024
a2af92f
Update
vrslev Nov 22, 2024
f8bcd39
Fix http integration
vrslev Nov 22, 2024
dae44e6
Add abstraction
vrslev Nov 24, 2024
01f7afd
Update
vrslev Nov 24, 2024
614b336
Update
vrslev Nov 24, 2024
817e17b
Update
vrslev Nov 24, 2024
d114a69
Update
vrslev Nov 24, 2024
d56b544
Update
vrslev Nov 24, 2024
a63624d
Update
vrslev Nov 24, 2024
aac2045
Update
vrslev Nov 24, 2024
4ff94aa
Fix typing
vrslev Nov 25, 2024
acca600
Update test
vrslev Nov 25, 2024
6b4a68c
Update
vrslev Nov 25, 2024
fd3beb9
Update
vrslev Nov 25, 2024
571ed37
Update
vrslev Nov 25, 2024
cbfece0
Update
vrslev Nov 25, 2024
f4f7c62
Update
vrslev Nov 25, 2024
8cd3863
Update
vrslev Nov 25, 2024
62d9d36
Update
vrslev Nov 25, 2024
114cf7f
Update
vrslev Nov 25, 2024
bf2acb3
Update
vrslev Nov 25, 2024
b0bfb39
Update
vrslev Nov 25, 2024
403587e
Update
vrslev Nov 25, 2024
b5bc67f
Add integration test for http.py
vrslev Nov 25, 2024
806c2ea
Add test recipe
vrslev Nov 25, 2024
bcb10fd
Update
vrslev Nov 25, 2024
e1be923
Update
vrslev Nov 25, 2024
b711d63
Update
vrslev Nov 25, 2024
3849680
Update
vrslev Nov 25, 2024
df83a5d
Merge branch 'main' into niquests
vrslev Nov 25, 2024
ff814ab
Put sse to http.py
vrslev Nov 25, 2024
1cbb793
Rename httpx to http or niquests
vrslev Nov 25, 2024
14d8661
Update
vrslev Nov 25, 2024
f729d4a
Update
vrslev Nov 25, 2024
192ceec
Update
vrslev Nov 25, 2024
f9d560c
Make HttpClient methods more effecient
vrslev Nov 25, 2024
58d5b3d
Remove asserts
vrslev Nov 25, 2024
1ca5279
Update
vrslev Nov 25, 2024
7dbbdde
Update
vrslev Nov 25, 2024
0db7fc3
Drop Python 3.13 (litestar doesn't support it)
vrslev Nov 25, 2024
ddc43a5
Update setup-uv
vrslev Nov 25, 2024
c8cfa76
Fix tests?
vrslev Nov 25, 2024
a10b0ac
Fix CI tests?
vrslev Nov 25, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/publish.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ jobs:
steps:
- uses: actions/checkout@v4
- uses: extractions/setup-just@v2
- uses: astral-sh/setup-uv@v3
- uses: astral-sh/setup-uv@v4
with:
enable-cache: true
cache-dependency-glob: "**/pyproject.toml"
Expand Down
7 changes: 3 additions & 4 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ jobs:
steps:
- uses: actions/checkout@v4
- uses: extractions/setup-just@v2
- uses: astral-sh/setup-uv@v3
- uses: astral-sh/setup-uv@v4
with:
enable-cache: true
cache-dependency-glob: "**/pyproject.toml"
Expand All @@ -32,13 +32,12 @@ jobs:
- "3.10"
- "3.11"
- "3.12"
- "3.13"
steps:
- uses: actions/checkout@v4
- uses: extractions/setup-just@v2
- uses: astral-sh/setup-uv@v3
- uses: astral-sh/setup-uv@v4
with:
enable-cache: true
cache-dependency-glob: "**/pyproject.toml"
- run: uv python install ${{ matrix.python-version }}
- run: uv venv --python ${{ matrix.python-version }}
- run: just test -vv
10 changes: 10 additions & 0 deletions Justfile
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,18 @@ lint:
uv run --group lint ruff format
uv run --group lint mypy .

_test-no-http *args:
uv run pytest --ignore tests/test_http.py {{ args }}

test *args:
#!/bin/bash
uv run litestar --app tests.testing_app:app run &
APP_PID=$!
uv run pytest {{ args }}
TEST_RESULT=$?
kill $APP_PID
wait $APP_PID 2>/dev/null
exit $TEST_RESULT

publish:
rm -rf dist
Expand Down
10 changes: 5 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -162,23 +162,23 @@ async with any_llm_client.OpenAIClient(config, ...) as client:
#### Timeouts, proxy & other HTTP settings


Pass custom [HTTPX](https://www.python-httpx.org) kwargs to `any_llm_client.get_client()`:
Pass custom [niquests](https://niquests.readthedocs.io) kwargs to `any_llm_client.get_client()`:

```python
import httpx
import urllib3

import any_llm_client


async with any_llm_client.get_client(
...,
mounts={"https://api.openai.com": httpx.AsyncHTTPTransport(proxy="http://localhost:8030")},
timeout=httpx.Timeout(None, connect=5.0),
proxies={"https://api.openai.com": "http://localhost:8030"},
timeout=urllib3.Timeout(total=10.0, connect=5.0),
) as client:
...
```

Default timeout is `httpx.Timeout(None, connect=5.0)` (5 seconds on connect, unlimited on read, write or pool).
Default timeout is `urllib3.Timeout(total=None, connect=5.0)`.

#### Retries

Expand Down
68 changes: 29 additions & 39 deletions any_llm_client/clients/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@
from http import HTTPStatus

import annotated_types
import httpx
import httpx_sse
import niquests
import pydantic
import typing_extensions

Expand All @@ -20,8 +19,9 @@
OutOfTokensOrSymbolsError,
UserMessage,
)
from any_llm_client.http import get_http_client_from_kwargs, make_http_request, make_streaming_http_request
from any_llm_client.http import HttpClient, HttpStatusError
from any_llm_client.retry import RequestRetryConfig
from any_llm_client.sse import parse_sse_events


OPENAI_AUTH_TOKEN_ENV_NAME: typing.Final = "ANY_LLM_CLIENT_OPENAI_AUTH_TOKEN"
Expand Down Expand Up @@ -99,31 +99,34 @@ def _make_user_assistant_alternate_messages(
yield ChatCompletionsMessage(role=current_message_role, content="\n\n".join(current_message_content_chunks))


def _handle_status_error(*, status_code: int, content: bytes) -> typing.NoReturn:
if status_code == HTTPStatus.BAD_REQUEST and b"Please reduce the length of the messages" in content: # vLLM
raise OutOfTokensOrSymbolsError(response_content=content)
raise LLMError(response_content=content)
def _handle_status_error(error: HttpStatusError) -> typing.NoReturn:
if (
error.status_code == HTTPStatus.BAD_REQUEST and b"Please reduce the length of the messages" in error.content
): # vLLM
raise OutOfTokensOrSymbolsError(response_content=error.content)
raise LLMError(response_content=error.content)


@dataclasses.dataclass(slots=True, init=False)
class OpenAIClient(LLMClient):
config: OpenAIConfig
httpx_client: httpx.AsyncClient
http_client: HttpClient
request_retry: RequestRetryConfig

def __init__(
self,
config: OpenAIConfig,
*,
request_retry: RequestRetryConfig | None = None,
**httpx_kwargs: typing.Any, # noqa: ANN401
**niquests_kwargs: typing.Any, # noqa: ANN401
) -> None:
self.config = config
self.request_retry = request_retry or RequestRetryConfig()
self.httpx_client = get_http_client_from_kwargs(httpx_kwargs)
self.http_client = HttpClient(
request_retry=request_retry or RequestRetryConfig(), niquests_kwargs=niquests_kwargs
)

def _build_request(self, payload: dict[str, typing.Any]) -> httpx.Request:
return self.httpx_client.build_request(
def _build_request(self, payload: dict[str, typing.Any]) -> niquests.Request:
return niquests.Request(
method="POST",
url=str(self.config.url),
json=payload,
Expand Down Expand Up @@ -152,24 +155,17 @@ async def request_llm_message(
**extra or {},
).model_dump(mode="json")
try:
response: typing.Final = await make_http_request(
httpx_client=self.httpx_client,
request_retry=self.request_retry,
build_request=lambda: self._build_request(payload),
)
except httpx.HTTPStatusError as exception:
_handle_status_error(status_code=exception.response.status_code, content=exception.response.content)
try:
return ChatCompletionsNotStreamingResponse.model_validate_json(response.content).choices[0].message.content
finally:
await response.aclose()
response: typing.Final = await self.http_client.request(self._build_request(payload))
except HttpStatusError as exception:
_handle_status_error(exception)
return ChatCompletionsNotStreamingResponse.model_validate_json(response).choices[0].message.content

async def _iter_partial_responses(self, response: httpx.Response) -> typing.AsyncIterable[str]:
async def _iter_partial_responses(self, response: typing.AsyncIterable[bytes]) -> typing.AsyncIterable[str]:
text_chunks: typing.Final = []
async for event in httpx_sse.EventSource(response).aiter_sse():
if event.data == "[DONE]":
async for one_event in parse_sse_events(response):
if one_event.data == "[DONE]":
break
validated_response = ChatCompletionsStreamingEvent.model_validate_json(event.data)
validated_response = ChatCompletionsStreamingEvent.model_validate_json(one_event.data)
if not (one_chunk := validated_response.choices[0].delta.content):
continue
text_chunks.append(one_chunk)
Expand All @@ -187,19 +183,13 @@ async def stream_llm_partial_messages(
**extra or {},
).model_dump(mode="json")
try:
async with make_streaming_http_request(
httpx_client=self.httpx_client,
request_retry=self.request_retry,
build_request=lambda: self._build_request(payload),
) as response:
async with self.http_client.stream(request=self._build_request(payload)) as response:
yield self._iter_partial_responses(response)
except httpx.HTTPStatusError as exception:
content: typing.Final = await exception.response.aread()
await exception.response.aclose()
_handle_status_error(status_code=exception.response.status_code, content=content)
except HttpStatusError as exception:
_handle_status_error(exception)

async def __aenter__(self) -> typing_extensions.Self:
await self.httpx_client.__aenter__()
await self.http_client.__aenter__()
return self

async def __aexit__(
Expand All @@ -208,4 +198,4 @@ async def __aexit__(
exc_value: BaseException | None,
traceback: types.TracebackType | None,
) -> None:
await self.httpx_client.__aexit__(exc_type=exc_type, exc_value=exc_value, traceback=traceback)
await self.http_client.__aexit__(exc_type=exc_type, exc_value=exc_value, traceback=traceback)
66 changes: 28 additions & 38 deletions any_llm_client/clients/yandexgpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@
from http import HTTPStatus

import annotated_types
import httpx
import niquests
import pydantic
import typing_extensions

from any_llm_client.core import LLMClient, LLMConfig, LLMError, Message, OutOfTokensOrSymbolsError, UserMessage
from any_llm_client.http import get_http_client_from_kwargs, make_http_request, make_streaming_http_request
from any_llm_client.http import HttpClient, HttpStatusError
from any_llm_client.retry import RequestRetryConfig


Expand Down Expand Up @@ -61,34 +61,34 @@ class YandexGPTResponse(pydantic.BaseModel):
result: YandexGPTResult


def _handle_status_error(*, status_code: int, content: bytes) -> typing.NoReturn:
if status_code == HTTPStatus.BAD_REQUEST and (
b"number of input tokens must be no more than" in content
or (b"text length is" in content and b"which is outside the range" in content)
def _handle_status_error(error: HttpStatusError) -> typing.NoReturn:
if error.status_code == HTTPStatus.BAD_REQUEST and (
b"number of input tokens must be no more than" in error.content
or (b"text length is" in error.content and b"which is outside the range" in error.content)
):
raise OutOfTokensOrSymbolsError(response_content=content)
raise LLMError(response_content=content)
raise OutOfTokensOrSymbolsError(response_content=error.content)
raise LLMError(response_content=error.content)


@dataclasses.dataclass(slots=True, init=False)
class YandexGPTClient(LLMClient):
config: YandexGPTConfig
httpx_client: httpx.AsyncClient
request_retry: RequestRetryConfig
http_client: HttpClient

def __init__(
self,
config: YandexGPTConfig,
*,
request_retry: RequestRetryConfig | None = None,
**httpx_kwargs: typing.Any, # noqa: ANN401
**niquests_kwargs: typing.Any, # noqa: ANN401
) -> None:
self.config = config
self.request_retry = request_retry or RequestRetryConfig()
self.httpx_client = get_http_client_from_kwargs(httpx_kwargs)
self.http_client = HttpClient(
request_retry=request_retry or RequestRetryConfig(), niquests_kwargs=niquests_kwargs
)

def _build_request(self, payload: dict[str, typing.Any]) -> httpx.Request:
return self.httpx_client.build_request(
def _build_request(self, payload: dict[str, typing.Any]) -> niquests.Request:
return niquests.Request(
method="POST",
url=str(self.config.url),
json=payload,
Expand Down Expand Up @@ -121,18 +121,14 @@ async def request_llm_message(
)

try:
response: typing.Final = await make_http_request(
httpx_client=self.httpx_client,
request_retry=self.request_retry,
build_request=lambda: self._build_request(payload),
)
except httpx.HTTPStatusError as exception:
_handle_status_error(status_code=exception.response.status_code, content=exception.response.content)

return YandexGPTResponse.model_validate_json(response.content).result.alternatives[0].message.text

async def _iter_completion_messages(self, response: httpx.Response) -> typing.AsyncIterable[str]:
async for one_line in response.aiter_lines():
response: typing.Final = await self.http_client.request(self._build_request(payload))
except HttpStatusError as exception:
_handle_status_error(exception)

return YandexGPTResponse.model_validate_json(response).result.alternatives[0].message.text

async def _iter_completion_messages(self, response: typing.AsyncIterable[bytes]) -> typing.AsyncIterable[str]:
async for one_line in response:
validated_response = YandexGPTResponse.model_validate_json(one_line)
yield validated_response.result.alternatives[0].message.text

Expand All @@ -145,19 +141,13 @@ async def stream_llm_partial_messages(
)

try:
async with make_streaming_http_request(
httpx_client=self.httpx_client,
request_retry=self.request_retry,
build_request=lambda: self._build_request(payload),
) as response:
async with self.http_client.stream(request=self._build_request(payload)) as response:
yield self._iter_completion_messages(response)
except httpx.HTTPStatusError as exception:
content: typing.Final = await exception.response.aread()
await exception.response.aclose()
_handle_status_error(status_code=exception.response.status_code, content=content)
except HttpStatusError as exception:
_handle_status_error(exception)

async def __aenter__(self) -> typing_extensions.Self:
await self.httpx_client.__aenter__()
await self.http_client.__aenter__()
return self

async def __aexit__(
Expand All @@ -166,4 +156,4 @@ async def __aexit__(
exc_value: BaseException | None,
traceback: types.TracebackType | None,
) -> None:
await self.httpx_client.__aexit__(exc_type=exc_type, exc_value=exc_value, traceback=traceback)
await self.http_client.__aexit__(exc_type=exc_type, exc_value=exc_value, traceback=traceback)
24 changes: 12 additions & 12 deletions any_llm_client/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,18 +48,6 @@ def AssistantMessage(text: str) -> Message: # noqa: N802
return Message(role=MessageRole.assistant, text=text)


@dataclasses.dataclass
class LLMError(Exception):
response_content: bytes

def __str__(self) -> str:
return self.__repr__().removeprefix(self.__class__.__name__)


@dataclasses.dataclass
class OutOfTokensOrSymbolsError(LLMError): ...


class LLMConfig(pydantic.BaseModel):
model_config = pydantic.ConfigDict(protected_namespaces=())
api_type: str
Expand All @@ -83,3 +71,15 @@ async def __aexit__(
exc_value: BaseException | None,
traceback: types.TracebackType | None,
) -> None: ...


@dataclasses.dataclass
class LLMError(Exception):
response_content: bytes

def __str__(self) -> str:
return self.__repr__().removeprefix(self.__class__.__name__)


@dataclasses.dataclass
class OutOfTokensOrSymbolsError(LLMError): ...
Loading