Skip to content

Commit cbbd966

Browse files
authored
Add GitHub Models provider (#2114)
1 parent e4e4999 commit cbbd966

File tree

8 files changed

+279
-2
lines changed

8 files changed

+279
-2
lines changed

docs/models/openai.md

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -366,6 +366,33 @@ agent = Agent(model)
366366
...
367367
```
368368

369+
### GitHub Models
370+
371+
To use [GitHub Models](https://docs.github.com/en/github-models), you'll need a GitHub personal access token with the `models: read` permission.
372+
373+
Once you have the token, you can use it with the [`GitHubProvider`][pydantic_ai.providers.github.GitHubProvider]:
374+
375+
```python
376+
from pydantic_ai import Agent
377+
from pydantic_ai.models.openai import OpenAIModel
378+
from pydantic_ai.providers.github import GitHubProvider
379+
380+
model = OpenAIModel(
381+
'xai/grok-3-mini', # GitHub Models uses prefixed model names
382+
provider=GitHubProvider(api_key='your-github-token'),
383+
)
384+
agent = Agent(model)
385+
...
386+
```
387+
388+
You can also set the `GITHUB_API_KEY` environment variable:
389+
390+
```bash
391+
export GITHUB_API_KEY='your-github-token'
392+
```
393+
394+
GitHub Models supports various model families with different prefixes. You can see the full list on the [GitHub Marketplace](https://github.com/marketplace?type=models) or the public [catalog endpoint](https://models.github.ai/catalog/models).
395+
369396
### Perplexity
370397

371398
Follow the Perplexity [getting started](https://docs.perplexity.ai/guides/getting-started)

pydantic_ai_slim/pydantic_ai/models/__init__.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -569,7 +569,17 @@ def infer_model(model: Model | KnownModelName | str) -> Model:
569569
from .cohere import CohereModel
570570

571571
return CohereModel(model_name, provider=provider)
572-
elif provider in ('openai', 'deepseek', 'azure', 'openrouter', 'grok', 'fireworks', 'together', 'heroku'):
572+
elif provider in (
573+
'openai',
574+
'deepseek',
575+
'azure',
576+
'openrouter',
577+
'grok',
578+
'fireworks',
579+
'together',
580+
'heroku',
581+
'github',
582+
):
573583
from .openai import OpenAIModel
574584

575585
return OpenAIModel(model_name, provider=provider)

pydantic_ai_slim/pydantic_ai/models/openai.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,9 @@ def __init__(
190190
self,
191191
model_name: OpenAIModelName,
192192
*,
193-
provider: Literal['openai', 'deepseek', 'azure', 'openrouter', 'grok', 'fireworks', 'together', 'heroku']
193+
provider: Literal[
194+
'openai', 'deepseek', 'azure', 'openrouter', 'grok', 'fireworks', 'together', 'heroku', 'github'
195+
]
194196
| Provider[AsyncOpenAI] = 'openai',
195197
profile: ModelProfileSpec | None = None,
196198
system_prompt_role: OpenAISystemPromptRole | None = None,

pydantic_ai_slim/pydantic_ai/providers/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,10 @@ def infer_provider_class(provider: str) -> type[Provider[Any]]: # noqa: C901
111111
from .heroku import HerokuProvider
112112

113113
return HerokuProvider
114+
elif provider == 'github':
115+
from .github import GitHubProvider
116+
117+
return GitHubProvider
114118
else: # pragma: no cover
115119
raise ValueError(f'Unknown provider: {provider}')
116120

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
from __future__ import annotations as _annotations
2+
3+
import os
4+
from typing import overload
5+
6+
from httpx import AsyncClient as AsyncHTTPClient
7+
8+
from pydantic_ai.exceptions import UserError
9+
from pydantic_ai.models import cached_async_http_client
10+
from pydantic_ai.profiles import ModelProfile
11+
from pydantic_ai.profiles.cohere import cohere_model_profile
12+
from pydantic_ai.profiles.deepseek import deepseek_model_profile
13+
from pydantic_ai.profiles.grok import grok_model_profile
14+
from pydantic_ai.profiles.meta import meta_model_profile
15+
from pydantic_ai.profiles.mistral import mistral_model_profile
16+
from pydantic_ai.profiles.openai import OpenAIJsonSchemaTransformer, OpenAIModelProfile, openai_model_profile
17+
from pydantic_ai.providers import Provider
18+
19+
try:
20+
from openai import AsyncOpenAI
21+
except ImportError as _import_error: # pragma: no cover
22+
raise ImportError(
23+
'Please install the `openai` package to use the GitHub Models provider, '
24+
'you can use the `openai` optional group — `pip install "pydantic-ai-slim[openai]"`'
25+
) from _import_error
26+
27+
28+
class GitHubProvider(Provider[AsyncOpenAI]):
29+
"""Provider for GitHub Models API.
30+
31+
GitHub Models provides access to various AI models through an OpenAI-compatible API.
32+
See <https://docs.github.com/en/github-models> for more information.
33+
"""
34+
35+
@property
36+
def name(self) -> str:
37+
return 'github'
38+
39+
@property
40+
def base_url(self) -> str:
41+
return 'https://models.github.ai/inference'
42+
43+
@property
44+
def client(self) -> AsyncOpenAI:
45+
return self._client
46+
47+
def model_profile(self, model_name: str) -> ModelProfile | None:
48+
provider_to_profile = {
49+
'xai': grok_model_profile,
50+
'meta': meta_model_profile,
51+
'microsoft': openai_model_profile,
52+
'mistral-ai': mistral_model_profile,
53+
'cohere': cohere_model_profile,
54+
'deepseek': deepseek_model_profile,
55+
}
56+
57+
profile = None
58+
59+
# If the model name does not contain a provider prefix, we assume it's an OpenAI model
60+
if '/' not in model_name:
61+
return openai_model_profile(model_name)
62+
63+
provider, model_name = model_name.lower().split('/', 1)
64+
if provider in provider_to_profile:
65+
model_name, *_ = model_name.split(':', 1) # drop tags
66+
profile = provider_to_profile[provider](model_name)
67+
68+
# As GitHubProvider is always used with OpenAIModel, which used to unconditionally use OpenAIJsonSchemaTransformer,
69+
# we need to maintain that behavior unless json_schema_transformer is set explicitly
70+
return OpenAIModelProfile(json_schema_transformer=OpenAIJsonSchemaTransformer).update(profile)
71+
72+
@overload
73+
def __init__(self) -> None: ...
74+
75+
@overload
76+
def __init__(self, *, api_key: str) -> None: ...
77+
78+
@overload
79+
def __init__(self, *, api_key: str, http_client: AsyncHTTPClient) -> None: ...
80+
81+
@overload
82+
def __init__(self, *, openai_client: AsyncOpenAI | None = None) -> None: ...
83+
84+
def __init__(
85+
self,
86+
*,
87+
api_key: str | None = None,
88+
openai_client: AsyncOpenAI | None = None,
89+
http_client: AsyncHTTPClient | None = None,
90+
) -> None:
91+
"""Create a new GitHub Models provider.
92+
93+
Args:
94+
api_key: The GitHub token to use for authentication. If not provided, the `GITHUB_API_KEY`
95+
environment variable will be used if available.
96+
openai_client: An existing `AsyncOpenAI` client to use. If provided, `api_key` and `http_client` must be `None`.
97+
http_client: An existing `httpx.AsyncClient` to use for making HTTP requests.
98+
"""
99+
api_key = api_key or os.getenv('GITHUB_API_KEY')
100+
if not api_key and openai_client is None:
101+
raise UserError(
102+
'Set the `GITHUB_API_KEY` environment variable or pass it via `GitHubProvider(api_key=...)`'
103+
' to use the GitHub Models provider.'
104+
)
105+
106+
if openai_client is not None:
107+
self._client = openai_client
108+
elif http_client is not None:
109+
self._client = AsyncOpenAI(base_url=self.base_url, api_key=api_key, http_client=http_client)
110+
else:
111+
http_client = cached_async_http_client(provider='github')
112+
self._client = AsyncOpenAI(base_url=self.base_url, api_key=api_key, http_client=http_client)

tests/models/test_model.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,14 @@
6464
'bedrock',
6565
'BedrockConverseModel',
6666
),
67+
(
68+
'GITHUB_API_KEY',
69+
'github:xai/grok-3-mini',
70+
'xai/grok-3-mini',
71+
'github',
72+
'github',
73+
'OpenAIModel',
74+
),
6775
]
6876

6977

tests/providers/test_github.py

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
import re
2+
3+
import httpx
4+
import pytest
5+
from pytest_mock import MockerFixture
6+
7+
from pydantic_ai.exceptions import UserError
8+
from pydantic_ai.profiles._json_schema import InlineDefsJsonSchemaTransformer
9+
from pydantic_ai.profiles.cohere import cohere_model_profile
10+
from pydantic_ai.profiles.deepseek import deepseek_model_profile
11+
from pydantic_ai.profiles.grok import grok_model_profile
12+
from pydantic_ai.profiles.meta import meta_model_profile
13+
from pydantic_ai.profiles.mistral import mistral_model_profile
14+
from pydantic_ai.profiles.openai import OpenAIJsonSchemaTransformer, openai_model_profile
15+
16+
from ..conftest import TestEnv, try_import
17+
18+
with try_import() as imports_successful:
19+
import openai
20+
21+
from pydantic_ai.providers.github import GitHubProvider
22+
23+
pytestmark = pytest.mark.skipif(not imports_successful(), reason='openai not installed')
24+
25+
26+
def test_github_provider():
27+
provider = GitHubProvider(api_key='ghp_test_token')
28+
assert provider.name == 'github'
29+
assert provider.base_url == 'https://models.github.ai/inference'
30+
assert isinstance(provider.client, openai.AsyncOpenAI)
31+
assert provider.client.api_key == 'ghp_test_token'
32+
33+
34+
def test_github_provider_need_api_key(env: TestEnv) -> None:
35+
env.remove('GITHUB_API_KEY')
36+
with pytest.raises(
37+
UserError,
38+
match=re.escape(
39+
'Set the `GITHUB_API_KEY` environment variable or pass it via `GitHubProvider(api_key=...)`'
40+
' to use the GitHub Models provider.'
41+
),
42+
):
43+
GitHubProvider()
44+
45+
46+
def test_github_provider_pass_http_client() -> None:
47+
http_client = httpx.AsyncClient()
48+
provider = GitHubProvider(http_client=http_client, api_key='ghp_test_token')
49+
assert provider.client._client == http_client # type: ignore[reportPrivateUsage]
50+
51+
52+
def test_github_pass_openai_client() -> None:
53+
openai_client = openai.AsyncOpenAI(api_key='ghp_test_token')
54+
provider = GitHubProvider(openai_client=openai_client)
55+
assert provider.client == openai_client
56+
57+
58+
def test_github_provider_model_profile(mocker: MockerFixture):
59+
provider = GitHubProvider(api_key='ghp_test_token')
60+
61+
ns = 'pydantic_ai.providers.github'
62+
meta_model_profile_mock = mocker.patch(f'{ns}.meta_model_profile', wraps=meta_model_profile)
63+
deepseek_model_profile_mock = mocker.patch(f'{ns}.deepseek_model_profile', wraps=deepseek_model_profile)
64+
mistral_model_profile_mock = mocker.patch(f'{ns}.mistral_model_profile', wraps=mistral_model_profile)
65+
cohere_model_profile_mock = mocker.patch(f'{ns}.cohere_model_profile', wraps=cohere_model_profile)
66+
grok_model_profile_mock = mocker.patch(f'{ns}.grok_model_profile', wraps=grok_model_profile)
67+
openai_model_profile_mock = mocker.patch(f'{ns}.openai_model_profile', wraps=openai_model_profile)
68+
69+
meta_profile = provider.model_profile('meta/Llama-3.2-11B-Vision-Instruct')
70+
meta_model_profile_mock.assert_called_with('llama-3.2-11b-vision-instruct')
71+
assert meta_profile is not None
72+
assert meta_profile.json_schema_transformer == InlineDefsJsonSchemaTransformer
73+
74+
meta_profile = provider.model_profile('meta/Llama-3.1-405B-Instruct')
75+
meta_model_profile_mock.assert_called_with('llama-3.1-405b-instruct')
76+
assert meta_profile is not None
77+
assert meta_profile.json_schema_transformer == InlineDefsJsonSchemaTransformer
78+
79+
deepseek_profile = provider.model_profile('deepseek/deepseek-coder')
80+
deepseek_model_profile_mock.assert_called_with('deepseek-coder')
81+
assert deepseek_profile is not None
82+
assert deepseek_profile.json_schema_transformer == OpenAIJsonSchemaTransformer
83+
84+
mistral_profile = provider.model_profile('mistral-ai/mixtral-8x7b-instruct')
85+
mistral_model_profile_mock.assert_called_with('mixtral-8x7b-instruct')
86+
assert mistral_profile is not None
87+
assert mistral_profile.json_schema_transformer == OpenAIJsonSchemaTransformer
88+
89+
cohere_profile = provider.model_profile('cohere/command-r-plus')
90+
cohere_model_profile_mock.assert_called_with('command-r-plus')
91+
assert cohere_profile is not None
92+
assert cohere_profile.json_schema_transformer == OpenAIJsonSchemaTransformer
93+
94+
grok_profile = provider.model_profile('xai/grok-3-mini')
95+
grok_model_profile_mock.assert_called_with('grok-3-mini')
96+
assert grok_profile is not None
97+
assert grok_profile.json_schema_transformer == OpenAIJsonSchemaTransformer
98+
99+
microsoft_profile = provider.model_profile('microsoft/Phi-3.5-mini-instruct')
100+
openai_model_profile_mock.assert_called_with('phi-3.5-mini-instruct')
101+
assert microsoft_profile is not None
102+
assert microsoft_profile.json_schema_transformer == OpenAIJsonSchemaTransformer
103+
104+
unknown_profile = provider.model_profile('some-unknown-model')
105+
openai_model_profile_mock.assert_called_with('some-unknown-model')
106+
assert unknown_profile is not None
107+
assert unknown_profile.json_schema_transformer == OpenAIJsonSchemaTransformer
108+
109+
unknown_profile_with_prefix = provider.model_profile('unknown-publisher/some-unknown-model')
110+
openai_model_profile_mock.assert_called_with('some-unknown-model')
111+
assert unknown_profile_with_prefix is not None
112+
assert unknown_profile_with_prefix.json_schema_transformer == OpenAIJsonSchemaTransformer

tests/providers/test_provider_names.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from pydantic_ai.providers.cohere import CohereProvider
2020
from pydantic_ai.providers.deepseek import DeepSeekProvider
2121
from pydantic_ai.providers.fireworks import FireworksProvider
22+
from pydantic_ai.providers.github import GitHubProvider
2223
from pydantic_ai.providers.google_gla import GoogleGLAProvider
2324
from pydantic_ai.providers.google_vertex import GoogleVertexProvider
2425
from pydantic_ai.providers.grok import GrokProvider
@@ -44,6 +45,7 @@
4445
('fireworks', FireworksProvider, 'FIREWORKS_API_KEY'),
4546
('together', TogetherProvider, 'TOGETHER_API_KEY'),
4647
('heroku', HerokuProvider, 'HEROKU_INFERENCE_KEY'),
48+
('github', GitHubProvider, 'GITHUB_API_KEY'),
4749
]
4850

4951
if not imports_successful():

0 commit comments

Comments
 (0)