1515import base64
1616import os
1717from io import BytesIO
18- from typing import List , Literal , Optional , Union
18+ from typing import ClassVar , List , Literal , Optional , Tuple , Union
1919
2020from openai import OpenAI
2121from PIL import Image
2929
3030
3131@MCPServer ()
32- class OpenAIImageToolkit (BaseToolkit ):
33- r"""A class toolkit for image generation using OpenAI's
34- Image Generation API.
35- """
36-
37- @api_keys_required (
38- [
39- ("api_key" , "OPENAI_API_KEY" ),
40- ]
41- )
32+ class ImageGenToolkit (BaseToolkit ):
33+ r"""A class toolkit for image generation using Grok and OpenAI models."""
34+
35+ GROK_MODELS : ClassVar [List [str ]] = [
36+ "grok-2-image" ,
37+ "grok-2-image-latest" ,
38+ "grok-2-image-1212" ,
39+ ]
40+ OPENAI_MODELS : ClassVar [List [str ]] = [
41+ "gpt-image-1" ,
42+ "dall-e-3" ,
43+ "dall-e-2" ,
44+ ]
45+
4246 def __init__ (
4347 self ,
4448 model : Optional [
45- Literal ["gpt-image-1" , "dall-e-3" , "dall-e-2" ]
46- ] = "gpt-image-1" ,
49+ Literal [
50+ "gpt-image-1" ,
51+ "dall-e-3" ,
52+ "dall-e-2" ,
53+ "grok-2-image" ,
54+ "grok-2-image-latest" ,
55+ "grok-2-image-1212" ,
56+ ]
57+ ] = "dall-e-3" ,
4758 timeout : Optional [float ] = None ,
4859 api_key : Optional [str ] = None ,
4960 url : Optional [str ] = None ,
@@ -72,12 +83,12 @@ def __init__(
7283 # NOTE: Some arguments are set in the constructor to prevent the agent
7384 # from making invalid API calls with model-specific parameters. For
7485 # example, the 'style' argument is only supported by 'dall-e-3'.
75- r"""Initializes a new instance of the OpenAIImageToolkit class.
86+ r"""Initializes a new instance of the ImageGenToolkit class.
7687
7788 Args:
7889 api_key (Optional[str]): The API key for authenticating
79- with the OpenAI service. (default: :obj:`None`)
80- url (Optional[str]): The url to the OpenAI service.
90+ with the image model service. (default: :obj:`None`)
91+ url (Optional[str]): The url to the image model service.
8192 (default: :obj:`None`)
8293 model (Optional[str]): The model to use.
8394 (default: :obj:`"dall-e-3"`)
@@ -103,9 +114,23 @@ def __init__(
103114 image.(default: :obj:`"image_save"`)
104115 """
105116 super ().__init__ (timeout = timeout )
106- api_key = api_key or os .environ .get ("OPENAI_API_KEY" )
107- url = url or os .environ .get ("OPENAI_API_BASE_URL" )
108- self .client = OpenAI (api_key = api_key , base_url = url )
117+ if model not in self .GROK_MODELS + self .OPENAI_MODELS :
118+ available_models = sorted (self .OPENAI_MODELS + self .GROK_MODELS )
119+ raise ValueError (
120+ f"Unsupported model: { model } . "
121+ f"Supported models are: { available_models } "
122+ )
123+
124+ # Set default url for Grok models
125+ url = "https://api.x.ai/v1" if model in self .GROK_MODELS else url
126+
127+ api_key , base_url = (
128+ self .get_openai_credentials (url , api_key )
129+ if model in self .OPENAI_MODELS
130+ else self .get_grok_credentials (url , api_key )
131+ )
132+
133+ self .client = OpenAI (api_key = api_key , base_url = base_url )
109134 self .model = model
110135 self .size = size
111136 self .quality = quality
@@ -139,7 +164,7 @@ def base64_to_image(self, base64_string: str) -> Optional[Image.Image]:
139164 return None
140165
141166 def _build_base_params (self , prompt : str , n : Optional [int ] = None ) -> dict :
142- r"""Build base parameters dict for OpenAI API calls.
167+ r"""Build base parameters dict for Image Model API calls.
143168
144169 Args:
145170 prompt (str): The text prompt for the image operation.
@@ -153,6 +178,10 @@ def _build_base_params(self, prompt: str, n: Optional[int] = None) -> dict:
153178 # basic parameters supported by all models
154179 if n is not None :
155180 params ["n" ] = n # type: ignore[assignment]
181+
182+ if self .model in self .GROK_MODELS :
183+ return params
184+
156185 if self .size is not None :
157186 params ["size" ] = self .size
158187
@@ -179,16 +208,18 @@ def _build_base_params(self, prompt: str, n: Optional[int] = None) -> dict:
179208 params ["quality" ] = self .quality
180209 if self .background is not None :
181210 params ["background" ] = self .background
182-
183211 return params
184212
185213 def _handle_api_response (
186- self , response , image_name : Union [str , List [str ]], operation : str
214+ self ,
215+ response ,
216+ image_name : Union [str , List [str ]],
217+ operation : str ,
187218 ) -> str :
188- r"""Handle API response from OpenAI image operations.
219+ r"""Handle API response from image operations.
189220
190221 Args:
191- response: The response object from OpenAI API.
222+ response: The response object from image model API.
192223 image_name (Union[str, List[str]]): Name(s) for the saved image
193224 file(s). If str, the same name is used for all images (will
194225 cause error for multiple images). If list, must have exactly
@@ -198,8 +229,9 @@ def _handle_api_response(
198229 Returns:
199230 str: Success message with image path/URL or error message.
200231 """
232+ source = "Grok" if self .model in self .GROK_MODELS else "OpenAI"
201233 if response .data is None or len (response .data ) == 0 :
202- error_msg = "No image data returned from OpenAI API."
234+ error_msg = f "No image data returned from { source } API."
203235 logger .error (error_msg )
204236 return error_msg
205237
@@ -283,7 +315,7 @@ def generate_image(
283315 image_name : Union [str , List [str ]] = "image.png" ,
284316 n : int = 1 ,
285317 ) -> str :
286- r"""Generate an image using OpenAI's Image Generation models.
318+ r"""Generate an image using image models.
287319 The generated image will be saved locally (for ``b64_json`` response
288320 formats) or an image URL will be returned (for ``url`` response
289321 formats).
@@ -309,15 +341,50 @@ def generate_image(
309341 logger .error (error_msg )
310342 return error_msg
311343
344+ @api_keys_required ([("api_key" , "XAI_API_KEY" )])
345+ def get_grok_credentials (self , url , api_key ) -> Tuple [str , str ]: # type: ignore[return-value]
346+ r"""Get API credentials for the specified Grok model.
347+
348+ Args:
349+ url (str): The base URL for the Grok API.
350+ api_key (str): The API key for the Grok API.
351+
352+ Returns:
353+ tuple: (api_key, base_url)
354+ """
355+
356+ # Get credentials based on model type
357+ api_key = api_key or os .getenv ("XAI_API_KEY" )
358+ return api_key , url
359+
360+ @api_keys_required ([("api_key" , "OPENAI_API_KEY" )])
361+ def get_openai_credentials (self , url , api_key ) -> Tuple [str , str | None ]: # type: ignore[return-value]
362+ r"""Get API credentials for the specified OpenAI model.
363+
364+ Args:
365+ url (str): The base URL for the OpenAI API.
366+ api_key (str): The API key for the OpenAI API.
367+
368+ Returns:
369+ Tuple[str, str | None]: (api_key, base_url)
370+ """
371+
372+ api_key = api_key or os .getenv ("OPENAI_API_KEY" )
373+ base_url = url or os .getenv ("OPENAI_API_BASE_URL" )
374+ return api_key , base_url
375+
312376 def get_tools (self ) -> List [FunctionTool ]:
313- r"""Returns a list of FunctionTool objects representing the
314- functions in the toolkit.
377+ r"""Returns a list of FunctionTool objects representing the functions
378+ in the toolkit.
315379
316380 Returns:
317- List[FunctionTool]: A list of FunctionTool objects
318- representing the functions in the toolkit.
381+ List[FunctionTool]: A list of FunctionTool objects representing the
382+ functions in the toolkit.
319383 """
320384 return [
321385 FunctionTool (self .generate_image ),
322- # could add edit_image function later
323386 ]
387+
388+
389+ # Backward compatibility alias
390+ OpenAIImageToolkit = ImageGenToolkit
0 commit comments