Skip to content

Commit 18a86c7

Browse files
authored
Add new provider classes for Together AI, Fireworks AI, and Grok with automatic model profile selection (#1842)
1 parent 8637608 commit 18a86c7

File tree

12 files changed

+576
-18
lines changed

12 files changed

+576
-18
lines changed

docs/api/providers.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,3 +19,9 @@
1919
::: pydantic_ai.providers.cohere
2020

2121
::: pydantic_ai.providers.mistral
22+
23+
::: pydantic_ai.providers.fireworks
24+
25+
::: pydantic_ai.providers.grok
26+
27+
::: pydantic_ai.providers.together

docs/models/openai.md

Lines changed: 9 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -336,16 +336,16 @@ agent = Agent(model)
336336
### Grok (xAI)
337337

338338
Go to [xAI API Console](https://console.x.ai/) and create an API key.
339-
Once you have the API key, you can use it with the `OpenAIProvider`:
339+
Once you have the API key, you can use it with the `GrokProvider`:
340340

341341
```python
342342
from pydantic_ai import Agent
343343
from pydantic_ai.models.openai import OpenAIModel
344-
from pydantic_ai.providers.openai import OpenAIProvider
344+
from pydantic_ai.providers.grok import GrokProvider
345345

346346
model = OpenAIModel(
347347
'grok-2-1212',
348-
provider=OpenAIProvider(base_url='https://api.x.ai/v1', api_key='your-xai-api-key'),
348+
provider=GrokProvider(api_key='your-xai-api-key'),
349349
)
350350
agent = Agent(model)
351351
...
@@ -375,19 +375,16 @@ agent = Agent(model)
375375
### Fireworks AI
376376

377377
Go to [Fireworks.AI](https://fireworks.ai/) and create an API key in your account settings.
378-
Once you have the API key, you can use it with the `OpenAIProvider`:
378+
Once you have the API key, you can use it with the `FireworksProvider`:
379379

380380
```python
381381
from pydantic_ai import Agent
382382
from pydantic_ai.models.openai import OpenAIModel
383-
from pydantic_ai.providers.openai import OpenAIProvider
383+
from pydantic_ai.providers.fireworks import FireworksProvider
384384

385385
model = OpenAIModel(
386386
'accounts/fireworks/models/qwq-32b', # model library available at https://fireworks.ai/models
387-
provider=OpenAIProvider(
388-
base_url='https://api.fireworks.ai/inference/v1',
389-
api_key='your-fireworks-api-key',
390-
),
387+
provider=FireworksProvider(api_key='your-fireworks-api-key'),
391388
)
392389
agent = Agent(model)
393390
...
@@ -396,19 +393,16 @@ agent = Agent(model)
396393
### Together AI
397394

398395
Go to [Together.ai](https://www.together.ai/) and create an API key in your account settings.
399-
Once you have the API key, you can use it with the `OpenAIProvider`:
396+
Once you have the API key, you can use it with the `TogetherProvider`:
400397

401398
```python
402399
from pydantic_ai import Agent
403400
from pydantic_ai.models.openai import OpenAIModel
404-
from pydantic_ai.providers.openai import OpenAIProvider
401+
from pydantic_ai.providers.together import TogetherProvider
405402

406403
model = OpenAIModel(
407404
'meta-llama/Llama-3.3-70B-Instruct-Turbo-Free', # model library available at https://www.together.ai/models
408-
provider=OpenAIProvider(
409-
base_url='https://api.together.xyz/v1',
410-
api_key='your-together-api-key',
411-
),
405+
provider=TogetherProvider(api_key='your-together-api-key'),
412406
)
413407
agent = Agent(model)
414408
...

pydantic_ai_slim/pydantic_ai/models/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -543,7 +543,7 @@ def infer_model(model: Model | KnownModelName | str) -> Model:
543543
from .cohere import CohereModel
544544

545545
return CohereModel(model_name, provider=provider)
546-
elif provider in ('deepseek', 'openai', 'azure', 'openrouter'):
546+
elif provider in ('openai', 'deepseek', 'azure', 'openrouter', 'grok', 'fireworks', 'together'):
547547
from .openai import OpenAIModel
548548

549549
return OpenAIModel(model_name, provider=provider)

pydantic_ai_slim/pydantic_ai/models/openai.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,8 @@ def __init__(
170170
self,
171171
model_name: OpenAIModelName,
172172
*,
173-
provider: Literal['openai', 'deepseek', 'azure', 'openrouter'] | Provider[AsyncOpenAI] = 'openai',
173+
provider: Literal['openai', 'deepseek', 'azure', 'openrouter', 'grok', 'fireworks', 'together']
174+
| Provider[AsyncOpenAI] = 'openai',
174175
profile: ModelProfileSpec | None = None,
175176
system_prompt_role: OpenAISystemPromptRole | None = None,
176177
):
@@ -534,7 +535,8 @@ def __init__(
534535
self,
535536
model_name: OpenAIModelName,
536537
*,
537-
provider: Literal['openai', 'deepseek', 'azure', 'openrouter'] | Provider[AsyncOpenAI] = 'openai',
538+
provider: Literal['openai', 'deepseek', 'azure', 'openrouter', 'grok', 'fireworks', 'together']
539+
| Provider[AsyncOpenAI] = 'openai',
538540
profile: ModelProfileSpec | None = None,
539541
):
540542
"""Initialize an OpenAI Responses model.

pydantic_ai_slim/pydantic_ai/providers/__init__.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,5 +95,17 @@ def infer_provider(provider: str) -> Provider[Any]:
9595
from .cohere import CohereProvider
9696

9797
return CohereProvider()
98+
elif provider == 'grok':
99+
from .grok import GrokProvider
100+
101+
return GrokProvider()
102+
elif provider == 'fireworks':
103+
from .fireworks import FireworksProvider
104+
105+
return FireworksProvider()
106+
elif provider == 'together':
107+
from .together import TogetherProvider
108+
109+
return TogetherProvider()
98110
else: # pragma: no cover
99111
raise ValueError(f'Unknown provider: {provider}')
Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
from __future__ import annotations as _annotations
2+
3+
import os
4+
from typing import overload
5+
6+
from httpx import AsyncClient as AsyncHTTPClient
7+
from openai import AsyncOpenAI
8+
9+
from pydantic_ai.exceptions import UserError
10+
from pydantic_ai.models import cached_async_http_client
11+
from pydantic_ai.profiles import ModelProfile
12+
from pydantic_ai.profiles.deepseek import deepseek_model_profile
13+
from pydantic_ai.profiles.google import google_model_profile
14+
from pydantic_ai.profiles.meta import meta_model_profile
15+
from pydantic_ai.profiles.mistral import mistral_model_profile
16+
from pydantic_ai.profiles.openai import OpenAIJsonSchemaTransformer, OpenAIModelProfile
17+
from pydantic_ai.profiles.qwen import qwen_model_profile
18+
from pydantic_ai.providers import Provider
19+
20+
try:
21+
from openai import AsyncOpenAI
22+
except ImportError as _import_error: # pragma: no cover
23+
raise ImportError(
24+
'Please install the `openai` package to use the Fireworks AI provider, '
25+
'you can use the `openai` optional group — `pip install "pydantic-ai-slim[openai]"`'
26+
) from _import_error
27+
28+
29+
class FireworksProvider(Provider[AsyncOpenAI]):
30+
"""Provider for Fireworks AI API."""
31+
32+
@property
33+
def name(self) -> str:
34+
return 'fireworks'
35+
36+
@property
37+
def base_url(self) -> str:
38+
return 'https://api.fireworks.ai/inference/v1'
39+
40+
@property
41+
def client(self) -> AsyncOpenAI:
42+
return self._client
43+
44+
def model_profile(self, model_name: str) -> ModelProfile | None:
45+
prefix_to_profile = {
46+
'llama': meta_model_profile,
47+
'qwen': qwen_model_profile,
48+
'deepseek': deepseek_model_profile,
49+
'mistral': mistral_model_profile,
50+
'gemma': google_model_profile,
51+
}
52+
53+
prefix = 'accounts/fireworks/models/'
54+
55+
profile = None
56+
if model_name.startswith(prefix):
57+
model_name = model_name[len(prefix) :]
58+
for provider, profile_func in prefix_to_profile.items():
59+
if model_name.startswith(provider):
60+
profile = profile_func(model_name)
61+
break
62+
63+
# As the Fireworks API is OpenAI-compatible, let's assume we also need OpenAIJsonSchemaTransformer,
64+
# unless json_schema_transformer is set explicitly
65+
return OpenAIModelProfile(json_schema_transformer=OpenAIJsonSchemaTransformer).update(profile)
66+
67+
@overload
68+
def __init__(self) -> None: ...
69+
70+
@overload
71+
def __init__(self, *, api_key: str) -> None: ...
72+
73+
@overload
74+
def __init__(self, *, api_key: str, http_client: AsyncHTTPClient) -> None: ...
75+
76+
@overload
77+
def __init__(self, *, openai_client: AsyncOpenAI | None = None) -> None: ...
78+
79+
def __init__(
80+
self,
81+
*,
82+
api_key: str | None = None,
83+
openai_client: AsyncOpenAI | None = None,
84+
http_client: AsyncHTTPClient | None = None,
85+
) -> None:
86+
api_key = api_key or os.getenv('FIREWORKS_API_KEY')
87+
if not api_key and openai_client is None:
88+
raise UserError(
89+
'Set the `FIREWORKS_API_KEY` environment variable or pass it via `FireworksProvider(api_key=...)`'
90+
'to use the Fireworks AI provider.'
91+
)
92+
93+
if openai_client is not None:
94+
self._client = openai_client
95+
elif http_client is not None:
96+
self._client = AsyncOpenAI(base_url=self.base_url, api_key=api_key, http_client=http_client)
97+
else:
98+
http_client = cached_async_http_client(provider='fireworks')
99+
self._client = AsyncOpenAI(base_url=self.base_url, api_key=api_key, http_client=http_client)
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
from __future__ import annotations as _annotations
2+
3+
import os
4+
from typing import overload
5+
6+
from httpx import AsyncClient as AsyncHTTPClient
7+
from openai import AsyncOpenAI
8+
9+
from pydantic_ai.exceptions import UserError
10+
from pydantic_ai.models import cached_async_http_client
11+
from pydantic_ai.profiles import ModelProfile
12+
from pydantic_ai.profiles.grok import grok_model_profile
13+
from pydantic_ai.profiles.openai import OpenAIJsonSchemaTransformer, OpenAIModelProfile
14+
from pydantic_ai.providers import Provider
15+
16+
try:
17+
from openai import AsyncOpenAI
18+
except ImportError as _import_error: # pragma: no cover
19+
raise ImportError(
20+
'Please install the `openai` package to use the Grok provider, '
21+
'you can use the `openai` optional group — `pip install "pydantic-ai-slim[openai]"`'
22+
) from _import_error
23+
24+
25+
class GrokProvider(Provider[AsyncOpenAI]):
26+
"""Provider for Grok API."""
27+
28+
@property
29+
def name(self) -> str:
30+
return 'grok'
31+
32+
@property
33+
def base_url(self) -> str:
34+
return 'https://api.x.ai/v1'
35+
36+
@property
37+
def client(self) -> AsyncOpenAI:
38+
return self._client
39+
40+
def model_profile(self, model_name: str) -> ModelProfile | None:
41+
profile = grok_model_profile(model_name)
42+
43+
# As the Grok API is OpenAI-compatible, let's assume we also need OpenAIJsonSchemaTransformer,
44+
# unless json_schema_transformer is set explicitly.
45+
# Also, Grok does not support strict tool definitions: https://github.com/pydantic/pydantic-ai/issues/1846
46+
return OpenAIModelProfile(
47+
json_schema_transformer=OpenAIJsonSchemaTransformer, openai_supports_strict_tool_definition=False
48+
).update(profile)
49+
50+
@overload
51+
def __init__(self) -> None: ...
52+
53+
@overload
54+
def __init__(self, *, api_key: str) -> None: ...
55+
56+
@overload
57+
def __init__(self, *, api_key: str, http_client: AsyncHTTPClient) -> None: ...
58+
59+
@overload
60+
def __init__(self, *, openai_client: AsyncOpenAI | None = None) -> None: ...
61+
62+
def __init__(
63+
self,
64+
*,
65+
api_key: str | None = None,
66+
openai_client: AsyncOpenAI | None = None,
67+
http_client: AsyncHTTPClient | None = None,
68+
) -> None:
69+
api_key = api_key or os.getenv('GROK_API_KEY')
70+
if not api_key and openai_client is None:
71+
raise UserError(
72+
'Set the `GROK_API_KEY` environment variable or pass it via `GrokProvider(api_key=...)`'
73+
'to use the Grok provider.'
74+
)
75+
76+
if openai_client is not None:
77+
self._client = openai_client
78+
elif http_client is not None:
79+
self._client = AsyncOpenAI(base_url=self.base_url, api_key=api_key, http_client=http_client)
80+
else:
81+
http_client = cached_async_http_client(provider='grok')
82+
self._client = AsyncOpenAI(base_url=self.base_url, api_key=api_key, http_client=http_client)

0 commit comments

Comments
 (0)