Skip to content

Commit b7766d1

Browse files
committed
Add public provider property to BaseClient
1 parent 603e32d commit b7766d1

File tree

5 files changed

+82
-37
lines changed

5 files changed

+82
-37
lines changed

python/mirascope/llm/clients/anthropic/clients.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from collections.abc import Sequence
55
from contextvars import ContextVar
66
from functools import lru_cache
7-
from typing import overload
7+
from typing import TYPE_CHECKING, overload
88
from typing_extensions import Unpack
99

1010
from anthropic import Anthropic, AsyncAnthropic
@@ -36,6 +36,9 @@
3636
from . import _utils
3737
from .model_ids import AnthropicModelId
3838

39+
if TYPE_CHECKING:
40+
from ..providers import Provider
41+
3942
ANTHROPIC_CLIENT_CONTEXT: ContextVar["AnthropicClient | None"] = ContextVar(
4043
"ANTHROPIC_CLIENT_CONTEXT", default=None
4144
)
@@ -87,6 +90,11 @@ class AnthropicClient(BaseClient[AnthropicModelId, Anthropic]):
8790
def _context_var(self) -> ContextVar["AnthropicClient | None"]:
8891
return ANTHROPIC_CLIENT_CONTEXT
8992

93+
@property
94+
def provider(self) -> "Provider":
95+
"""Return the provider name for this client."""
96+
return "anthropic"
97+
9098
def __init__(
9199
self, *, api_key: str | None = None, base_url: str | None = None
92100
) -> None:
@@ -170,7 +178,7 @@ def call(
170178

171179
return Response(
172180
raw=anthropic_response,
173-
provider="anthropic",
181+
provider=self.provider,
174182
model_id=model_id,
175183
params=params,
176184
tools=tools,
@@ -269,7 +277,7 @@ def context_call(
269277

270278
return ContextResponse(
271279
raw=anthropic_response,
272-
provider="anthropic",
280+
provider=self.provider,
273281
model_id=model_id,
274282
params=params,
275283
tools=tools,
@@ -355,7 +363,7 @@ async def call_async(
355363

356364
return AsyncResponse(
357365
raw=anthropic_response,
358-
provider="anthropic",
366+
provider=self.provider,
359367
model_id=model_id,
360368
params=params,
361369
tools=tools,
@@ -454,7 +462,7 @@ async def context_call_async(
454462

455463
return AsyncContextResponse(
456464
raw=anthropic_response,
457-
provider="anthropic",
465+
provider=self.provider,
458466
model_id=model_id,
459467
params=params,
460468
tools=tools,
@@ -537,7 +545,7 @@ def stream(
537545
chunk_iterator = _utils.decode_stream(anthropic_stream)
538546

539547
return StreamResponse(
540-
provider="anthropic",
548+
provider=self.provider,
541549
model_id=model_id,
542550
params=params,
543551
tools=tools,
@@ -632,7 +640,7 @@ def context_stream(
632640
chunk_iterator = _utils.decode_stream(anthropic_stream)
633641

634642
return ContextStreamResponse(
635-
provider="anthropic",
643+
provider=self.provider,
636644
model_id=model_id,
637645
params=params,
638646
tools=tools,
@@ -714,7 +722,7 @@ async def stream_async(
714722
chunk_iterator = _utils.decode_async_stream(anthropic_stream)
715723

716724
return AsyncStreamResponse(
717-
provider="anthropic",
725+
provider=self.provider,
718726
model_id=model_id,
719727
params=params,
720728
tools=tools,
@@ -809,7 +817,7 @@ async def context_stream_async(
809817
chunk_iterator = _utils.decode_async_stream(anthropic_stream)
810818

811819
return AsyncContextStreamResponse(
812-
provider="anthropic",
820+
provider=self.provider,
813821
model_id=model_id,
814822
params=params,
815823
tools=tools,

python/mirascope/llm/clients/base/client.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from collections.abc import Sequence
77
from contextvars import ContextVar, Token
88
from types import TracebackType
9-
from typing import Generic, overload
9+
from typing import TYPE_CHECKING, Generic, overload
1010
from typing_extensions import Self, TypeVar, Unpack
1111

1212
from ...context import Context, DepsT
@@ -34,6 +34,9 @@
3434
)
3535
from .params import Params
3636

37+
if TYPE_CHECKING:
38+
from ..providers import Provider
39+
3740
ModelIdT = TypeVar("ModelIdT", bound=str)
3841
ProviderClientT = TypeVar("ProviderClientT")
3942

@@ -57,6 +60,16 @@ def _context_var(self) -> ContextVar:
5760
"""The ContextVar for this client type."""
5861
...
5962

63+
@property
64+
@abstractmethod
65+
def provider(self) -> "Provider":
66+
"""The provider name for this client.
67+
68+
This property provides the name of the provider and is available for
69+
overriding by subclasses in the case of a mirrored or wrapped client.
70+
"""
71+
...
72+
6073
def __enter__(self) -> Self:
6174
"""Sets the client context and stores the token."""
6275
self._token = self._context_var.set(self)

python/mirascope/llm/clients/google/clients.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from collections.abc import Sequence
55
from contextvars import ContextVar
66
from functools import lru_cache
7-
from typing import overload
7+
from typing import TYPE_CHECKING, overload
88
from typing_extensions import Unpack
99

1010
from google.genai import Client
@@ -37,6 +37,9 @@
3737
from . import _utils
3838
from .model_ids import GoogleModelId
3939

40+
if TYPE_CHECKING:
41+
from ..providers import Provider
42+
4043
GOOGLE_CLIENT_CONTEXT: ContextVar["GoogleClient | None"] = ContextVar(
4144
"GOOGLE_CLIENT_CONTEXT", default=None
4245
)
@@ -86,6 +89,11 @@ class GoogleClient(BaseClient[GoogleModelId, Client]):
8689
def _context_var(self) -> ContextVar["GoogleClient | None"]:
8790
return GOOGLE_CLIENT_CONTEXT
8891

92+
@property
93+
def provider(self) -> "Provider":
94+
"""Return the provider name for this client."""
95+
return "google"
96+
8997
def __init__(
9098
self, *, api_key: str | None = None, base_url: str | None = None
9199
) -> None:
@@ -176,7 +184,7 @@ def call(
176184

177185
return Response(
178186
raw=google_response,
179-
provider="google",
187+
provider=self.provider,
180188
model_id=model_id,
181189
params=params,
182190
tools=tools,
@@ -279,7 +287,7 @@ def context_call(
279287

280288
return ContextResponse(
281289
raw=google_response,
282-
provider="google",
290+
provider=self.provider,
283291
model_id=model_id,
284292
params=params,
285293
tools=tools,
@@ -369,7 +377,7 @@ async def call_async(
369377

370378
return AsyncResponse(
371379
raw=google_response,
372-
provider="google",
380+
provider=self.provider,
373381
model_id=model_id,
374382
params=params,
375383
tools=tools,
@@ -472,7 +480,7 @@ async def context_call_async(
472480

473481
return AsyncContextResponse(
474482
raw=google_response,
475-
provider="google",
483+
provider=self.provider,
476484
model_id=model_id,
477485
params=params,
478486
tools=tools,
@@ -559,7 +567,7 @@ def stream(
559567
chunk_iterator = _utils.decode_stream(google_stream)
560568

561569
return StreamResponse(
562-
provider="google",
570+
provider=self.provider,
563571
model_id=model_id,
564572
params=params,
565573
tools=tools,
@@ -658,7 +666,7 @@ def context_stream(
658666
chunk_iterator = _utils.decode_stream(google_stream)
659667

660668
return ContextStreamResponse(
661-
provider="google",
669+
provider=self.provider,
662670
model_id=model_id,
663671
params=params,
664672
tools=tools,
@@ -744,7 +752,7 @@ async def stream_async(
744752
chunk_iterator = _utils.decode_async_stream(google_stream)
745753

746754
return AsyncStreamResponse(
747-
provider="google",
755+
provider=self.provider,
748756
model_id=model_id,
749757
params=params,
750758
tools=tools,
@@ -843,7 +851,7 @@ async def context_stream_async(
843851
chunk_iterator = _utils.decode_async_stream(google_stream)
844852

845853
return AsyncContextStreamResponse(
846-
provider="google",
854+
provider=self.provider,
847855
model_id=model_id,
848856
params=params,
849857
tools=tools,

python/mirascope/llm/clients/openai/completions/clients.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from collections.abc import Sequence
55
from contextvars import ContextVar
66
from functools import lru_cache
7-
from typing import overload
7+
from typing import TYPE_CHECKING, overload
88
from typing_extensions import Unpack
99

1010
from openai import AsyncOpenAI, OpenAI
@@ -36,6 +36,9 @@
3636
from . import _utils
3737
from .model_ids import OpenAICompletionsModelId
3838

39+
if TYPE_CHECKING:
40+
from ...providers import Provider
41+
3942
OPENAI_COMPLETIONS_CLIENT_CONTEXT: ContextVar["OpenAICompletionsClient | None"] = (
4043
ContextVar("OPENAI_COMPLETIONS_CLIENT_CONTEXT", default=None)
4144
)
@@ -87,6 +90,11 @@ class OpenAICompletionsClient(BaseClient[OpenAICompletionsModelId, OpenAI]):
8790
def _context_var(self) -> ContextVar["OpenAICompletionsClient | None"]:
8891
return OPENAI_COMPLETIONS_CLIENT_CONTEXT
8992

93+
@property
94+
def provider(self) -> "Provider":
95+
"""Return the provider name for this client."""
96+
return "openai:completions"
97+
9098
def __init__(
9199
self, *, api_key: str | None = None, base_url: str | None = None
92100
) -> None:
@@ -170,7 +178,7 @@ def call(
170178

171179
return Response(
172180
raw=openai_response,
173-
provider="openai:completions",
181+
provider=self.provider,
174182
model_id=model_id,
175183
params=params,
176184
tools=tools,
@@ -269,7 +277,7 @@ def context_call(
269277

270278
return ContextResponse(
271279
raw=openai_response,
272-
provider="openai:completions",
280+
provider=self.provider,
273281
model_id=model_id,
274282
params=params,
275283
tools=tools,
@@ -356,7 +364,7 @@ async def call_async(
356364

357365
return AsyncResponse(
358366
raw=openai_response,
359-
provider="openai:completions",
367+
provider=self.provider,
360368
model_id=model_id,
361369
params=params,
362370
tools=tools,
@@ -455,7 +463,7 @@ async def context_call_async(
455463

456464
return AsyncContextResponse(
457465
raw=openai_response,
458-
provider="openai:completions",
466+
provider=self.provider,
459467
model_id=model_id,
460468
params=params,
461469
tools=tools,
@@ -541,7 +549,7 @@ def stream(
541549
chunk_iterator = _utils.decode_stream(openai_stream)
542550

543551
return StreamResponse(
544-
provider="openai:completions",
552+
provider=self.provider,
545553
model_id=model_id,
546554
params=params,
547555
tools=tools,
@@ -639,7 +647,7 @@ def context_stream(
639647
chunk_iterator = _utils.decode_stream(openai_stream)
640648

641649
return ContextStreamResponse(
642-
provider="openai:completions",
650+
provider=self.provider,
643651
model_id=model_id,
644652
params=params,
645653
tools=tools,
@@ -725,7 +733,7 @@ async def stream_async(
725733
chunk_iterator = _utils.decode_async_stream(openai_stream)
726734

727735
return AsyncStreamResponse(
728-
provider="openai:completions",
736+
provider=self.provider,
729737
model_id=model_id,
730738
params=params,
731739
tools=tools,
@@ -823,7 +831,7 @@ async def context_stream_async(
823831
chunk_iterator = _utils.decode_async_stream(openai_stream)
824832

825833
return AsyncContextStreamResponse(
826-
provider="openai:completions",
834+
provider=self.provider,
827835
model_id=model_id,
828836
params=params,
829837
tools=tools,

0 commit comments

Comments
 (0)