4
4
import aiohttp
5
5
import websockets
6
6
from pydantic import BaseModel , Field
7
+ import logging
8
+
9
+ logger = logging .getLogger (__name__ )
7
10
8
11
class LocalLabConfig (BaseModel ):
9
12
base_url : str
@@ -29,22 +32,23 @@ class Usage(BaseModel):
29
32
total_tokens : int
30
33
31
34
class GenerateResponse (BaseModel ):
32
- response : str
33
- model_id : str
34
- usage : Usage
35
+ """Response model for text generation"""
36
+ text : str # Changed from 'response' to 'text' to match server
37
+ model : str # Changed from 'model_id' to 'model' to match server
38
+ usage : Optional [Usage ] = None # Made usage optional since server might not always send it
35
39
36
40
class ChatChoice (BaseModel ):
37
41
message : ChatMessage
38
- finish_reason : str
42
+ finish_reason : Optional [ str ] = None # Made optional
39
43
40
44
class ChatResponse (BaseModel ):
41
45
choices : List [ChatChoice ]
42
- usage : Usage
46
+ usage : Optional [ Usage ] = None # Made usage optional
43
47
44
48
class BatchResponse (BaseModel ):
45
49
responses : List [str ]
46
- model_id : str
47
- usage : Usage
50
+ model : str # Changed from 'model_id' to 'model'
51
+ usage : Optional [ Usage ] = None # Made usage optional
48
52
49
53
class ModelInfo (BaseModel ):
50
54
name : str
@@ -89,8 +93,12 @@ def __init__(self, message: str, retry_after: int):
89
93
self .retry_after = retry_after
90
94
91
95
class LocalLabClient :
92
- def __init__ (self , config : Union [LocalLabConfig , Dict [str , Any ]]):
93
- if isinstance (config , dict ):
96
+ def __init__ (self , config : Union [str , LocalLabConfig , Dict [str , Any ]]):
97
+ """Initialize the client with either a URL string or config object"""
98
+ if isinstance (config , str ):
99
+ # If just a URL string is provided, create a config object
100
+ config = LocalLabConfig (base_url = config )
101
+ elif isinstance (config , dict ):
94
102
config = LocalLabConfig (** config )
95
103
self .config = config
96
104
self .session : Optional [aiohttp .ClientSession ] = None
@@ -157,32 +165,137 @@ async def _request(self, method: str, path: str, **kwargs) -> Any:
157
165
raise LocalLabError (str (e ), "CONNECTION_ERROR" )
158
166
await asyncio .sleep (2 ** attempt )
159
167
160
- async def generate (self , prompt : str , options : Optional [Union [GenerateOptions , Dict ]] = None ) -> GenerateResponse :
161
- """Generate text from prompt"""
162
- if isinstance (options , dict ):
163
- options = GenerateOptions (** options )
164
- data = {"prompt" : prompt , ** (options .model_dump () if options else {})}
165
- response = await self ._request ("POST" , "/generate" , json = data )
166
- return GenerateResponse (** response )
167
-
168
168
async def stream_generate (self , prompt : str , options : Optional [Union [GenerateOptions , Dict ]] = None ) -> AsyncGenerator [str , None ]:
169
169
"""Stream generated text"""
170
170
if isinstance (options , dict ):
171
171
options = GenerateOptions (** options )
172
- if options :
173
- options .stream = True
174
- else :
175
- options = GenerateOptions (stream = True )
172
+ if options is None :
173
+ options = GenerateOptions ()
176
174
177
- data = {"prompt" : prompt , ** options .model_dump ()}
178
- async with self .session .post ("/generate/stream" , json = data ) as response :
175
+ # Ensure stream is True and format data correctly
176
+ data = {
177
+ "prompt" : prompt ,
178
+ "stream" : True ,
179
+ "max_tokens" : options .max_length ,
180
+ "temperature" : options .temperature ,
181
+ "top_p" : options .top_p ,
182
+ "model" : options .model_id
183
+ }
184
+ # Remove None values
185
+ data = {k : v for k , v in data .items () if v is not None }
186
+
187
+ async with self .session .post ("/generate" , json = data ) as response :
188
+ if response .status != 200 :
189
+ try :
190
+ error_data = await response .json ()
191
+ error_msg = error_data .get ("detail" , "Streaming failed" )
192
+ logger .error (f"Streaming error: { error_msg } " )
193
+ yield f"\n Error: { error_msg } "
194
+ return
195
+ except :
196
+ yield "\n Error: Streaming failed"
197
+ return
198
+
199
+ buffer = ""
200
+ current_sentence = ""
201
+ last_token_was_space = False
202
+
179
203
async for line in response .content :
180
204
if line :
181
205
try :
182
- data = json .loads (line )
183
- yield data ["response" ]
184
- except json .JSONDecodeError :
185
- yield line .decode ().strip ()
206
+ line = line .decode ('utf-8' ).strip ()
207
+ # Skip empty lines
208
+ if not line :
209
+ continue
210
+
211
+ # Handle SSE format
212
+ if line .startswith ("data: " ):
213
+ line = line [6 :] # Remove "data: " prefix
214
+
215
+ # Skip control messages
216
+ if line in ["[DONE]" , "[ERROR]" ]:
217
+ continue
218
+
219
+ try :
220
+ # Try to parse as JSON
221
+ data = json .loads (line )
222
+ text = data .get ("text" , data .get ("response" , "" ))
223
+ except json .JSONDecodeError :
224
+ # If not JSON, use the line as is
225
+ text = line
226
+
227
+ if text :
228
+ # Clean up any special tokens
229
+ text = text .replace ("<|" , "" ).replace ("|>" , "" )
230
+ text = text .replace ("<" , "" ).replace (">" , "" )
231
+ text = text .replace ("[" , "" ).replace ("]" , "" )
232
+ text = text .replace ("{" , "" ).replace ("}" , "" )
233
+ text = text .replace ("data:" , "" )
234
+ text = text .replace ("��" , "" )
235
+ text = text .replace ("\\ n" , "\n " )
236
+ text = text .replace ("|user|" , "" )
237
+ text = text .replace ("|The" , "The" )
238
+ text = text .replace ("/|assistant|" , "" ).replace ("/|user|" , "" )
239
+
240
+ # Add space between words if needed
241
+ if (not text .startswith (" " ) and
242
+ not text .startswith ("\n " ) and
243
+ not last_token_was_space and
244
+ buffer and
245
+ not buffer .endswith (" " ) and
246
+ not buffer .endswith ("\n " )):
247
+ text = " " + text
248
+
249
+ # Update tracking variables
250
+ buffer += text
251
+ current_sentence += text
252
+ last_token_was_space = text .endswith (" " ) or text .endswith ("\n " )
253
+
254
+ # Check for sentence completion
255
+ if any (current_sentence .endswith (p ) for p in ["." , "!" , "?" , "\n " ]):
256
+ current_sentence = ""
257
+
258
+ yield text
259
+
260
+ except Exception as e :
261
+ logger .error (f"Error processing stream chunk: { str (e )} " )
262
+ yield f"\n Error: { str (e )} "
263
+ return
264
+
265
+ async def generate (self , prompt : str , options : Optional [Union [GenerateOptions , Dict ]] = None ) -> GenerateResponse :
266
+ """Generate text from prompt"""
267
+ if isinstance (options , dict ):
268
+ options = GenerateOptions (** options )
269
+ if options is None :
270
+ options = GenerateOptions ()
271
+
272
+ # Format data consistently with stream_generate
273
+ data = {
274
+ "prompt" : prompt ,
275
+ "max_tokens" : options .max_length ,
276
+ "temperature" : options .temperature ,
277
+ "top_p" : options .top_p ,
278
+ "model" : options .model_id ,
279
+ "stream" : False
280
+ }
281
+ # Remove None values
282
+ data = {k : v for k , v in data .items () if v is not None }
283
+
284
+ response = await self ._request ("POST" , "/generate" , json = data )
285
+ text = response .get ("text" , response .get ("response" , "" ))
286
+ if isinstance (text , str ):
287
+ # Clean up any special tokens
288
+ text = text .replace ("<|" , "" ).replace ("|>" , "" )
289
+ text = text .replace ("<" , "" ).replace (">" , "" )
290
+ text = text .replace ("[" , "" ).replace ("]" , "" )
291
+ text = text .replace ("{" , "" ).replace ("}" , "" )
292
+ text = text .strip ()
293
+
294
+ return GenerateResponse (
295
+ text = text ,
296
+ model = response .get ("model" , response .get ("model_id" , "" )),
297
+ usage = response .get ("usage" )
298
+ )
186
299
187
300
async def chat (self , messages : List [Union [ChatMessage , Dict ]], options : Optional [Union [GenerateOptions , Dict ]] = None ) -> ChatResponse :
188
301
"""Chat completion"""
@@ -281,4 +394,4 @@ async def on_message(self, callback: callable) -> None:
281
394
data = json .loads (message )
282
395
await callback (data )
283
396
except json .JSONDecodeError :
284
- await callback (message )
397
+ await callback (message )
0 commit comments