|  | 
|  | 1 | +"""Anthropic Vertex AI client implementation.""" | 
|  | 2 | + | 
|  | 3 | +import os | 
|  | 4 | +from collections.abc import Sequence | 
|  | 5 | +from contextvars import ContextVar | 
|  | 6 | +from dataclasses import replace | 
|  | 7 | +from functools import lru_cache | 
|  | 8 | +from typing import Literal | 
|  | 9 | + | 
|  | 10 | +from anthropic import NOT_GIVEN | 
|  | 11 | +from anthropic.lib.vertex._client import AnthropicVertex, AsyncAnthropicVertex | 
|  | 12 | + | 
|  | 13 | +from ...content import Image | 
|  | 14 | +from ...messages import Message, UserMessage | 
|  | 15 | +from ..anthropic import BaseAnthropicClient | 
|  | 16 | + | 
|  | 17 | +ANTHROPIC_VERTEX_CLIENT_CONTEXT: ContextVar["AnthropicVertexClient | None"] = ( | 
|  | 18 | +    ContextVar("ANTHROPIC_VERTEX_CLIENT_CONTEXT", default=None) | 
|  | 19 | +) | 
|  | 20 | + | 
|  | 21 | + | 
|  | 22 | +@lru_cache(maxsize=256) | 
|  | 23 | +def _anthropic_vertex_singleton( | 
|  | 24 | +    project_id: str | None, | 
|  | 25 | +    region: str | None, | 
|  | 26 | +) -> "AnthropicVertexClient": | 
|  | 27 | +    """Return a cached AnthropicVertexClient instance for the given parameters.""" | 
|  | 28 | +    return AnthropicVertexClient( | 
|  | 29 | +        project_id=project_id, | 
|  | 30 | +        region=region, | 
|  | 31 | +    ) | 
|  | 32 | + | 
|  | 33 | + | 
|  | 34 | +def client( | 
|  | 35 | +    *, | 
|  | 36 | +    project_id: str | None = None, | 
|  | 37 | +    region: str | None = None, | 
|  | 38 | +) -> "AnthropicVertexClient": | 
|  | 39 | +    """Return an `AnthropicVertexClient`. | 
|  | 40 | +
 | 
|  | 41 | +    Args: | 
|  | 42 | +        project_id: GCP project ID. If None, uses GOOGLE_CLOUD_PROJECT, GCLOUD_PROJECT, | 
|  | 43 | +            CLOUD_ML_PROJECT_ID, or GCP_PROJECT_ID env vars (in that order). | 
|  | 44 | +        region: GCP region. If None, uses CLOUD_ML_REGION, GOOGLE_CLOUD_REGION, or | 
|  | 45 | +            GOOGLE_CLOUD_LOCATION env vars (in that order). | 
|  | 46 | +
 | 
|  | 47 | +    Returns: | 
|  | 48 | +        An `AnthropicVertexClient` instance. | 
|  | 49 | +
 | 
|  | 50 | +    Examples: | 
|  | 51 | +        # Use environment variables | 
|  | 52 | +        client = client() | 
|  | 53 | +
 | 
|  | 54 | +        # Use explicit parameters | 
|  | 55 | +        client = client( | 
|  | 56 | +            project_id="my-gcp-project", | 
|  | 57 | +            region="us-central1" | 
|  | 58 | +        ) | 
|  | 59 | +    """ | 
|  | 60 | +    project_id = ( | 
|  | 61 | +        project_id | 
|  | 62 | +        or os.getenv("GOOGLE_CLOUD_PROJECT") | 
|  | 63 | +        or os.getenv("GCLOUD_PROJECT") | 
|  | 64 | +        or os.getenv("CLOUD_ML_PROJECT_ID") | 
|  | 65 | +        or os.getenv("GCP_PROJECT_ID") | 
|  | 66 | +    ) | 
|  | 67 | +    region = ( | 
|  | 68 | +        region | 
|  | 69 | +        or os.getenv("CLOUD_ML_REGION") | 
|  | 70 | +        or os.getenv("GOOGLE_CLOUD_REGION") | 
|  | 71 | +        or os.getenv("GOOGLE_CLOUD_LOCATION") | 
|  | 72 | +    ) | 
|  | 73 | + | 
|  | 74 | +    return _anthropic_vertex_singleton( | 
|  | 75 | +        project_id, | 
|  | 76 | +        region, | 
|  | 77 | +    ) | 
|  | 78 | + | 
|  | 79 | + | 
|  | 80 | +def clear_cache() -> None: | 
|  | 81 | +    """Clear the client singleton cache. | 
|  | 82 | +
 | 
|  | 83 | +    This is useful for testing or when you need to force recreation | 
|  | 84 | +    of clients with updated configuration. | 
|  | 85 | +    """ | 
|  | 86 | +    _anthropic_vertex_singleton.cache_clear() | 
|  | 87 | + | 
|  | 88 | + | 
|  | 89 | +def get_client() -> "AnthropicVertexClient": | 
|  | 90 | +    """Retrieve the current Anthropic Vertex client from context, or a global default. | 
|  | 91 | +
 | 
|  | 92 | +    Returns: | 
|  | 93 | +        The current Anthropic Vertex client from context if available, otherwise | 
|  | 94 | +        a global default client based on environment variables. | 
|  | 95 | +    """ | 
|  | 96 | +    ctx_client = ANTHROPIC_VERTEX_CLIENT_CONTEXT.get() | 
|  | 97 | +    return ctx_client or client() | 
|  | 98 | + | 
|  | 99 | + | 
|  | 100 | +class AnthropicVertexClient( | 
|  | 101 | +    BaseAnthropicClient[AnthropicVertex, AsyncAnthropicVertex, "AnthropicVertexClient"] | 
|  | 102 | +): | 
|  | 103 | +    """Anthropic Vertex AI client that inherits from BaseAnthropicClient. | 
|  | 104 | +
 | 
|  | 105 | +    Only overrides initialization to use Vertex-specific SDK classes and | 
|  | 106 | +    provider naming to return 'anthropic-vertex'. | 
|  | 107 | +    """ | 
|  | 108 | + | 
|  | 109 | +    @property | 
|  | 110 | +    def _context_var(self) -> ContextVar["AnthropicVertexClient | None"]: | 
|  | 111 | +        return ANTHROPIC_VERTEX_CLIENT_CONTEXT | 
|  | 112 | + | 
|  | 113 | +    def __init__( | 
|  | 114 | +        self, | 
|  | 115 | +        *, | 
|  | 116 | +        project_id: str | None = None, | 
|  | 117 | +        region: str | None = None, | 
|  | 118 | +    ) -> None: | 
|  | 119 | +        """Initialize the Anthropic Vertex AI client. | 
|  | 120 | +
 | 
|  | 121 | +        Args: | 
|  | 122 | +            project_id: GCP project ID. | 
|  | 123 | +            region: GCP region for Vertex AI. | 
|  | 124 | +        """ | 
|  | 125 | +        self.client = AnthropicVertex( | 
|  | 126 | +            project_id=project_id or NOT_GIVEN, | 
|  | 127 | +            region=region or NOT_GIVEN, | 
|  | 128 | +        ) | 
|  | 129 | +        self.async_client = AsyncAnthropicVertex( | 
|  | 130 | +            project_id=project_id or NOT_GIVEN, | 
|  | 131 | +            region=region or NOT_GIVEN, | 
|  | 132 | +        ) | 
|  | 133 | + | 
|  | 134 | +    @property | 
|  | 135 | +    def provider(self) -> Literal["anthropic-vertex"]: | 
|  | 136 | +        """Return the provider name for Anthropic Vertex AI.""" | 
|  | 137 | +        return "anthropic-vertex" | 
|  | 138 | + | 
|  | 139 | +    def _prepare_messages(self, messages: Sequence[Message]) -> Sequence[Message]: | 
|  | 140 | +        return _ensure_base64_images(messages) | 
|  | 141 | + | 
|  | 142 | +    async def _prepare_messages_async( | 
|  | 143 | +        self, messages: Sequence[Message] | 
|  | 144 | +    ) -> Sequence[Message]: | 
|  | 145 | +        return await _ensure_base64_images_async(messages) | 
|  | 146 | + | 
|  | 147 | + | 
|  | 148 | +def _ensure_base64_images(messages: Sequence[Message]) -> Sequence[Message]: | 
|  | 149 | +    """Convert URL-sourced images to base64 for Vertex AI.""" | 
|  | 150 | +    updated_messages: list[Message] = [] | 
|  | 151 | +    any_updates = False | 
|  | 152 | + | 
|  | 153 | +    for message in messages: | 
|  | 154 | +        if isinstance(message, UserMessage): | 
|  | 155 | +            converted_content = [] | 
|  | 156 | +            content_changed = False | 
|  | 157 | +            for part in message.content: | 
|  | 158 | +                if isinstance(part, Image) and part.source.type == "url_image_source": | 
|  | 159 | +                    converted_content.append(Image.download(part.source.url)) | 
|  | 160 | +                    content_changed = True | 
|  | 161 | +                else: | 
|  | 162 | +                    converted_content.append(part) | 
|  | 163 | +            if content_changed: | 
|  | 164 | +                message = replace(message, content=converted_content) | 
|  | 165 | +                any_updates = True | 
|  | 166 | +        updated_messages.append(message) | 
|  | 167 | + | 
|  | 168 | +    return updated_messages if any_updates else messages | 
|  | 169 | + | 
|  | 170 | + | 
|  | 171 | +async def _ensure_base64_images_async( | 
|  | 172 | +    messages: Sequence[Message], | 
|  | 173 | +) -> Sequence[Message]: | 
|  | 174 | +    """Convert URL-sourced images to base64 for Vertex AI (async).""" | 
|  | 175 | +    updated_messages: list[Message] = [] | 
|  | 176 | +    any_updates = False | 
|  | 177 | + | 
|  | 178 | +    for message in messages: | 
|  | 179 | +        if isinstance(message, UserMessage): | 
|  | 180 | +            converted_content = [] | 
|  | 181 | +            content_changed = False | 
|  | 182 | +            for part in message.content: | 
|  | 183 | +                if isinstance(part, Image) and part.source.type == "url_image_source": | 
|  | 184 | +                    converted_content.append( | 
|  | 185 | +                        await Image.download_async(part.source.url) | 
|  | 186 | +                    ) | 
|  | 187 | +                    content_changed = True | 
|  | 188 | +                else: | 
|  | 189 | +                    converted_content.append(part) | 
|  | 190 | +            if content_changed: | 
|  | 191 | +                message = replace(message, content=converted_content) | 
|  | 192 | +                any_updates = True | 
|  | 193 | +        updated_messages.append(message) | 
|  | 194 | + | 
|  | 195 | +    return updated_messages if any_updates else messages | 
0 commit comments