Skip to content

Commit a2ebb7b

Browse files
committed
Implement mirrored provider clients (AnthropicVertex)
1 parent fda97da commit a2ebb7b

File tree

5 files changed

+247
-1
lines changed

5 files changed

+247
-1
lines changed

python/mirascope/llm/clients/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
AnthropicModelId,
66
)
77
from .anthropic_bedrock import AnthropicBedrockClient
8+
from .anthropic_vertex import AnthropicVertexClient
89
from .azure_openai.completions import AzureOpenAICompletionsClient
910
from .azure_openai.responses import AzureOpenAIResponsesClient
1011
from .base import BaseClient, ClientT, Params
@@ -22,6 +23,7 @@
2223
"AnthropicBedrockClient",
2324
"AnthropicClient",
2425
"AnthropicModelId",
26+
"AnthropicVertexClient",
2527
"AzureOpenAICompletionsClient",
2628
"AzureOpenAIResponsesClient",
2729
"BaseClient",

python/mirascope/llm/clients/anthropic/base_client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ async def _prepare_messages_async(
8585

8686
@property
8787
@abstractmethod
88-
def provider(self) -> Literal["anthropic", "anthropic-bedrock"]:
88+
def provider(self) -> Literal["anthropic", "anthropic-bedrock", "anthropic-vertex"]:
8989
"""Return the provider name for Anthropic-compatible clients."""
9090
...
9191

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
"""Anthropic Vertex AI client implementations."""
2+
3+
from .clients import (
4+
AnthropicVertexClient,
5+
client,
6+
get_client,
7+
)
8+
9+
__all__ = [
10+
"AnthropicVertexClient",
11+
"client",
12+
"get_client",
13+
]
Lines changed: 195 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,195 @@
1+
"""Anthropic Vertex AI client implementation."""
2+
3+
import os
4+
from collections.abc import Sequence
5+
from contextvars import ContextVar
6+
from dataclasses import replace
7+
from functools import lru_cache
8+
from typing import Literal
9+
10+
from anthropic import NOT_GIVEN
11+
from anthropic.lib.vertex._client import AnthropicVertex, AsyncAnthropicVertex
12+
13+
from ...content import Image
14+
from ...messages import Message, UserMessage
15+
from ..anthropic import BaseAnthropicClient
16+
17+
ANTHROPIC_VERTEX_CLIENT_CONTEXT: ContextVar["AnthropicVertexClient | None"] = (
18+
ContextVar("ANTHROPIC_VERTEX_CLIENT_CONTEXT", default=None)
19+
)
20+
21+
22+
@lru_cache(maxsize=256)
23+
def _anthropic_vertex_singleton(
24+
project_id: str | None,
25+
region: str | None,
26+
) -> "AnthropicVertexClient":
27+
"""Return a cached AnthropicVertexClient instance for the given parameters."""
28+
return AnthropicVertexClient(
29+
project_id=project_id,
30+
region=region,
31+
)
32+
33+
34+
def client(
35+
*,
36+
project_id: str | None = None,
37+
region: str | None = None,
38+
) -> "AnthropicVertexClient":
39+
"""Return an `AnthropicVertexClient`.
40+
41+
Args:
42+
project_id: GCP project ID. If None, uses GOOGLE_CLOUD_PROJECT, GCLOUD_PROJECT,
43+
CLOUD_ML_PROJECT_ID, or GCP_PROJECT_ID env vars (in that order).
44+
region: GCP region. If None, uses CLOUD_ML_REGION, GOOGLE_CLOUD_REGION, or
45+
GOOGLE_CLOUD_LOCATION env vars (in that order).
46+
47+
Returns:
48+
An `AnthropicVertexClient` instance.
49+
50+
Examples:
51+
# Use environment variables
52+
client = client()
53+
54+
# Use explicit parameters
55+
client = client(
56+
project_id="my-gcp-project",
57+
region="us-central1"
58+
)
59+
"""
60+
project_id = (
61+
project_id
62+
or os.getenv("GOOGLE_CLOUD_PROJECT")
63+
or os.getenv("GCLOUD_PROJECT")
64+
or os.getenv("CLOUD_ML_PROJECT_ID")
65+
or os.getenv("GCP_PROJECT_ID")
66+
)
67+
region = (
68+
region
69+
or os.getenv("CLOUD_ML_REGION")
70+
or os.getenv("GOOGLE_CLOUD_REGION")
71+
or os.getenv("GOOGLE_CLOUD_LOCATION")
72+
)
73+
74+
return _anthropic_vertex_singleton(
75+
project_id,
76+
region,
77+
)
78+
79+
80+
def clear_cache() -> None:
81+
"""Clear the client singleton cache.
82+
83+
This is useful for testing or when you need to force recreation
84+
of clients with updated configuration.
85+
"""
86+
_anthropic_vertex_singleton.cache_clear()
87+
88+
89+
def get_client() -> "AnthropicVertexClient":
90+
"""Retrieve the current Anthropic Vertex client from context, or a global default.
91+
92+
Returns:
93+
The current Anthropic Vertex client from context if available, otherwise
94+
a global default client based on environment variables.
95+
"""
96+
ctx_client = ANTHROPIC_VERTEX_CLIENT_CONTEXT.get()
97+
return ctx_client or client()
98+
99+
100+
class AnthropicVertexClient(
101+
BaseAnthropicClient[AnthropicVertex, AsyncAnthropicVertex, "AnthropicVertexClient"]
102+
):
103+
"""Anthropic Vertex AI client that inherits from BaseAnthropicClient.
104+
105+
Only overrides initialization to use Vertex-specific SDK classes and
106+
provider naming to return 'anthropic-vertex'.
107+
"""
108+
109+
@property
110+
def _context_var(self) -> ContextVar["AnthropicVertexClient | None"]:
111+
return ANTHROPIC_VERTEX_CLIENT_CONTEXT
112+
113+
def __init__(
114+
self,
115+
*,
116+
project_id: str | None = None,
117+
region: str | None = None,
118+
) -> None:
119+
"""Initialize the Anthropic Vertex AI client.
120+
121+
Args:
122+
project_id: GCP project ID.
123+
region: GCP region for Vertex AI.
124+
"""
125+
self.client = AnthropicVertex(
126+
project_id=project_id or NOT_GIVEN,
127+
region=region or NOT_GIVEN,
128+
)
129+
self.async_client = AsyncAnthropicVertex(
130+
project_id=project_id or NOT_GIVEN,
131+
region=region or NOT_GIVEN,
132+
)
133+
134+
@property
135+
def provider(self) -> Literal["anthropic-vertex"]:
136+
"""Return the provider name for Anthropic Vertex AI."""
137+
return "anthropic-vertex"
138+
139+
def _prepare_messages(self, messages: Sequence[Message]) -> Sequence[Message]:
140+
return _ensure_base64_images(messages)
141+
142+
async def _prepare_messages_async(
143+
self, messages: Sequence[Message]
144+
) -> Sequence[Message]:
145+
return await _ensure_base64_images_async(messages)
146+
147+
148+
def _ensure_base64_images(messages: Sequence[Message]) -> Sequence[Message]:
149+
"""Convert URL-sourced images to base64 for Vertex AI."""
150+
updated_messages: list[Message] = []
151+
any_updates = False
152+
153+
for message in messages:
154+
if isinstance(message, UserMessage):
155+
converted_content = []
156+
content_changed = False
157+
for part in message.content:
158+
if isinstance(part, Image) and part.source.type == "url_image_source":
159+
converted_content.append(Image.download(part.source.url))
160+
content_changed = True
161+
else:
162+
converted_content.append(part)
163+
if content_changed:
164+
message = replace(message, content=converted_content)
165+
any_updates = True
166+
updated_messages.append(message)
167+
168+
return updated_messages if any_updates else messages
169+
170+
171+
async def _ensure_base64_images_async(
172+
messages: Sequence[Message],
173+
) -> Sequence[Message]:
174+
"""Convert URL-sourced images to base64 for Vertex AI (async)."""
175+
updated_messages: list[Message] = []
176+
any_updates = False
177+
178+
for message in messages:
179+
if isinstance(message, UserMessage):
180+
converted_content = []
181+
content_changed = False
182+
for part in message.content:
183+
if isinstance(part, Image) and part.source.type == "url_image_source":
184+
converted_content.append(
185+
await Image.download_async(part.source.url)
186+
)
187+
content_changed = True
188+
else:
189+
converted_content.append(part)
190+
if content_changed:
191+
message = replace(message, content=converted_content)
192+
any_updates = True
193+
updated_messages.append(message)
194+
195+
return updated_messages if any_updates else messages

python/mirascope/llm/clients/providers.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,11 @@
1313
client as anthropic_bedrock_client,
1414
get_client as get_anthropic_bedrock_client,
1515
)
16+
from .anthropic_vertex import (
17+
AnthropicVertexClient,
18+
client as anthropic_vertex_client,
19+
get_client as get_anthropic_vertex_client,
20+
)
1621
from .azure_openai.completions import (
1722
AzureOpenAICompletionsClient,
1823
client as azure_openai_completions_client,
@@ -43,6 +48,7 @@
4348
Provider: TypeAlias = Literal[
4449
"anthropic",
4550
"anthropic-bedrock", # AnthropicBedrockClient
51+
"anthropic-vertex", # AnthropicVertexClient
4652
"azure-openai:completions", # AzureOpenAICompletionsClient
4753
"azure-openai:responses", # AzureOpenAIResponsesClient
4854
"google",
@@ -73,6 +79,12 @@ def get_client(provider: Literal["anthropic-bedrock"]) -> AnthropicBedrockClient
7379
...
7480

7581

82+
@overload
83+
def get_client(provider: Literal["anthropic-vertex"]) -> AnthropicVertexClient:
84+
"""Get an Anthropic Vertex AI client instance."""
85+
...
86+
87+
7688
@overload
7789
def get_client(
7890
provider: Literal["azure-openai:completions"],
@@ -114,6 +126,7 @@ def get_client(
114126
) -> (
115127
AnthropicClient
116128
| AnthropicBedrockClient
129+
| AnthropicVertexClient
117130
| AzureOpenAICompletionsClient
118131
| AzureOpenAIResponsesClient
119132
| GoogleClient
@@ -132,6 +145,7 @@ def get_client(
132145
- "openai:responses" returns `OpenAIResponsesClient` (Responses API)
133146
- "anthropic" returns `AnthropicClient`
134147
- "anthropic-bedrock" returns `AnthropicBedrockClient`
148+
- "anthropic-vertex" returns `AnthropicVertexClient`
135149
- "azure-openai:completions" returns `AzureOpenAICompletionsClient`
136150
- "azure-openai:responses" returns `AzureOpenAIResponsesClient`
137151
- "google" returns `GoogleClient`
@@ -147,6 +161,8 @@ def get_client(
147161
return get_anthropic_client()
148162
case "anthropic-bedrock":
149163
return get_anthropic_bedrock_client()
164+
case "anthropic-vertex":
165+
return get_anthropic_vertex_client()
150166
case "azure-openai:completions":
151167
return get_azure_openai_completions_client()
152168
case "azure-openai:responses":
@@ -209,6 +225,17 @@ def client(
209225
...
210226

211227

228+
@overload
229+
def client(
230+
provider: Literal["anthropic-vertex"],
231+
*,
232+
project_id: str | None = None,
233+
region: str | None = None,
234+
) -> AnthropicVertexClient:
235+
"""Create a cached Anthropic Vertex AI client with the given parameters."""
236+
...
237+
238+
212239
@overload
213240
def client(
214241
provider: Literal["google"],
@@ -255,6 +282,7 @@ def client(
255282
) -> (
256283
AnthropicClient
257284
| AnthropicBedrockClient
285+
| AnthropicVertexClient
258286
| AzureOpenAICompletionsClient
259287
| AzureOpenAIResponsesClient
260288
| GoogleClient
@@ -278,6 +306,9 @@ def client(
278306
- aws_secret_key: AWS secret key
279307
- aws_session_token: AWS session token
280308
- aws_profile: AWS profile name
309+
- For Vertex AI provider:
310+
- project_id: GCP project ID
311+
- region: GCP region
281312
282313
Returns:
283314
A cached client instance for the specified provider with the given parameters.
@@ -297,6 +328,11 @@ def client(
297328
aws_session_token=kwargs.get("aws_session_token"),
298329
aws_profile=kwargs.get("aws_profile"),
299330
)
331+
case "anthropic-vertex":
332+
return anthropic_vertex_client(
333+
project_id=kwargs.get("project_id"),
334+
region=kwargs.get("region"),
335+
)
300336
case "azure-openai:completions":
301337
return azure_openai_completions_client(
302338
api_key=api_key,

0 commit comments

Comments
 (0)