Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 17 additions & 9 deletions python/mirascope/llm/clients/anthropic/clients.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from collections.abc import Sequence
from contextvars import ContextVar
from functools import lru_cache
from typing import overload
from typing import TYPE_CHECKING, overload
from typing_extensions import Unpack

from anthropic import Anthropic, AsyncAnthropic
Expand Down Expand Up @@ -36,6 +36,9 @@
from . import _utils
from .model_ids import AnthropicModelId

if TYPE_CHECKING:
from ..providers import Provider

ANTHROPIC_CLIENT_CONTEXT: ContextVar["AnthropicClient | None"] = ContextVar(
"ANTHROPIC_CLIENT_CONTEXT", default=None
)
Expand Down Expand Up @@ -87,6 +90,11 @@ class AnthropicClient(BaseClient[AnthropicModelId, Anthropic]):
def _context_var(self) -> ContextVar["AnthropicClient | None"]:
return ANTHROPIC_CLIENT_CONTEXT

@property
def provider(self) -> "Provider":
"""Return the provider name for this client."""
return "anthropic"

def __init__(
self, *, api_key: str | None = None, base_url: str | None = None
) -> None:
Expand Down Expand Up @@ -170,7 +178,7 @@ def call(

return Response(
raw=anthropic_response,
provider="anthropic",
provider=self.provider,
model_id=model_id,
params=params,
tools=tools,
Expand Down Expand Up @@ -269,7 +277,7 @@ def context_call(

return ContextResponse(
raw=anthropic_response,
provider="anthropic",
provider=self.provider,
model_id=model_id,
params=params,
tools=tools,
Expand Down Expand Up @@ -355,7 +363,7 @@ async def call_async(

return AsyncResponse(
raw=anthropic_response,
provider="anthropic",
provider=self.provider,
model_id=model_id,
params=params,
tools=tools,
Expand Down Expand Up @@ -454,7 +462,7 @@ async def context_call_async(

return AsyncContextResponse(
raw=anthropic_response,
provider="anthropic",
provider=self.provider,
model_id=model_id,
params=params,
tools=tools,
Expand Down Expand Up @@ -537,7 +545,7 @@ def stream(
chunk_iterator = _utils.decode_stream(anthropic_stream)

return StreamResponse(
provider="anthropic",
provider=self.provider,
model_id=model_id,
params=params,
tools=tools,
Expand Down Expand Up @@ -632,7 +640,7 @@ def context_stream(
chunk_iterator = _utils.decode_stream(anthropic_stream)

return ContextStreamResponse(
provider="anthropic",
provider=self.provider,
model_id=model_id,
params=params,
tools=tools,
Expand Down Expand Up @@ -714,7 +722,7 @@ async def stream_async(
chunk_iterator = _utils.decode_async_stream(anthropic_stream)

return AsyncStreamResponse(
provider="anthropic",
provider=self.provider,
model_id=model_id,
params=params,
tools=tools,
Expand Down Expand Up @@ -809,7 +817,7 @@ async def context_stream_async(
chunk_iterator = _utils.decode_async_stream(anthropic_stream)

return AsyncContextStreamResponse(
provider="anthropic",
provider=self.provider,
model_id=model_id,
params=params,
tools=tools,
Expand Down
15 changes: 14 additions & 1 deletion python/mirascope/llm/clients/base/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from collections.abc import Sequence
from contextvars import ContextVar, Token
from types import TracebackType
from typing import Generic, overload
from typing import TYPE_CHECKING, Generic, overload
from typing_extensions import Self, TypeVar, Unpack

from ...context import Context, DepsT
Expand Down Expand Up @@ -34,6 +34,9 @@
)
from .params import Params

if TYPE_CHECKING:
from ..providers import Provider

ModelIdT = TypeVar("ModelIdT", bound=str)
ProviderClientT = TypeVar("ProviderClientT")

Expand All @@ -57,6 +60,16 @@ def _context_var(self) -> ContextVar:
"""The ContextVar for this client type."""
...

@property
@abstractmethod
def provider(self) -> Provider:
"""The provider name for this client.

This property provides the name of the provider and is available for
overriding by subclasses in the case of a mirrored or wrapped client.
"""
...

def __enter__(self) -> Self:
"""Sets the client context and stores the token."""
self._token = self._context_var.set(self)
Expand Down
26 changes: 17 additions & 9 deletions python/mirascope/llm/clients/google/clients.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from collections.abc import Sequence
from contextvars import ContextVar
from functools import lru_cache
from typing import overload
from typing import TYPE_CHECKING, overload
from typing_extensions import Unpack

from google.genai import Client
Expand Down Expand Up @@ -37,6 +37,9 @@
from . import _utils
from .model_ids import GoogleModelId

if TYPE_CHECKING:
from ..providers import Provider

GOOGLE_CLIENT_CONTEXT: ContextVar["GoogleClient | None"] = ContextVar(
"GOOGLE_CLIENT_CONTEXT", default=None
)
Expand Down Expand Up @@ -86,6 +89,11 @@ class GoogleClient(BaseClient[GoogleModelId, Client]):
def _context_var(self) -> ContextVar["GoogleClient | None"]:
return GOOGLE_CLIENT_CONTEXT

@property
def provider(self) -> "Provider":
"""Return the provider name for this client."""
return "google"

def __init__(
self, *, api_key: str | None = None, base_url: str | None = None
) -> None:
Expand Down Expand Up @@ -176,7 +184,7 @@ def call(

return Response(
raw=google_response,
provider="google",
provider=self.provider,
model_id=model_id,
params=params,
tools=tools,
Expand Down Expand Up @@ -279,7 +287,7 @@ def context_call(

return ContextResponse(
raw=google_response,
provider="google",
provider=self.provider,
model_id=model_id,
params=params,
tools=tools,
Expand Down Expand Up @@ -369,7 +377,7 @@ async def call_async(

return AsyncResponse(
raw=google_response,
provider="google",
provider=self.provider,
model_id=model_id,
params=params,
tools=tools,
Expand Down Expand Up @@ -472,7 +480,7 @@ async def context_call_async(

return AsyncContextResponse(
raw=google_response,
provider="google",
provider=self.provider,
model_id=model_id,
params=params,
tools=tools,
Expand Down Expand Up @@ -559,7 +567,7 @@ def stream(
chunk_iterator = _utils.decode_stream(google_stream)

return StreamResponse(
provider="google",
provider=self.provider,
model_id=model_id,
params=params,
tools=tools,
Expand Down Expand Up @@ -658,7 +666,7 @@ def context_stream(
chunk_iterator = _utils.decode_stream(google_stream)

return ContextStreamResponse(
provider="google",
provider=self.provider,
model_id=model_id,
params=params,
tools=tools,
Expand Down Expand Up @@ -744,7 +752,7 @@ async def stream_async(
chunk_iterator = _utils.decode_async_stream(google_stream)

return AsyncStreamResponse(
provider="google",
provider=self.provider,
model_id=model_id,
params=params,
tools=tools,
Expand Down Expand Up @@ -843,7 +851,7 @@ async def context_stream_async(
chunk_iterator = _utils.decode_async_stream(google_stream)

return AsyncContextStreamResponse(
provider="google",
provider=self.provider,
model_id=model_id,
params=params,
tools=tools,
Expand Down
26 changes: 17 additions & 9 deletions python/mirascope/llm/clients/openai/completions/clients.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from collections.abc import Sequence
from contextvars import ContextVar
from functools import lru_cache
from typing import overload
from typing import TYPE_CHECKING, overload
from typing_extensions import Unpack

from openai import AsyncOpenAI, OpenAI
Expand Down Expand Up @@ -36,6 +36,9 @@
from . import _utils
from .model_ids import OpenAICompletionsModelId

if TYPE_CHECKING:
from ...providers import Provider

OPENAI_COMPLETIONS_CLIENT_CONTEXT: ContextVar["OpenAICompletionsClient | None"] = (
ContextVar("OPENAI_COMPLETIONS_CLIENT_CONTEXT", default=None)
)
Expand Down Expand Up @@ -87,6 +90,11 @@ class OpenAICompletionsClient(BaseClient[OpenAICompletionsModelId, OpenAI]):
def _context_var(self) -> ContextVar["OpenAICompletionsClient | None"]:
return OPENAI_COMPLETIONS_CLIENT_CONTEXT

@property
def provider(self) -> "Provider":
"""Return the provider name for this client."""
return "openai:completions"

def __init__(
self, *, api_key: str | None = None, base_url: str | None = None
) -> None:
Expand Down Expand Up @@ -170,7 +178,7 @@ def call(

return Response(
raw=openai_response,
provider="openai:completions",
provider=self.provider,
model_id=model_id,
params=params,
tools=tools,
Expand Down Expand Up @@ -269,7 +277,7 @@ def context_call(

return ContextResponse(
raw=openai_response,
provider="openai:completions",
provider=self.provider,
model_id=model_id,
params=params,
tools=tools,
Expand Down Expand Up @@ -356,7 +364,7 @@ async def call_async(

return AsyncResponse(
raw=openai_response,
provider="openai:completions",
provider=self.provider,
model_id=model_id,
params=params,
tools=tools,
Expand Down Expand Up @@ -455,7 +463,7 @@ async def context_call_async(

return AsyncContextResponse(
raw=openai_response,
provider="openai:completions",
provider=self.provider,
model_id=model_id,
params=params,
tools=tools,
Expand Down Expand Up @@ -541,7 +549,7 @@ def stream(
chunk_iterator = _utils.decode_stream(openai_stream)

return StreamResponse(
provider="openai:completions",
provider=self.provider,
model_id=model_id,
params=params,
tools=tools,
Expand Down Expand Up @@ -639,7 +647,7 @@ def context_stream(
chunk_iterator = _utils.decode_stream(openai_stream)

return ContextStreamResponse(
provider="openai:completions",
provider=self.provider,
model_id=model_id,
params=params,
tools=tools,
Expand Down Expand Up @@ -725,7 +733,7 @@ async def stream_async(
chunk_iterator = _utils.decode_async_stream(openai_stream)

return AsyncStreamResponse(
provider="openai:completions",
provider=self.provider,
model_id=model_id,
params=params,
tools=tools,
Expand Down Expand Up @@ -823,7 +831,7 @@ async def context_stream_async(
chunk_iterator = _utils.decode_async_stream(openai_stream)

return AsyncContextStreamResponse(
provider="openai:completions",
provider=self.provider,
model_id=model_id,
params=params,
tools=tools,
Expand Down
Loading