@@ -48,68 +48,74 @@ def model_profile(self, model_name: str) -> ModelProfile | None:
48
48
return None # pragma: no cover
49
49
50
50
51
- def infer_provider (provider : str ) -> Provider [Any ]: # noqa: C901
52
- """Infer the provider from the provider name."""
51
+ def infer_provider_class (provider : str ) -> type [ Provider [Any ] ]: # noqa: C901
52
+ """Infers the provider class from the provider name."""
53
53
if provider == 'openai' :
54
54
from .openai import OpenAIProvider
55
55
56
- return OpenAIProvider ()
56
+ return OpenAIProvider
57
57
elif provider == 'deepseek' :
58
58
from .deepseek import DeepSeekProvider
59
59
60
- return DeepSeekProvider ()
60
+ return DeepSeekProvider
61
61
elif provider == 'openrouter' :
62
62
from .openrouter import OpenRouterProvider
63
63
64
- return OpenRouterProvider ()
64
+ return OpenRouterProvider
65
65
elif provider == 'azure' :
66
66
from .azure import AzureProvider
67
67
68
- return AzureProvider ()
68
+ return AzureProvider
69
69
elif provider == 'google-vertex' :
70
70
from .google_vertex import GoogleVertexProvider
71
71
72
- return GoogleVertexProvider ()
72
+ return GoogleVertexProvider
73
73
elif provider == 'google-gla' :
74
74
from .google_gla import GoogleGLAProvider
75
75
76
- return GoogleGLAProvider ()
76
+ return GoogleGLAProvider
77
77
# NOTE: We don't test because there are many ways the `boto3.client` can retrieve the credentials.
78
78
elif provider == 'bedrock' :
79
79
from .bedrock import BedrockProvider
80
80
81
- return BedrockProvider ()
81
+ return BedrockProvider
82
82
elif provider == 'groq' :
83
83
from .groq import GroqProvider
84
84
85
- return GroqProvider ()
85
+ return GroqProvider
86
86
elif provider == 'anthropic' :
87
87
from .anthropic import AnthropicProvider
88
88
89
- return AnthropicProvider ()
89
+ return AnthropicProvider
90
90
elif provider == 'mistral' :
91
91
from .mistral import MistralProvider
92
92
93
- return MistralProvider ()
93
+ return MistralProvider
94
94
elif provider == 'cohere' :
95
95
from .cohere import CohereProvider
96
96
97
- return CohereProvider ()
97
+ return CohereProvider
98
98
elif provider == 'grok' :
99
99
from .grok import GrokProvider
100
100
101
- return GrokProvider ()
101
+ return GrokProvider
102
102
elif provider == 'fireworks' :
103
103
from .fireworks import FireworksProvider
104
104
105
- return FireworksProvider ()
105
+ return FireworksProvider
106
106
elif provider == 'together' :
107
107
from .together import TogetherProvider
108
108
109
- return TogetherProvider ()
109
+ return TogetherProvider
110
110
elif provider == 'heroku' :
111
111
from .heroku import HerokuProvider
112
112
113
- return HerokuProvider ()
113
+ return HerokuProvider
114
114
else : # pragma: no cover
115
115
raise ValueError (f'Unknown provider: { provider } ' )
116
+
117
+
118
+ def infer_provider (provider : str ) -> Provider [Any ]:
119
+ """Infer the provider from the provider name."""
120
+ provider_class = infer_provider_class (provider )
121
+ return provider_class ()
0 commit comments