Skip to content

Commit dab34b7

Browse files
authored
Add support for HerokuProvider (#1933)
1 parent 41130b5 commit dab34b7

File tree

18 files changed

+370
-18
lines changed

18 files changed

+370
-18
lines changed

docs/api/providers.md

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,14 @@
1818

1919
::: pydantic_ai.providers.cohere
2020

21-
::: pydantic_ai.providers.mistral
21+
::: pydantic_ai.providers.mistral.MistralProvider
2222

23-
::: pydantic_ai.providers.fireworks
23+
::: pydantic_ai.providers.fireworks.FireworksProvider
2424

25-
::: pydantic_ai.providers.grok
25+
::: pydantic_ai.providers.grok.GrokProvider
2626

27-
::: pydantic_ai.providers.together
27+
::: pydantic_ai.providers.together.TogetherProvider
28+
29+
::: pydantic_ai.providers.heroku.HerokuProvider
30+
31+
::: pydantic_ai.providers.openrouter.OpenRouterProvider

docs/models/index.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ In addition, many providers are compatible with the OpenAI API, and can be used
2222
* [Fireworks AI](openai.md#fireworks-ai)
2323
* [Together AI](openai.md#together-ai)
2424
* [Azure AI Foundry](openai.md#azure-ai-foundry)
25+
* [Heroku](openai.md#heroku-ai)
2526

2627
PydanticAI also comes with [`TestModel`](../api/models/test.md) and [`FunctionModel`](../api/models/function.md)
2728
for testing and development.

docs/models/openai.md

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ agent = Agent(model)
183183
### DeepSeek
184184

185185
To use the [DeepSeek](https://deepseek.com) provider, first create an API key by following the [Quick Start guide](https://api-docs.deepseek.com/).
186-
Once you have the API key, you can use it with the `DeepSeekProvider`:
186+
Once you have the API key, you can use it with the [`DeepSeekProvider`][pydantic_ai.providers.deepseek.DeepSeekProvider]:
187187

188188
```python
189189
from pydantic_ai import Agent
@@ -295,7 +295,8 @@ print(result.usage())
295295

296296
### Azure AI Foundry
297297

298-
If you want to use [Azure AI Foundry](https://ai.azure.com/) as your provider, you can do so by using the `AzureProvider` class.
298+
If you want to use [Azure AI Foundry](https://ai.azure.com/) as your provider, you can do so by using the
299+
[`AzureProvider`][pydantic_ai.providers.azure.AzureProvider] class.
299300

300301
```python
301302
from pydantic_ai import Agent
@@ -318,7 +319,7 @@ agent = Agent(model)
318319

319320
To use [OpenRouter](https://openrouter.ai), first create an API key at [openrouter.ai/keys](https://openrouter.ai/keys).
320321

321-
Once you have the API key, you can use it with the `OpenRouterProvider`:
322+
Once you have the API key, you can use it with the [`OpenRouterProvider`][pydantic_ai.providers.openrouter.OpenRouterProvider]:
322323

323324
```python
324325
from pydantic_ai import Agent
@@ -336,7 +337,7 @@ agent = Agent(model)
336337
### Grok (xAI)
337338

338339
Go to [xAI API Console](https://console.x.ai/) and create an API key.
339-
Once you have the API key, you can use it with the `GrokProvider`:
340+
Once you have the API key, you can use it with the [`GrokProvider`][pydantic_ai.providers.grok.GrokProvider]:
340341

341342
```python
342343
from pydantic_ai import Agent
@@ -375,7 +376,7 @@ agent = Agent(model)
375376
### Fireworks AI
376377

377378
Go to [Fireworks.AI](https://fireworks.ai/) and create an API key in your account settings.
378-
Once you have the API key, you can use it with the `FireworksProvider`:
379+
Once you have the API key, you can use it with the [`FireworksProvider`][pydantic_ai.providers.fireworks.FireworksProvider]:
379380

380381
```python
381382
from pydantic_ai import Agent
@@ -393,7 +394,7 @@ agent = Agent(model)
393394
### Together AI
394395

395396
Go to [Together.ai](https://www.together.ai/) and create an API key in your account settings.
396-
Once you have the API key, you can use it with the `TogetherProvider`:
397+
Once you have the API key, you can use it with the [`TogetherProvider`][pydantic_ai.providers.together.TogetherProvider]:
397398

398399
```python
399400
from pydantic_ai import Agent
@@ -407,3 +408,27 @@ model = OpenAIModel(
407408
agent = Agent(model)
408409
...
409410
```
411+
412+
### Heroku AI
413+
414+
To use [Heroku AI](https://www.heroku.com/ai), you can use the [`HerokuProvider`][pydantic_ai.providers.heroku.HerokuProvider]:
415+
416+
```python
417+
from pydantic_ai import Agent
418+
from pydantic_ai.models.openai import OpenAIModel
419+
from pydantic_ai.providers.heroku import HerokuProvider
420+
421+
model = OpenAIModel(
422+
'claude-3-7-sonnet',
423+
provider=HerokuProvider(api_key='your-heroku-inference-key'),
424+
)
425+
agent = Agent(model)
426+
...
427+
```
428+
429+
You can set the `HEROKU_INFERENCE_KEY` and `HEROKU_INFERENCE_URL` environment variables to set the API key and base URL, respectively:
430+
431+
```bash
432+
export HEROKU_INFERENCE_KEY='your-heroku-inference-key'
433+
export HEROKU_INFERENCE_URL='https://us.inference.heroku.com'
434+
```

pydantic_ai_slim/pydantic_ai/models/__init__.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,11 @@
211211
'groq:llama-3.2-3b-preview',
212212
'groq:llama-3.2-11b-vision-preview',
213213
'groq:llama-3.2-90b-vision-preview',
214+
'heroku:claude-3-5-haiku',
215+
'heroku:claude-3-5-sonnet-latest',
216+
'heroku:claude-3-7-sonnet',
217+
'heroku:claude-4-sonnet',
218+
'heroku:claude-3-haiku',
214219
'mistral:codestral-latest',
215220
'mistral:mistral-large-latest',
216221
'mistral:mistral-moderation-latest',
@@ -543,7 +548,7 @@ def infer_model(model: Model | KnownModelName | str) -> Model:
543548
from .cohere import CohereModel
544549

545550
return CohereModel(model_name, provider=provider)
546-
elif provider in ('openai', 'deepseek', 'azure', 'openrouter', 'grok', 'fireworks', 'together'):
551+
elif provider in ('openai', 'deepseek', 'azure', 'openrouter', 'grok', 'fireworks', 'together', 'heroku'):
547552
from .openai import OpenAIModel
548553

549554
return OpenAIModel(model_name, provider=provider)

pydantic_ai_slim/pydantic_ai/models/gemini.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,7 @@ async def _make_request(
228228

229229
if gemini_labels := model_settings.get('gemini_labels'):
230230
if self._system == 'google-vertex':
231-
request_data['labels'] = gemini_labels
231+
request_data['labels'] = gemini_labels # pragma: lax no cover
232232

233233
headers = {'Content-Type': 'application/json', 'User-Agent': get_user_agent()}
234234
url = f'/{self._model_name}:{"streamGenerateContent" if streamed else "generateContent"}'

pydantic_ai_slim/pydantic_ai/models/openai.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,7 @@ def __init__(
170170
self,
171171
model_name: OpenAIModelName,
172172
*,
173-
provider: Literal['openai', 'deepseek', 'azure', 'openrouter', 'grok', 'fireworks', 'together']
173+
provider: Literal['openai', 'deepseek', 'azure', 'openrouter', 'grok', 'fireworks', 'together', 'heroku']
174174
| Provider[AsyncOpenAI] = 'openai',
175175
profile: ModelProfileSpec | None = None,
176176
system_prompt_role: OpenAISystemPromptRole | None = None,

pydantic_ai_slim/pydantic_ai/providers/__init__.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def model_profile(self, model_name: str) -> ModelProfile | None:
4848
return None # pragma: no cover
4949

5050

51-
def infer_provider(provider: str) -> Provider[Any]:
51+
def infer_provider(provider: str) -> Provider[Any]: # noqa: C901
5252
"""Infer the provider from the provider name."""
5353
if provider == 'openai':
5454
from .openai import OpenAIProvider
@@ -107,5 +107,9 @@ def infer_provider(provider: str) -> Provider[Any]:
107107
from .together import TogetherProvider
108108

109109
return TogetherProvider()
110+
elif provider == 'heroku':
111+
from .heroku import HerokuProvider
112+
113+
return HerokuProvider()
110114
else: # pragma: no cover
111115
raise ValueError(f'Unknown provider: {provider}')

pydantic_ai_slim/pydantic_ai/providers/google_vertex.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def client(self) -> httpx.AsyncClient:
5050
return self._client
5151

5252
def model_profile(self, model_name: str) -> ModelProfile | None:
53-
return google_model_profile(model_name)
53+
return google_model_profile(model_name) # pragma: lax no cover
5454

5555
@overload
5656
def __init__(
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
from __future__ import annotations as _annotations
2+
3+
import os
4+
from typing import overload
5+
6+
from httpx import AsyncClient as AsyncHTTPClient
7+
from openai import AsyncOpenAI
8+
9+
from pydantic_ai.exceptions import UserError
10+
from pydantic_ai.models import cached_async_http_client
11+
from pydantic_ai.profiles import ModelProfile
12+
from pydantic_ai.profiles.openai import OpenAIJsonSchemaTransformer, OpenAIModelProfile
13+
from pydantic_ai.providers import Provider
14+
15+
try:
16+
from openai import AsyncOpenAI
17+
except ImportError as _import_error: # pragma: no cover
18+
raise ImportError(
19+
'Please install the `openai` package to use the Heroku provider, '
20+
'you can use the `openai` optional group — `pip install "pydantic-ai-slim[openai]"`'
21+
) from _import_error
22+
23+
24+
class HerokuProvider(Provider[AsyncOpenAI]):
25+
"""Provider for Heroku API."""
26+
27+
@property
28+
def name(self) -> str:
29+
return 'heroku'
30+
31+
@property
32+
def base_url(self) -> str:
33+
return str(self.client.base_url)
34+
35+
@property
36+
def client(self) -> AsyncOpenAI:
37+
return self._client
38+
39+
def model_profile(self, model_name: str) -> ModelProfile | None:
40+
# As the Heroku API is OpenAI-compatible, let's assume we also need OpenAIJsonSchemaTransformer.
41+
return OpenAIModelProfile(json_schema_transformer=OpenAIJsonSchemaTransformer)
42+
43+
@overload
44+
def __init__(self) -> None: ...
45+
46+
@overload
47+
def __init__(self, *, api_key: str) -> None: ...
48+
49+
@overload
50+
def __init__(self, *, api_key: str, http_client: AsyncHTTPClient) -> None: ...
51+
52+
@overload
53+
def __init__(self, *, openai_client: AsyncOpenAI | None = None) -> None: ...
54+
55+
def __init__(
56+
self,
57+
*,
58+
base_url: str | None = None,
59+
api_key: str | None = None,
60+
openai_client: AsyncOpenAI | None = None,
61+
http_client: AsyncHTTPClient | None = None,
62+
) -> None:
63+
if openai_client is not None:
64+
assert http_client is None, 'Cannot provide both `openai_client` and `http_client`'
65+
assert api_key is None, 'Cannot provide both `openai_client` and `api_key`'
66+
self._client = openai_client
67+
else:
68+
api_key = api_key or os.environ.get('HEROKU_INFERENCE_KEY')
69+
if not api_key:
70+
raise UserError(
71+
'Set the `HEROKU_INFERENCE_KEY` environment variable or pass it via `HerokuProvider(api_key=...)`'
72+
'to use the Heroku provider.'
73+
)
74+
75+
base_url = base_url or os.environ.get('HEROKU_INFERENCE_URL', 'https://us.inference.heroku.com')
76+
base_url = base_url.rstrip('/') + '/v1'
77+
78+
if http_client is not None:
79+
self._client = AsyncOpenAI(api_key=api_key, http_client=http_client, base_url=base_url)
80+
else:
81+
http_client = cached_async_http_client(provider='heroku')
82+
self._client = AsyncOpenAI(api_key=api_key, http_client=http_client, base_url=base_url)

tests/conftest.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -281,6 +281,11 @@ def openrouter_api_key() -> str:
281281
return os.getenv('OPENROUTER_API_KEY', 'mock-api-key')
282282

283283

284+
@pytest.fixture(scope='session')
285+
def heroku_inference_key() -> str:
286+
return os.getenv('HEROKU_INFERENCE_KEY', 'mock-api-key')
287+
288+
284289
@pytest.fixture(scope='session')
285290
def bedrock_provider():
286291
try:

0 commit comments

Comments
 (0)