Skip to content

Commit 8a1c63b

Browse files
committed
Implement mirrored provider clients (AnthropicVertex)
1 parent d9f514d commit 8a1c63b

File tree

5 files changed

+185
-1
lines changed

5 files changed

+185
-1
lines changed

python/mirascope/llm/clients/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
AnthropicModelId,
66
)
77
from .anthropic_bedrock import AnthropicBedrockClient
8+
from .anthropic_vertex import AnthropicVertexClient
89
from .azure_openai.completions import AzureOpenAICompletionsClient
910
from .azure_openai.responses import AzureOpenAIResponsesClient
1011
from .base import BaseClient, ClientT, Params
@@ -22,6 +23,7 @@
2223
"AnthropicBedrockClient",
2324
"AnthropicClient",
2425
"AnthropicModelId",
26+
"AnthropicVertexClient",
2527
"AzureOpenAICompletionsClient",
2628
"AzureOpenAIResponsesClient",
2729
"BaseClient",

python/mirascope/llm/clients/anthropic/base_client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ async def _prepare_messages_async(
8585

8686
@property
8787
@abstractmethod
88-
def provider(self) -> Literal["anthropic", "anthropic-bedrock"]:
88+
def provider(self) -> Literal["anthropic", "anthropic-bedrock", "anthropic-vertex"]:
8989
"""Return the provider name for Anthropic-compatible clients."""
9090
...
9191

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
"""Anthropic Vertex AI client implementations."""
2+
3+
from .clients import (
4+
AnthropicVertexClient,
5+
client,
6+
get_client,
7+
)
8+
9+
__all__ = [
10+
"AnthropicVertexClient",
11+
"client",
12+
"get_client",
13+
]
Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
"""Anthropic Vertex AI client implementation."""
2+
3+
import os
4+
from contextvars import ContextVar
5+
from functools import lru_cache
6+
from typing import Literal
7+
8+
from anthropic import NOT_GIVEN
9+
from anthropic.lib.vertex._client import AnthropicVertex, AsyncAnthropicVertex
10+
11+
from ..anthropic import BaseAnthropicClient
12+
13+
ANTHROPIC_VERTEX_CLIENT_CONTEXT: ContextVar["AnthropicVertexClient | None"] = (
14+
ContextVar("ANTHROPIC_VERTEX_CLIENT_CONTEXT", default=None)
15+
)
16+
17+
18+
@lru_cache(maxsize=256)
19+
def _anthropic_vertex_singleton(
20+
project_id: str | None,
21+
region: str | None,
22+
) -> "AnthropicVertexClient":
23+
"""Return a cached AnthropicVertexClient instance for the given parameters."""
24+
return AnthropicVertexClient(
25+
project_id=project_id,
26+
region=region,
27+
)
28+
29+
30+
def client(
31+
*,
32+
project_id: str | None = None,
33+
region: str | None = None,
34+
) -> "AnthropicVertexClient":
35+
"""Return an `AnthropicVertexClient`.
36+
37+
Args:
38+
project_id: GCP project ID. If None, uses GOOGLE_CLOUD_PROJECT, GCLOUD_PROJECT,
39+
CLOUD_ML_PROJECT_ID, or GCP_PROJECT_ID env vars (in that order).
40+
region: GCP region. If None, uses CLOUD_ML_REGION, GOOGLE_CLOUD_REGION, or
41+
GOOGLE_CLOUD_LOCATION env vars (in that order).
42+
43+
Returns:
44+
An `AnthropicVertexClient` instance.
45+
46+
Examples:
47+
# Use environment variables
48+
client = client()
49+
50+
# Use explicit parameters
51+
client = client(
52+
project_id="my-gcp-project",
53+
region="us-central1"
54+
)
55+
"""
56+
project_id = (
57+
project_id
58+
or os.getenv("GOOGLE_CLOUD_PROJECT")
59+
or os.getenv("GCLOUD_PROJECT")
60+
or os.getenv("CLOUD_ML_PROJECT_ID")
61+
or os.getenv("GCP_PROJECT_ID")
62+
)
63+
region = (
64+
region
65+
or os.getenv("CLOUD_ML_REGION")
66+
or os.getenv("GOOGLE_CLOUD_REGION")
67+
or os.getenv("GOOGLE_CLOUD_LOCATION")
68+
)
69+
70+
return _anthropic_vertex_singleton(
71+
project_id,
72+
region,
73+
)
74+
75+
76+
def clear_cache() -> None:
77+
"""Clear the client singleton cache.
78+
79+
This is useful for testing or when you need to force recreation
80+
of clients with updated configuration.
81+
"""
82+
_anthropic_vertex_singleton.cache_clear()
83+
84+
85+
def get_client() -> "AnthropicVertexClient":
86+
"""Retrieve the current Anthropic Vertex client from context, or a global default.
87+
88+
Returns:
89+
The current Anthropic Vertex client from context if available, otherwise
90+
a global default client based on environment variables.
91+
"""
92+
ctx_client = ANTHROPIC_VERTEX_CLIENT_CONTEXT.get()
93+
return ctx_client or client()
94+
95+
96+
class AnthropicVertexClient(
97+
BaseAnthropicClient[AnthropicVertex, AsyncAnthropicVertex, "AnthropicVertexClient"]
98+
):
99+
"""Anthropic Vertex AI client that inherits from BaseAnthropicClient.
100+
101+
Only overrides initialization to use Vertex-specific SDK classes and
102+
provider naming to return 'anthropic-vertex'.
103+
"""
104+
105+
@property
106+
def _context_var(self) -> ContextVar["AnthropicVertexClient | None"]:
107+
return ANTHROPIC_VERTEX_CLIENT_CONTEXT
108+
109+
def __init__(
110+
self,
111+
*,
112+
project_id: str | None = None,
113+
region: str | None = None,
114+
) -> None:
115+
"""Initialize the Anthropic Vertex AI client.
116+
117+
Args:
118+
project_id: GCP project ID.
119+
region: GCP region for Vertex AI.
120+
"""
121+
self.client = AnthropicVertex(
122+
project_id=project_id or NOT_GIVEN,
123+
region=region or NOT_GIVEN,
124+
)
125+
self.async_client = AsyncAnthropicVertex(
126+
project_id=project_id or NOT_GIVEN,
127+
region=region or NOT_GIVEN,
128+
)
129+
130+
@property
131+
def provider(self) -> Literal["anthropic-vertex"]:
132+
"""Return the provider name for Anthropic Vertex AI."""
133+
return "anthropic-vertex"

python/mirascope/llm/clients/providers.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,11 @@
1313
client as anthropic_bedrock_client,
1414
get_client as get_anthropic_bedrock_client,
1515
)
16+
from .anthropic_vertex import (
17+
AnthropicVertexClient,
18+
client as anthropic_vertex_client,
19+
get_client as get_anthropic_vertex_client,
20+
)
1621
from .azure_openai.completions import (
1722
AzureOpenAICompletionsClient,
1823
client as azure_openai_completions_client,
@@ -43,6 +48,7 @@
4348
Provider: TypeAlias = Literal[
4449
"anthropic",
4550
"anthropic-bedrock", # AnthropicBedrockClient
51+
"anthropic-vertex", # AnthropicVertexClient
4652
"azure-openai:completions", # AzureOpenAICompletionsClient
4753
"azure-openai:responses", # AzureOpenAIResponsesClient
4854
"google",
@@ -73,6 +79,12 @@ def get_client(provider: Literal["anthropic-bedrock"]) -> AnthropicBedrockClient
7379
...
7480

7581

82+
@overload
83+
def get_client(provider: Literal["anthropic-vertex"]) -> AnthropicVertexClient:
84+
"""Get an Anthropic Vertex AI client instance."""
85+
...
86+
87+
7688
@overload
7789
def get_client(
7890
provider: Literal["azure-openai:completions"],
@@ -114,6 +126,7 @@ def get_client(
114126
) -> (
115127
AnthropicClient
116128
| AnthropicBedrockClient
129+
| AnthropicVertexClient
117130
| AzureOpenAICompletionsClient
118131
| AzureOpenAIResponsesClient
119132
| GoogleClient
@@ -132,6 +145,7 @@ def get_client(
132145
- "openai:responses" returns `OpenAIResponsesClient` (Responses API)
133146
- "anthropic" returns `AnthropicClient`
134147
- "anthropic-bedrock" returns `AnthropicBedrockClient`
148+
- "anthropic-vertex" returns `AnthropicVertexClient`
135149
- "azure-openai:completions" returns `AzureOpenAICompletionsClient`
136150
- "azure-openai:responses" returns `AzureOpenAIResponsesClient`
137151
- "google" returns `GoogleClient`
@@ -147,6 +161,8 @@ def get_client(
147161
return get_anthropic_client()
148162
case "anthropic-bedrock":
149163
return get_anthropic_bedrock_client()
164+
case "anthropic-vertex":
165+
return get_anthropic_vertex_client()
150166
case "azure-openai:completions":
151167
return get_azure_openai_completions_client()
152168
case "azure-openai:responses":
@@ -209,6 +225,17 @@ def client(
209225
...
210226

211227

228+
@overload
229+
def client(
230+
provider: Literal["anthropic-vertex"],
231+
*,
232+
project_id: str | None = None,
233+
region: str | None = None,
234+
) -> AnthropicVertexClient:
235+
"""Create a cached Anthropic Vertex AI client with the given parameters."""
236+
...
237+
238+
212239
@overload
213240
def client(
214241
provider: Literal["google"],
@@ -255,6 +282,7 @@ def client(
255282
) -> (
256283
AnthropicClient
257284
| AnthropicBedrockClient
285+
| AnthropicVertexClient
258286
| AzureOpenAICompletionsClient
259287
| AzureOpenAIResponsesClient
260288
| GoogleClient
@@ -278,6 +306,9 @@ def client(
278306
- aws_secret_key: AWS secret key
279307
- aws_session_token: AWS session token
280308
- aws_profile: AWS profile name
309+
- For Vertex AI provider:
310+
- project_id: GCP project ID
311+
- region: GCP region
281312
282313
Returns:
283314
A cached client instance for the specified provider with the given parameters.
@@ -297,6 +328,11 @@ def client(
297328
aws_session_token=kwargs.get("aws_session_token"),
298329
aws_profile=kwargs.get("aws_profile"),
299330
)
331+
case "anthropic-vertex":
332+
return anthropic_vertex_client(
333+
project_id=kwargs.get("project_id"),
334+
region=kwargs.get("region"),
335+
)
300336
case "azure-openai:completions":
301337
return azure_openai_completions_client(
302338
api_key=api_key,

0 commit comments

Comments
 (0)