@@ -88,8 +88,19 @@ class AzureOpenAIModel(BaseModelBackend):
8888 (default: :obj:`None`)
8989 max_retries (int, optional): Maximum number of retries for API calls.
9090 (default: :obj:`3`)
91+ client (Optional[Any], optional): A custom synchronous AzureOpenAI
92+ client instance. If provided, this client will be used instead of
93+ creating a new one. Useful for RL frameworks like AReaL or rLLM
94+ that provide Azure OpenAI-compatible clients. The client should
95+ implement the AzureOpenAI client interface with
96+ `.chat.completions.create()` and `.beta.chat.completions.parse()`
97+ methods. (default: :obj:`None`)
98+ async_client (Optional[Any], optional): A custom asynchronous
99+ AzureOpenAI client instance. If provided, this client will be
100+ used instead of creating a new one. The client should implement
101+ the AsyncAzureOpenAI client interface. (default: :obj:`None`)
91102 **kwargs (Any): Additional arguments to pass to the client
92- initialization.
103+ initialization. Ignored if custom clients are provided.
93104
94105 References:
95106 https://learn.microsoft.com/en-us/azure/ai-services/openai/
@@ -108,6 +119,8 @@ def __init__(
108119 azure_ad_token_provider : Optional ["AzureADTokenProvider" ] = None ,
109120 azure_ad_token : Optional [str ] = None ,
110121 max_retries : int = 3 ,
122+ client : Optional [Any ] = None ,
123+ async_client : Optional [Any ] = None ,
111124 ** kwargs : Any ,
112125 ) -> None :
113126 if model_config_dict is None :
@@ -138,56 +151,72 @@ def __init__(
138151 "or `AZURE_DEPLOYMENT_NAME` environment variable."
139152 )
140153
141- if is_langfuse_available ():
142- from langfuse .openai import AsyncAzureOpenAI as LangfuseAsyncOpenAI
143- from langfuse .openai import AzureOpenAI as LangfuseOpenAI
144-
145- self ._client = LangfuseOpenAI (
146- azure_endpoint = str (self ._url ),
147- azure_deployment = self ._azure_deployment_name ,
148- api_version = self .api_version ,
149- api_key = self ._api_key ,
150- azure_ad_token = self ._azure_ad_token ,
151- azure_ad_token_provider = self .azure_ad_token_provider ,
152- timeout = self ._timeout ,
153- max_retries = max_retries ,
154- ** kwargs ,
155- )
156- self ._async_client = LangfuseAsyncOpenAI (
157- azure_endpoint = str (self ._url ),
158- azure_deployment = self ._azure_deployment_name ,
159- api_version = self .api_version ,
160- api_key = self ._api_key ,
161- azure_ad_token = self ._azure_ad_token ,
162- azure_ad_token_provider = self .azure_ad_token_provider ,
163- timeout = self ._timeout ,
164- max_retries = max_retries ,
165- ** kwargs ,
166- )
154+ # Use custom clients if provided, otherwise create new ones
155+ if client is not None :
156+ # Use the provided custom sync client
157+ self ._client = client
167158 else :
168- self ._client = AzureOpenAI (
169- azure_endpoint = str (self ._url ),
170- azure_deployment = self ._azure_deployment_name ,
171- api_version = self .api_version ,
172- api_key = self ._api_key ,
173- azure_ad_token = self ._azure_ad_token ,
174- azure_ad_token_provider = self .azure_ad_token_provider ,
175- timeout = self ._timeout ,
176- max_retries = max_retries ,
177- ** kwargs ,
178- )
159+ # Create default sync client
160+ if is_langfuse_available ():
161+ from langfuse .openai import AzureOpenAI as LangfuseOpenAI
162+
163+ self ._client = LangfuseOpenAI (
164+ azure_endpoint = str (self ._url ),
165+ azure_deployment = self ._azure_deployment_name ,
166+ api_version = self .api_version ,
167+ api_key = self ._api_key ,
168+ azure_ad_token = self ._azure_ad_token ,
169+ azure_ad_token_provider = self .azure_ad_token_provider ,
170+ timeout = self ._timeout ,
171+ max_retries = max_retries ,
172+ ** kwargs ,
173+ )
174+ else :
175+ self ._client = AzureOpenAI (
176+ azure_endpoint = str (self ._url ),
177+ azure_deployment = self ._azure_deployment_name ,
178+ api_version = self .api_version ,
179+ api_key = self ._api_key ,
180+ azure_ad_token = self ._azure_ad_token ,
181+ azure_ad_token_provider = self .azure_ad_token_provider ,
182+ timeout = self ._timeout ,
183+ max_retries = max_retries ,
184+ ** kwargs ,
185+ )
179186
180- self ._async_client = AsyncAzureOpenAI (
181- azure_endpoint = str (self ._url ),
182- azure_deployment = self ._azure_deployment_name ,
183- api_version = self .api_version ,
184- api_key = self ._api_key ,
185- azure_ad_token = self ._azure_ad_token ,
186- azure_ad_token_provider = self .azure_ad_token_provider ,
187- timeout = self ._timeout ,
188- max_retries = max_retries ,
189- ** kwargs ,
190- )
187+ if async_client is not None :
188+ # Use the provided custom async client
189+ self ._async_client = async_client
190+ else :
191+ # Create default async client
192+ if is_langfuse_available ():
193+ from langfuse .openai import (
194+ AsyncAzureOpenAI as LangfuseAsyncOpenAI ,
195+ )
196+
197+ self ._async_client = LangfuseAsyncOpenAI (
198+ azure_endpoint = str (self ._url ),
199+ azure_deployment = self ._azure_deployment_name ,
200+ api_version = self .api_version ,
201+ api_key = self ._api_key ,
202+ azure_ad_token = self ._azure_ad_token ,
203+ azure_ad_token_provider = self .azure_ad_token_provider ,
204+ timeout = self ._timeout ,
205+ max_retries = max_retries ,
206+ ** kwargs ,
207+ )
208+ else :
209+ self ._async_client = AsyncAzureOpenAI (
210+ azure_endpoint = str (self ._url ),
211+ azure_deployment = self ._azure_deployment_name ,
212+ api_version = self .api_version ,
213+ api_key = self ._api_key ,
214+ azure_ad_token = self ._azure_ad_token ,
215+ azure_ad_token_provider = self .azure_ad_token_provider ,
216+ timeout = self ._timeout ,
217+ max_retries = max_retries ,
218+ ** kwargs ,
219+ )
191220
192221 @property
193222 def token_counter (self ) -> BaseTokenCounter :
0 commit comments