Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions any_llm_client/clients/mock.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import pydantic
import typing_extensions

from any_llm_client.core import LLMClient, LLMConfig, Message
from any_llm_client.core import LLMClient, LLMConfig, LLMConfigValue, Message


class MockLLMConfig(LLMConfig):
Expand All @@ -23,7 +23,7 @@ async def request_llm_message(
self,
messages: str | list[Message], # noqa: ARG002
*,
temperature: float = 0.2, # noqa: ARG002
temperature: float = LLMConfigValue(attr="temperature"), # noqa: ARG002
extra: dict[str, typing.Any] | None = None, # noqa: ARG002
) -> str:
return self.config.response_message
Expand All @@ -37,7 +37,7 @@ async def stream_llm_message_chunks(
self,
messages: str | list[Message], # noqa: ARG002
*,
temperature: float = 0.2, # noqa: ARG002
temperature: float = LLMConfigValue(attr="temperature"), # noqa: ARG002
extra: dict[str, typing.Any] | None = None, # noqa: ARG002
) -> typing.AsyncIterator[typing.AsyncIterable[str]]:
yield self._iter_config_stream_messages()
Expand Down
42 changes: 27 additions & 15 deletions any_llm_client/clients/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from any_llm_client.core import (
LLMClient,
LLMConfig,
LLMConfigValue,
LLMError,
Message,
MessageRole,
Expand Down Expand Up @@ -50,7 +51,7 @@ class ChatCompletionsRequest(pydantic.BaseModel):
stream: bool
model: str
messages: list[ChatCompletionsMessage]
temperature: float = 0.2
temperature: float


class OneStreamingChoiceDelta(pydantic.BaseModel):
Expand Down Expand Up @@ -142,16 +143,27 @@ def _prepare_messages(self, messages: str | list[Message]) -> list[ChatCompletio
else list(initial_messages)
)

async def request_llm_message(
self, messages: str | list[Message], *, temperature: float = 0.2, extra: dict[str, typing.Any] | None = None
) -> str:
payload: typing.Final = ChatCompletionsRequest(
stream=False,
def _prepare_payload(
self, *, messages: str | list[Message], temperature: float, stream: bool, extra: dict[str, typing.Any] | None
) -> dict[str, typing.Any]:
return ChatCompletionsRequest(
stream=stream,
model=self.config.model_name,
messages=self._prepare_messages(messages),
temperature=temperature,
temperature=self.config._resolve_request_temperature(temperature), # noqa: SLF001
**self.config.request_extra | (extra or {}),
).model_dump(mode="json")

async def request_llm_message(
self,
messages: str | list[Message],
*,
temperature: float = LLMConfigValue(attr="temperature"),
extra: dict[str, typing.Any] | None = None,
) -> str:
payload: typing.Final = self._prepare_payload(
messages=messages, temperature=temperature, stream=False, extra=extra
)
try:
response: typing.Final = await make_http_request(
httpx_client=self.httpx_client,
Expand All @@ -176,15 +188,15 @@ async def _iter_response_chunks(self, response: httpx.Response) -> typing.AsyncI

@contextlib.asynccontextmanager
async def stream_llm_message_chunks(
self, messages: str | list[Message], *, temperature: float = 0.2, extra: dict[str, typing.Any] | None = None
self,
messages: str | list[Message],
*,
temperature: float = LLMConfigValue(attr="temperature"),
extra: dict[str, typing.Any] | None = None,
) -> typing.AsyncIterator[typing.AsyncIterable[str]]:
payload: typing.Final = ChatCompletionsRequest(
stream=True,
model=self.config.model_name,
messages=self._prepare_messages(messages),
temperature=temperature,
**self.config.request_extra | (extra or {}),
).model_dump(mode="json")
payload: typing.Final = self._prepare_payload(
messages=messages, temperature=temperature, stream=True, extra=extra
)
try:
async with make_streaming_http_request(
httpx_client=self.httpx_client,
Expand Down
32 changes: 25 additions & 7 deletions any_llm_client/clients/yandexgpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,15 @@
import pydantic
import typing_extensions

from any_llm_client.core import LLMClient, LLMConfig, LLMError, Message, OutOfTokensOrSymbolsError, UserMessage
from any_llm_client.core import (
LLMClient,
LLMConfig,
LLMConfigValue,
LLMError,
Message,
OutOfTokensOrSymbolsError,
UserMessage,
)
from any_llm_client.http import get_http_client_from_kwargs, make_http_request, make_streaming_http_request
from any_llm_client.retry import RequestRetryConfig

Expand Down Expand Up @@ -38,7 +46,7 @@ class YandexGPTConfig(LLMConfig):

class YandexGPTCompletionOptions(pydantic.BaseModel):
stream: bool
temperature: float = 0.2
temperature: float
max_tokens: int = pydantic.Field(gt=0, alias="maxTokens")


Expand Down Expand Up @@ -99,22 +107,28 @@ def _prepare_payload(
self,
*,
messages: str | list[Message],
temperature: float = 0.2,
temperature: float,
stream: bool,
extra: dict[str, typing.Any] | None,
) -> dict[str, typing.Any]:
messages = [UserMessage(messages)] if isinstance(messages, str) else messages
return YandexGPTRequest(
modelUri=f"gpt://{self.config.folder_id}/{self.config.model_name}/{self.config.model_version}",
completionOptions=YandexGPTCompletionOptions(
stream=stream, temperature=temperature, maxTokens=self.config.max_tokens
stream=stream,
temperature=self.config._resolve_request_temperature(temperature), # noqa: SLF001
maxTokens=self.config.max_tokens,
),
messages=messages,
**extra or {},
**self.config.request_extra | (extra or {}),
).model_dump(mode="json", by_alias=True)

async def request_llm_message(
self, messages: str | list[Message], *, temperature: float = 0.2, extra: dict[str, typing.Any] | None = None
self,
messages: str | list[Message],
*,
temperature: float = LLMConfigValue(attr="temperature"),
extra: dict[str, typing.Any] | None = None,
) -> str:
payload: typing.Final = self._prepare_payload(
messages=messages, temperature=temperature, stream=False, extra=extra
Expand All @@ -141,7 +155,11 @@ async def _iter_response_chunks(self, response: httpx.Response) -> typing.AsyncI

@contextlib.asynccontextmanager
async def stream_llm_message_chunks(
self, messages: str | list[Message], *, temperature: float = 0.2, extra: dict[str, typing.Any] | None = None
self,
messages: str | list[Message],
*,
temperature: float = LLMConfigValue(attr="temperature"),
extra: dict[str, typing.Any] | None = None,
) -> typing.AsyncIterator[typing.AsyncIterable[str]]:
payload: typing.Final = self._prepare_payload(
messages=messages, temperature=temperature, stream=True, extra=extra
Expand Down
34 changes: 32 additions & 2 deletions any_llm_client/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,17 +51,47 @@ def AssistantMessage(text: str) -> Message: # noqa: N802
class LLMConfig(pydantic.BaseModel):
model_config = pydantic.ConfigDict(protected_namespaces=())
api_type: str
temperature: float = 0.2
request_extra: dict[str, typing.Any] = pydantic.Field(default_factory=dict)

def _resolve_request_temperature(self, temperature_arg_value: float) -> float:
return (
self.temperature
if isinstance(temperature_arg_value, LLMConfigValue) # type: ignore[arg-type]
else temperature_arg_value
)


if typing.TYPE_CHECKING:

def LLMConfigValue(*, attr: str) -> typing.Any: # noqa: ANN401, N802
"""Defaults to value from LLMConfig."""
else:

@dataclasses.dataclass(kw_only=True, frozen=True, slots=True)
class LLMConfigValue:
"""Defaults to value from LLMConfig."""

attr: str


@dataclasses.dataclass(slots=True, init=False)
class LLMClient(typing.Protocol):
async def request_llm_message(
self, messages: str | list[Message], *, temperature: float = 0.2, extra: dict[str, typing.Any] | None = None
self,
messages: str | list[Message],
*,
temperature: float = LLMConfigValue(attr="temperature"),
extra: dict[str, typing.Any] | None = None,
) -> str: ... # raises LLMError

@contextlib.asynccontextmanager
def stream_llm_message_chunks(
self, messages: str | list[Message], *, temperature: float = 0.2, extra: dict[str, typing.Any] | None = None
self,
messages: str | list[Message],
*,
temperature: float = LLMConfigValue(attr="temperature"),
extra: dict[str, typing.Any] | None = None,
) -> typing.AsyncIterator[typing.AsyncIterable[str]]: ... # raises LLMError

async def __aenter__(self) -> typing_extensions.Self: ...
Expand Down
25 changes: 22 additions & 3 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import pytest
import stamina
import typing_extensions
from polyfactory.factories.typed_dict_factory import TypedDictFactory

import any_llm_client
Expand All @@ -20,11 +21,29 @@ def _deactivate_retries() -> None:

class LLMFuncRequest(typing.TypedDict):
messages: str | list[any_llm_client.Message]
temperature: float
extra: dict[str, typing.Any] | None
temperature: typing_extensions.NotRequired[float]
extra: typing_extensions.NotRequired[dict[str, typing.Any] | None]


class LLMFuncRequestFactory(TypedDictFactory[LLMFuncRequest]): ...
class LLMFuncRequestFactory(TypedDictFactory[LLMFuncRequest]):
# Polyfactory ignores `NotRequired`:
# https://github.com/litestar-org/polyfactory/issues/656
@classmethod
def coverage(cls, **kwargs: typing.Any) -> typing.Iterator[LLMFuncRequest]: # noqa: ANN401
yield from super().coverage(**kwargs)

first_additional_example: typing.Final = cls.build(**kwargs)
first_additional_example.pop("temperature")
yield first_additional_example

second_additional_example: typing.Final = cls.build(**kwargs)
second_additional_example.pop("extra")
yield second_additional_example

third_additional_example: typing.Final = cls.build(**kwargs)
third_additional_example.pop("extra")
third_additional_example.pop("temperature")
yield third_additional_example


async def consume_llm_message_chunks(
Expand Down
15 changes: 8 additions & 7 deletions tests/test_mock_client.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,26 @@
import typing

import pytest
from polyfactory.factories.pydantic_factory import ModelFactory

import any_llm_client
from tests.conftest import LLMFuncRequestFactory, consume_llm_message_chunks
from tests.conftest import LLMFuncRequest, LLMFuncRequestFactory, consume_llm_message_chunks


class MockLLMConfigFactory(ModelFactory[any_llm_client.MockLLMConfig]): ...


async def test_mock_client_request_llm_message_returns_config_value() -> None:
@pytest.mark.parametrize("func_request", LLMFuncRequestFactory.coverage())
async def test_mock_client_request_llm_message_returns_config_value(func_request: LLMFuncRequest) -> None:
config: typing.Final = MockLLMConfigFactory.build()
response: typing.Final = await any_llm_client.get_client(config).request_llm_message(
**LLMFuncRequestFactory.build()
)
response: typing.Final = await any_llm_client.get_client(config).request_llm_message(**func_request)
assert response == config.response_message


async def test_mock_client_stream_llm_message_chunks_returns_config_value() -> None:
@pytest.mark.parametrize("func_request", LLMFuncRequestFactory.coverage())
async def test_mock_client_stream_llm_message_chunks_returns_config_value(func_request: LLMFuncRequest) -> None:
config: typing.Final = MockLLMConfigFactory.build()
response: typing.Final = await consume_llm_message_chunks(
any_llm_client.get_client(config).stream_llm_message_chunks(**LLMFuncRequestFactory.build())
any_llm_client.get_client(config).stream_llm_message_chunks(**func_request)
)
assert response == config.stream_messages
11 changes: 6 additions & 5 deletions tests/test_openai_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,15 @@
OneStreamingChoice,
OneStreamingChoiceDelta,
)
from tests.conftest import LLMFuncRequestFactory, consume_llm_message_chunks
from tests.conftest import LLMFuncRequest, LLMFuncRequestFactory, consume_llm_message_chunks


class OpenAIConfigFactory(ModelFactory[any_llm_client.OpenAIConfig]): ...


class TestOpenAIRequestLLMResponse:
async def test_ok(self, faker: faker.Faker) -> None:
@pytest.mark.parametrize("func_request", LLMFuncRequestFactory.coverage())
async def test_ok(self, faker: faker.Faker, func_request: LLMFuncRequest) -> None:
expected_result: typing.Final = faker.pystr()
response: typing.Final = httpx.Response(
200,
Expand All @@ -39,7 +40,7 @@ async def test_ok(self, faker: faker.Faker) -> None:

result: typing.Final = await any_llm_client.get_client(
OpenAIConfigFactory.build(), transport=httpx.MockTransport(lambda _: response)
).request_llm_message(**LLMFuncRequestFactory.build())
).request_llm_message(**func_request)

assert result == expected_result

Expand All @@ -57,7 +58,8 @@ async def test_fails_without_alternatives(self) -> None:


class TestOpenAIRequestLLMMessageChunks:
async def test_ok(self, faker: faker.Faker) -> None:
@pytest.mark.parametrize("func_request", LLMFuncRequestFactory.coverage())
async def test_ok(self, faker: faker.Faker, func_request: LLMFuncRequest) -> None:
generated_messages: typing.Final = [
OneStreamingChoiceDelta(role=any_llm_client.MessageRole.assistant),
OneStreamingChoiceDelta(content="H"),
Expand All @@ -78,7 +80,6 @@ async def test_ok(self, faker: faker.Faker) -> None:
"r day?",
]
config: typing.Final = OpenAIConfigFactory.build()
func_request: typing.Final = LLMFuncRequestFactory.build()
response_content: typing.Final = (
"\n\n".join(
"data: "
Expand Down
17 changes: 0 additions & 17 deletions tests/test_static.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import any_llm_client
from any_llm_client.clients.openai import ChatCompletionsRequest
from any_llm_client.clients.yandexgpt import YandexGPTRequest
from tests.conftest import LLMFuncRequest


def test_request_retry_config_default_kwargs_match() -> None:
Expand All @@ -31,22 +30,6 @@ def test_llm_error_str(faker: faker.Faker) -> None:
assert str(any_llm_client.LLMError(response_content=response_content)) == f"(response_content={response_content!r})"


def test_llm_func_request_has_same_annotations_as_llm_client_methods() -> None:
all_objects: typing.Final = (
any_llm_client.LLMClient.request_llm_message,
any_llm_client.LLMClient.stream_llm_message_chunks,
LLMFuncRequest,
)
all_annotations: typing.Final = [typing.get_type_hints(one_object) for one_object in all_objects]

for one_ignored_prop in ("return",):
for annotations in all_annotations:
if one_ignored_prop in annotations:
annotations.pop(one_ignored_prop)

assert all(annotations == all_annotations[0] for annotations in all_annotations)


@pytest.mark.parametrize("model_type", [YandexGPTRequest, ChatCompletionsRequest])
def test_dumped_llm_request_payload_dump_has_extra_data(model_type: type[pydantic.BaseModel]) -> None:
extra: typing.Final = {"hi": "there", "hi-hi": "there-there"}
Expand Down
Loading