Skip to content

Commit e3e435e

Browse files
authored
More flexible method infer_provider (#1945)
Co-authored-by: Karel Hovorka <git@karel-hovorka.eu>
1 parent 352acff commit e3e435e

File tree

2 files changed

+30
-17
lines changed

2 files changed

+30
-17
lines changed

pydantic_ai_slim/pydantic_ai/providers/__init__.py

Lines changed: 23 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -48,68 +48,74 @@ def model_profile(self, model_name: str) -> ModelProfile | None:
4848
return None # pragma: no cover
4949

5050

51-
def infer_provider(provider: str) -> Provider[Any]: # noqa: C901
52-
"""Infer the provider from the provider name."""
51+
def infer_provider_class(provider: str) -> type[Provider[Any]]: # noqa: C901
52+
"""Infers the provider class from the provider name."""
5353
if provider == 'openai':
5454
from .openai import OpenAIProvider
5555

56-
return OpenAIProvider()
56+
return OpenAIProvider
5757
elif provider == 'deepseek':
5858
from .deepseek import DeepSeekProvider
5959

60-
return DeepSeekProvider()
60+
return DeepSeekProvider
6161
elif provider == 'openrouter':
6262
from .openrouter import OpenRouterProvider
6363

64-
return OpenRouterProvider()
64+
return OpenRouterProvider
6565
elif provider == 'azure':
6666
from .azure import AzureProvider
6767

68-
return AzureProvider()
68+
return AzureProvider
6969
elif provider == 'google-vertex':
7070
from .google_vertex import GoogleVertexProvider
7171

72-
return GoogleVertexProvider()
72+
return GoogleVertexProvider
7373
elif provider == 'google-gla':
7474
from .google_gla import GoogleGLAProvider
7575

76-
return GoogleGLAProvider()
76+
return GoogleGLAProvider
7777
# NOTE: We don't test because there are many ways the `boto3.client` can retrieve the credentials.
7878
elif provider == 'bedrock':
7979
from .bedrock import BedrockProvider
8080

81-
return BedrockProvider()
81+
return BedrockProvider
8282
elif provider == 'groq':
8383
from .groq import GroqProvider
8484

85-
return GroqProvider()
85+
return GroqProvider
8686
elif provider == 'anthropic':
8787
from .anthropic import AnthropicProvider
8888

89-
return AnthropicProvider()
89+
return AnthropicProvider
9090
elif provider == 'mistral':
9191
from .mistral import MistralProvider
9292

93-
return MistralProvider()
93+
return MistralProvider
9494
elif provider == 'cohere':
9595
from .cohere import CohereProvider
9696

97-
return CohereProvider()
97+
return CohereProvider
9898
elif provider == 'grok':
9999
from .grok import GrokProvider
100100

101-
return GrokProvider()
101+
return GrokProvider
102102
elif provider == 'fireworks':
103103
from .fireworks import FireworksProvider
104104

105-
return FireworksProvider()
105+
return FireworksProvider
106106
elif provider == 'together':
107107
from .together import TogetherProvider
108108

109-
return TogetherProvider()
109+
return TogetherProvider
110110
elif provider == 'heroku':
111111
from .heroku import HerokuProvider
112112

113-
return HerokuProvider()
113+
return HerokuProvider
114114
else: # pragma: no cover
115115
raise ValueError(f'Unknown provider: {provider}')
116+
117+
118+
def infer_provider(provider: str) -> Provider[Any]:
119+
"""Infer the provider from the provider name."""
120+
provider_class = infer_provider_class(provider)
121+
return provider_class()

tests/providers/test_provider_names.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,3 +65,10 @@ def test_infer_provider(provider: str, provider_cls: type[Provider[Any]], except
6565
infer_provider(provider)
6666
else:
6767
assert isinstance(infer_provider(provider), provider_cls)
68+
69+
70+
@pytest.mark.parametrize(('provider', 'provider_cls', 'exception_has'), test_infer_provider_params)
71+
def test_infer_provider_class(provider: str, provider_cls: type[Provider[Any]], exception_has: str | None):
72+
from pydantic_ai.providers import infer_provider_class
73+
74+
assert infer_provider_class(provider) == provider_cls

0 commit comments

Comments
 (0)