1+ from collections .abc import AsyncIterator
12from dataclasses import dataclass
23from os import getenv
3- from typing import Optional
4+ from typing import Iterator , List , Optional
45
6+ from agno .exceptions import ModelProviderError
7+ from agno .models .message import Message
58from agno .models .openai .like import OpenAILike
9+ from agno .utils .log import log_error
610
11+ try :
12+ from openai import APIConnectionError , APIStatusError , RateLimitError
13+ from openai .types .chat .chat_completion_chunk import (
14+ ChatCompletionChunk ,
15+ )
16+ except (ImportError , ModuleNotFoundError ):
17+ raise ImportError ("`openai` not installed. Please install using `pip install openai`" )
718
819@dataclass
920class xAI (OpenAILike ):
@@ -24,3 +35,118 @@ class xAI(OpenAILike):
2435
2536 api_key : Optional [str ] = getenv ("XAI_API_KEY" )
2637 base_url : Optional [str ] = "https://api.x.ai/v1"
38+
39+
40+ def invoke_stream (self , messages : List [Message ]) -> Iterator [ChatCompletionChunk ]:
41+ """
42+ Send a streaming chat completion request to the OpenAI API.
43+
44+ Args:
45+ messages (List[Message]): A list of messages to send to the model.
46+
47+ Returns:
48+ Iterator[ChatCompletionChunk]: An iterator of chat completion chunks.
49+ """
50+
51+ try :
52+ yield from self .get_client ().chat .completions .create (
53+ model = self .id ,
54+ messages = [self ._format_message (m ) for m in messages ], # type: ignore
55+ stream = True ,
56+ ** self .request_kwargs ,
57+ ) # type: ignore
58+ except RateLimitError as e :
59+ log_error (f"Rate limit error from OpenAI API: { e } " )
60+ error_message = e .response .json ().get ("error" , {})
61+ error_message = (
62+ error_message .get ("message" , "Unknown model error" )
63+ if isinstance (error_message , dict )
64+ else error_message
65+ )
66+ raise ModelProviderError (
67+ message = error_message ,
68+ status_code = e .response .status_code ,
69+ model_name = self .name ,
70+ model_id = self .id ,
71+ ) from e
72+ except APIConnectionError as e :
73+ log_error (f"API connection error from OpenAI API: { e } " )
74+ raise ModelProviderError (message = str (e ), model_name = self .name , model_id = self .id ) from e
75+ except APIStatusError as e :
76+ log_error (f"API status error from OpenAI API: { e } " )
77+ try :
78+ error_message = e .response .json ().get ("error" , {})
79+ except Exception :
80+ error_message = e .response .text
81+ error_message = (
82+ error_message .get ("message" , "Unknown model error" )
83+ if isinstance (error_message , dict )
84+ else error_message
85+ )
86+ raise ModelProviderError (
87+ message = error_message ,
88+ status_code = e .response .status_code ,
89+ model_name = self .name ,
90+ model_id = self .id ,
91+ ) from e
92+ except Exception as e :
93+ log_error (f"Error from OpenAI API: { e } " )
94+ raise ModelProviderError (message = str (e ), model_name = self .name , model_id = self .id ) from e
95+
96+ async def ainvoke_stream (self , messages : List [Message ]) -> AsyncIterator [ChatCompletionChunk ]:
97+ """
98+ Sends an asynchronous streaming chat completion request to the OpenAI API.
99+
100+ Args:
101+ messages (List[Message]): A list of messages to send to the model.
102+
103+ Returns:
104+ Any: An asynchronous iterator of chat completion chunks.
105+ """
106+
107+ try :
108+ async_stream = await self .get_async_client ().chat .completions .create (
109+ model = self .id ,
110+ messages = [self ._format_message (m ) for m in messages ], # type: ignore
111+ stream = True ,
112+ ** self .request_kwargs ,
113+ )
114+ async for chunk in async_stream :
115+ yield chunk
116+ except RateLimitError as e :
117+ log_error (f"Rate limit error from OpenAI API: { e } " )
118+ error_message = e .response .json ().get ("error" , {})
119+ error_message = (
120+ error_message .get ("message" , "Unknown model error" )
121+ if isinstance (error_message , dict )
122+ else error_message
123+ )
124+ raise ModelProviderError (
125+ message = error_message ,
126+ status_code = e .response .status_code ,
127+ model_name = self .name ,
128+ model_id = self .id ,
129+ ) from e
130+ except APIConnectionError as e :
131+ log_error (f"API connection error from OpenAI API: { e } " )
132+ raise ModelProviderError (message = str (e ), model_name = self .name , model_id = self .id ) from e
133+ except APIStatusError as e :
134+ log_error (f"API status error from OpenAI API: { e } " )
135+ try :
136+ error_message = e .response .json ().get ("error" , {})
137+ except Exception :
138+ error_message = e .response .text
139+ error_message = (
140+ error_message .get ("message" , "Unknown model error" )
141+ if isinstance (error_message , dict )
142+ else error_message
143+ )
144+ raise ModelProviderError (
145+ message = error_message ,
146+ status_code = e .response .status_code ,
147+ model_name = self .name ,
148+ model_id = self .id ,
149+ ) from e
150+ except Exception as e :
151+ log_error (f"Error from OpenAI API: { e } " )
152+ raise ModelProviderError (message = str (e ), model_name = self .name , model_id = self .id ) from e
0 commit comments