Skip to content

Commit 2af4db6

Browse files
Add Vercel AI Gateway provider (#2277)
Co-authored-by: Douwe Maan <douwe@pydantic.dev>
1 parent a0c3abb commit 2af4db6

File tree

10 files changed

+306
-1
lines changed

10 files changed

+306
-1
lines changed

docs/api/providers.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,4 +32,6 @@
3232

3333
::: pydantic_ai.providers.openrouter.OpenRouterProvider
3434

35+
::: pydantic_ai.providers.vercel.VercelProvider
36+
3537
::: pydantic_ai.providers.huggingface.HuggingFaceProvider

docs/models/index.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ In addition, many providers are compatible with the OpenAI API, and can be used
1919
- [Grok (xAI)](openai.md#grok-xai)
2020
- [Ollama](openai.md#ollama)
2121
- [OpenRouter](openai.md#openrouter)
22+
- [Vercel AI Gateway](openai.md#vercel-ai-gateway)
2223
- [Perplexity](openai.md#perplexity)
2324
- [Fireworks AI](openai.md#fireworks-ai)
2425
- [Together AI](openai.md#together-ai)

docs/models/openai.md

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -348,6 +348,41 @@ agent = Agent(model)
348348
...
349349
```
350350

351+
### Vercel AI Gateway
352+
353+
To use [Vercel's AI Gateway](https://vercel.com/docs/ai-gateway), first follow the [documentation](https://vercel.com/docs/ai-gateway) instructions on obtaining an API key or OIDC token.
354+
355+
You can set your credentials using one of these environment variables:
356+
357+
```bash
358+
export VERCEL_AI_GATEWAY_API_KEY='your-ai-gateway-api-key'
359+
# OR
360+
export VERCEL_OIDC_TOKEN='your-oidc-token'
361+
```
362+
363+
Once you have set the environment variable, you can use it with the [`VercelProvider`][pydantic_ai.providers.vercel.VercelProvider]:
364+
365+
```python
366+
from pydantic_ai import Agent
367+
from pydantic_ai.models.openai import OpenAIModel
368+
from pydantic_ai.providers.vercel import VercelProvider
369+
370+
# Uses environment variable automatically
371+
model = OpenAIModel(
372+
'anthropic/claude-4-sonnet',
373+
provider=VercelProvider(),
374+
)
375+
agent = Agent(model)
376+
377+
# Or pass the API key directly
378+
model = OpenAIModel(
379+
'anthropic/claude-4-sonnet',
380+
provider=VercelProvider(api_key='your-vercel-ai-gateway-api-key'),
381+
)
382+
agent = Agent(model)
383+
...
384+
```
385+
351386
### Grok (xAI)
352387

353388
Go to [xAI API Console](https://console.x.ai/) and create an API key.

pydantic_ai_slim/pydantic_ai/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -615,6 +615,7 @@ def infer_model(model: Model | KnownModelName | str) -> Model: # noqa: C901
615615
'deepseek',
616616
'azure',
617617
'openrouter',
618+
'vercel',
618619
'grok',
619620
'fireworks',
620621
'together',

pydantic_ai_slim/pydantic_ai/models/openai.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,16 @@ def __init__(
191191
model_name: OpenAIModelName,
192192
*,
193193
provider: Literal[
194-
'openai', 'deepseek', 'azure', 'openrouter', 'grok', 'fireworks', 'together', 'heroku', 'github'
194+
'openai',
195+
'deepseek',
196+
'azure',
197+
'openrouter',
198+
'vercel',
199+
'grok',
200+
'fireworks',
201+
'together',
202+
'heroku',
203+
'github',
195204
]
196205
| Provider[AsyncOpenAI] = 'openai',
197206
profile: ModelProfileSpec | None = None,

pydantic_ai_slim/pydantic_ai/providers/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,10 @@ def infer_provider_class(provider: str) -> type[Provider[Any]]: # noqa: C901
6262
from .openrouter import OpenRouterProvider
6363

6464
return OpenRouterProvider
65+
elif provider == 'vercel':
66+
from .vercel import VercelProvider
67+
68+
return VercelProvider
6569
elif provider == 'azure':
6670
from .azure import AzureProvider
6771

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
from __future__ import annotations as _annotations
2+
3+
import os
4+
from typing import overload
5+
6+
from httpx import AsyncClient as AsyncHTTPClient
7+
8+
from pydantic_ai.exceptions import UserError
9+
from pydantic_ai.models import cached_async_http_client
10+
from pydantic_ai.profiles import ModelProfile
11+
from pydantic_ai.profiles.amazon import amazon_model_profile
12+
from pydantic_ai.profiles.anthropic import anthropic_model_profile
13+
from pydantic_ai.profiles.cohere import cohere_model_profile
14+
from pydantic_ai.profiles.deepseek import deepseek_model_profile
15+
from pydantic_ai.profiles.google import google_model_profile
16+
from pydantic_ai.profiles.grok import grok_model_profile
17+
from pydantic_ai.profiles.mistral import mistral_model_profile
18+
from pydantic_ai.profiles.openai import OpenAIJsonSchemaTransformer, OpenAIModelProfile, openai_model_profile
19+
from pydantic_ai.providers import Provider
20+
21+
try:
22+
from openai import AsyncOpenAI
23+
except ImportError as _import_error: # pragma: no cover
24+
raise ImportError(
25+
'Please install the `openai` package to use the Vercel provider, '
26+
'you can use the `openai` optional group — `pip install "pydantic-ai-slim[openai]"`'
27+
) from _import_error
28+
29+
30+
class VercelProvider(Provider[AsyncOpenAI]):
31+
"""Provider for Vercel AI Gateway API."""
32+
33+
@property
34+
def name(self) -> str:
35+
return 'vercel'
36+
37+
@property
38+
def base_url(self) -> str:
39+
return 'https://ai-gateway.vercel.sh/v1'
40+
41+
@property
42+
def client(self) -> AsyncOpenAI:
43+
return self._client
44+
45+
def model_profile(self, model_name: str) -> ModelProfile | None:
46+
provider_to_profile = {
47+
'anthropic': anthropic_model_profile,
48+
'bedrock': amazon_model_profile,
49+
'cohere': cohere_model_profile,
50+
'deepseek': deepseek_model_profile,
51+
'mistral': mistral_model_profile,
52+
'openai': openai_model_profile,
53+
'vertex': google_model_profile,
54+
'xai': grok_model_profile,
55+
}
56+
57+
profile = None
58+
59+
try:
60+
provider, model_name = model_name.split('/', 1)
61+
except ValueError:
62+
raise UserError(f"Model name must be in 'provider/model' format, got: {model_name!r}")
63+
64+
if provider in provider_to_profile:
65+
profile = provider_to_profile[provider](model_name)
66+
67+
# As VercelProvider is always used with OpenAIModel, which used to unconditionally use OpenAIJsonSchemaTransformer,
68+
# we need to maintain that behavior unless json_schema_transformer is set explicitly
69+
return OpenAIModelProfile(
70+
json_schema_transformer=OpenAIJsonSchemaTransformer,
71+
).update(profile)
72+
73+
@overload
74+
def __init__(self) -> None: ...
75+
76+
@overload
77+
def __init__(self, *, api_key: str) -> None: ...
78+
79+
@overload
80+
def __init__(self, *, api_key: str, http_client: AsyncHTTPClient) -> None: ...
81+
82+
@overload
83+
def __init__(self, *, openai_client: AsyncOpenAI | None = None) -> None: ...
84+
85+
def __init__(
86+
self,
87+
*,
88+
api_key: str | None = None,
89+
openai_client: AsyncOpenAI | None = None,
90+
http_client: AsyncHTTPClient | None = None,
91+
) -> None:
92+
# Support Vercel AI Gateway's standard environment variables
93+
api_key = api_key or os.getenv('VERCEL_AI_GATEWAY_API_KEY') or os.getenv('VERCEL_OIDC_TOKEN')
94+
95+
if not api_key and openai_client is None:
96+
raise UserError(
97+
'Set the `VERCEL_AI_GATEWAY_API_KEY` or `VERCEL_OIDC_TOKEN` environment variable '
98+
'or pass the API key via `VercelProvider(api_key=...)` to use the Vercel provider.'
99+
)
100+
101+
if openai_client is not None:
102+
self._client = openai_client
103+
elif http_client is not None:
104+
self._client = AsyncOpenAI(base_url=self.base_url, api_key=api_key, http_client=http_client)
105+
else:
106+
http_client = cached_async_http_client(provider='vercel')
107+
self._client = AsyncOpenAI(base_url=self.base_url, api_key=api_key, http_client=http_client)

tests/providers/test_provider_names.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,12 +29,14 @@
2929
from pydantic_ai.providers.openai import OpenAIProvider
3030
from pydantic_ai.providers.openrouter import OpenRouterProvider
3131
from pydantic_ai.providers.together import TogetherProvider
32+
from pydantic_ai.providers.vercel import VercelProvider
3233

3334
test_infer_provider_params = [
3435
('anthropic', AnthropicProvider, 'ANTHROPIC_API_KEY'),
3536
('cohere', CohereProvider, 'CO_API_KEY'),
3637
('deepseek', DeepSeekProvider, 'DEEPSEEK_API_KEY'),
3738
('openrouter', OpenRouterProvider, 'OPENROUTER_API_KEY'),
39+
('vercel', VercelProvider, 'VERCEL_AI_GATEWAY_API_KEY'),
3840
('openai', OpenAIProvider, 'OPENAI_API_KEY'),
3941
('azure', AzureProvider, 'AZURE_OPENAI'),
4042
('google-vertex', GoogleVertexProvider, None),

tests/providers/test_vercel.py

Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
import re
2+
3+
import httpx
4+
import pytest
5+
from pytest_mock import MockerFixture
6+
7+
from pydantic_ai.exceptions import UserError
8+
from pydantic_ai.profiles._json_schema import InlineDefsJsonSchemaTransformer
9+
from pydantic_ai.profiles.amazon import amazon_model_profile
10+
from pydantic_ai.profiles.anthropic import anthropic_model_profile
11+
from pydantic_ai.profiles.cohere import cohere_model_profile
12+
from pydantic_ai.profiles.deepseek import deepseek_model_profile
13+
from pydantic_ai.profiles.google import GoogleJsonSchemaTransformer, google_model_profile
14+
from pydantic_ai.profiles.grok import grok_model_profile
15+
from pydantic_ai.profiles.mistral import mistral_model_profile
16+
from pydantic_ai.profiles.openai import OpenAIJsonSchemaTransformer, openai_model_profile
17+
18+
from ..conftest import TestEnv, try_import
19+
20+
with try_import() as imports_successful:
21+
import openai
22+
23+
from pydantic_ai.providers.vercel import VercelProvider
24+
25+
26+
pytestmark = [
27+
pytest.mark.skipif(not imports_successful(), reason='openai not installed'),
28+
pytest.mark.vcr,
29+
pytest.mark.anyio,
30+
]
31+
32+
33+
def test_vercel_provider():
34+
provider = VercelProvider(api_key='api-key')
35+
assert provider.name == 'vercel'
36+
assert provider.base_url == 'https://ai-gateway.vercel.sh/v1'
37+
assert isinstance(provider.client, openai.AsyncOpenAI)
38+
assert provider.client.api_key == 'api-key'
39+
40+
41+
def test_vercel_provider_need_api_key(env: TestEnv) -> None:
42+
env.remove('VERCEL_AI_GATEWAY_API_KEY')
43+
env.remove('VERCEL_OIDC_TOKEN')
44+
with pytest.raises(
45+
UserError,
46+
match=re.escape(
47+
'Set the `VERCEL_AI_GATEWAY_API_KEY` or `VERCEL_OIDC_TOKEN` environment variable '
48+
'or pass the API key via `VercelProvider(api_key=...)` to use the Vercel provider.'
49+
),
50+
):
51+
VercelProvider()
52+
53+
54+
def test_vercel_pass_openai_client() -> None:
55+
openai_client = openai.AsyncOpenAI(api_key='api-key')
56+
provider = VercelProvider(openai_client=openai_client)
57+
assert provider.client == openai_client
58+
59+
60+
def test_vercel_provider_model_profile(mocker: MockerFixture):
61+
provider = VercelProvider(api_key='api-key')
62+
63+
ns = 'pydantic_ai.providers.vercel'
64+
65+
# Mock all profile functions
66+
anthropic_mock = mocker.patch(f'{ns}.anthropic_model_profile', wraps=anthropic_model_profile)
67+
amazon_mock = mocker.patch(f'{ns}.amazon_model_profile', wraps=amazon_model_profile)
68+
cohere_mock = mocker.patch(f'{ns}.cohere_model_profile', wraps=cohere_model_profile)
69+
deepseek_mock = mocker.patch(f'{ns}.deepseek_model_profile', wraps=deepseek_model_profile)
70+
google_mock = mocker.patch(f'{ns}.google_model_profile', wraps=google_model_profile)
71+
grok_mock = mocker.patch(f'{ns}.grok_model_profile', wraps=grok_model_profile)
72+
mistral_mock = mocker.patch(f'{ns}.mistral_model_profile', wraps=mistral_model_profile)
73+
openai_mock = mocker.patch(f'{ns}.openai_model_profile', wraps=openai_model_profile)
74+
75+
# Test openai provider
76+
profile = provider.model_profile('openai/gpt-4o')
77+
openai_mock.assert_called_with('gpt-4o')
78+
assert profile is not None
79+
assert profile.json_schema_transformer == OpenAIJsonSchemaTransformer
80+
81+
# Test anthropic provider
82+
profile = provider.model_profile('anthropic/claude-3-sonnet')
83+
anthropic_mock.assert_called_with('claude-3-sonnet')
84+
assert profile is not None
85+
assert profile.json_schema_transformer == OpenAIJsonSchemaTransformer
86+
87+
# Test bedrock provider
88+
profile = provider.model_profile('bedrock/anthropic.claude-3-sonnet')
89+
amazon_mock.assert_called_with('anthropic.claude-3-sonnet')
90+
assert profile is not None
91+
assert profile.json_schema_transformer == InlineDefsJsonSchemaTransformer
92+
93+
# Test cohere provider
94+
profile = provider.model_profile('cohere/command-r-plus')
95+
cohere_mock.assert_called_with('command-r-plus')
96+
assert profile is not None
97+
assert profile.json_schema_transformer == OpenAIJsonSchemaTransformer
98+
99+
# Test deepseek provider
100+
profile = provider.model_profile('deepseek/deepseek-chat')
101+
deepseek_mock.assert_called_with('deepseek-chat')
102+
assert profile is not None
103+
assert profile.json_schema_transformer == OpenAIJsonSchemaTransformer
104+
105+
# Test mistral provider
106+
profile = provider.model_profile('mistral/mistral-large')
107+
mistral_mock.assert_called_with('mistral-large')
108+
assert profile is not None
109+
assert profile.json_schema_transformer == OpenAIJsonSchemaTransformer
110+
111+
# Test vertex provider
112+
profile = provider.model_profile('vertex/gemini-1.5-pro')
113+
google_mock.assert_called_with('gemini-1.5-pro')
114+
assert profile is not None
115+
assert profile.json_schema_transformer == GoogleJsonSchemaTransformer
116+
117+
# Test xai provider
118+
profile = provider.model_profile('xai/grok-beta')
119+
grok_mock.assert_called_with('grok-beta')
120+
assert profile is not None
121+
assert profile.json_schema_transformer == OpenAIJsonSchemaTransformer
122+
123+
124+
def test_vercel_with_http_client():
125+
http_client = httpx.AsyncClient()
126+
provider = VercelProvider(api_key='test-key', http_client=http_client)
127+
assert provider.client.api_key == 'test-key'
128+
assert str(provider.client.base_url) == 'https://ai-gateway.vercel.sh/v1/'
129+
130+
131+
def test_vercel_provider_invalid_model_name():
132+
provider = VercelProvider(api_key='api-key')
133+
134+
with pytest.raises(UserError, match="Model name must be in 'provider/model' format"):
135+
provider.model_profile('invalid-model-name')
136+
137+
138+
def test_vercel_provider_unknown_provider():
139+
provider = VercelProvider(api_key='api-key')
140+
141+
profile = provider.model_profile('unknown/gpt-4')
142+
assert profile is not None
143+
assert profile.json_schema_transformer == OpenAIJsonSchemaTransformer

tests/test_examples.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,7 @@ def print(self, *args: Any, **kwargs: Any) -> None:
156156
env.set('AWS_ACCESS_KEY_ID', 'testing')
157157
env.set('AWS_SECRET_ACCESS_KEY', 'testing')
158158
env.set('AWS_DEFAULT_REGION', 'us-east-1')
159+
env.set('VERCEL_AI_GATEWAY_API_KEY', 'testing')
159160

160161
prefix_settings = example.prefix_settings()
161162
opt_test = prefix_settings.get('test', '')

0 commit comments

Comments
 (0)