1
- from typing import Any , Iterator , List , Optional
1
+ import functools
2
+ import os
3
+ from typing import Any , Dict , Iterator , List , Optional
2
4
3
- import openai
4
5
from loguru import logger
6
+ from openai import OpenAI , Stream
7
+ from openai .types import Completion
5
8
from transformers import AutoTokenizer
6
9
7
- from guidellm .backend import Backend , BackendTypes , GenerativeResponse
8
- from guidellm .core . request import TextGenerationRequest
10
+ from guidellm .backend import Backend , BackendEngine , GenerativeResponse
11
+ from guidellm .core import TextGenerationRequest
9
12
10
13
__all__ = ["OpenAIBackend" ]
11
14
12
15
13
- @Backend .register_backend ( BackendTypes .OPENAI_SERVER )
16
+ @Backend .register ( BackendEngine .OPENAI_SERVER )
14
17
class OpenAIBackend (Backend ):
15
18
"""
16
19
An OpenAI backend implementation for the generative AI result.
@@ -33,34 +36,37 @@ class OpenAIBackend(Backend):
33
36
34
37
def __init__ (
35
38
self ,
36
- target : Optional [str ] = None ,
37
- host : Optional [str ] = None ,
38
- port : Optional [int ] = None ,
39
- path : Optional [str ] = None ,
39
+ openai_api_key : Optional [str ] = None ,
40
+ internal_callback_url : Optional [str ] = None ,
40
41
model : Optional [str ] = None ,
41
- api_key : Optional [str ] = None ,
42
- ** request_args ,
42
+ ** request_args : Any ,
43
43
):
44
- self .target = target
45
- self .model = model
46
- self .request_args = request_args
47
-
48
- if not self .target :
49
- if not host :
50
- raise ValueError ("Host is required if target is not provided." )
51
-
52
- port_incl = f":{ port } " if port else ""
53
- path_incl = path if path else ""
54
- self .target = f"http://{ host } { port_incl } { path_incl } "
44
+ """
45
+ Initialize an OpenAI Client
46
+ """
55
47
56
- openai .api_base = self .target
57
- openai .api_key = api_key
48
+ self .request_args = request_args
58
49
59
- if not model :
60
- self .model = self .default_model ()
50
+ if not (_api_key := (openai_api_key or os .getenv ("OPENAI_API_KEY" , None ))):
51
+ raise ValueError (
52
+ "`OPENAI_API_KEY` environment variable "
53
+ "or --openai-api-key CLI parameter "
54
+ "must be specify for the OpenAI backend"
55
+ )
56
+
57
+ if not (
58
+ _base_url := (internal_callback_url or os .getenv ("OPENAI_BASE_URL" , None ))
59
+ ):
60
+ raise ValueError (
61
+ "`OPENAI_BASE_URL` environment variable "
62
+ "or --openai-base-url CLI parameter "
63
+ "must be specify for the OpenAI backend"
64
+ )
65
+ self .openai_client = OpenAI (api_key = _api_key , base_url = _base_url )
66
+ self .model = model or self .default_model
61
67
62
68
logger .info (
63
- f"Initialized OpenAIBackend with target : { self . target } "
69
+ f"Initialized OpenAIBackend with callback url : { internal_callback_url } "
64
70
f"and model: { self .model } "
65
71
)
66
72
@@ -75,52 +81,46 @@ def make_request(
75
81
:return: An iterator over the generative responses.
76
82
:rtype: Iterator[GenerativeResponse]
77
83
"""
84
+
78
85
logger .debug (f"Making request to OpenAI backend with prompt: { request .prompt } " )
79
- num_gen_tokens = request .params .get ("generated_tokens" , None )
80
- request_args = {
81
- "n" : 1 ,
82
- }
83
86
84
- if num_gen_tokens :
85
- request_args ["max_tokens" ] = num_gen_tokens
86
- request_args ["stop" ] = None
87
+ # How many completions to generate for each prompt
88
+ request_args : Dict = {"n" : 1 }
89
+
90
+ if (num_gen_tokens := request .params .get ("generated_tokens" , None )) is not None :
91
+ request_args .update (max_tokens = num_gen_tokens , stop = None )
87
92
88
93
if self .request_args :
89
94
request_args .update (self .request_args )
90
95
91
- response = openai . Completion .create (
92
- engine = self .model ,
96
+ response : Stream [ Completion ] = self . openai_client . completions .create (
97
+ model = self .model ,
93
98
prompt = request .prompt ,
94
99
stream = True ,
95
100
** request_args ,
96
101
)
97
102
98
103
for chunk in response :
99
- if chunk .get ("choices" ):
100
- choice = chunk ["choices" ][0 ]
101
- if choice .get ("finish_reason" ) == "stop" :
102
- logger .debug ("Received final response from OpenAI backend" )
103
- yield GenerativeResponse (
104
- type_ = "final" ,
105
- output = choice ["text" ],
106
- prompt = request .prompt ,
107
- prompt_token_count = (
108
- request .token_count
109
- if request .token_count
110
- else self ._token_count (request .prompt )
111
- ),
112
- output_token_count = (
113
- num_gen_tokens
114
- if num_gen_tokens
115
- else self ._token_count (choice ["text" ])
116
- ),
117
- )
118
- break
119
- else :
120
- logger .debug ("Received token from OpenAI backend" )
121
- yield GenerativeResponse (
122
- type_ = "token_iter" , add_token = choice ["text" ]
123
- )
104
+ chunk_content : str = getattr (chunk , "content" , "" )
105
+
106
+ if getattr (chunk , "stop" , True ) is True :
107
+ logger .debug ("Received final response from OpenAI backend" )
108
+
109
+ yield GenerativeResponse (
110
+ type_ = "final" ,
111
+ prompt = getattr (chunk , "prompt" , request .prompt ),
112
+ prompt_token_count = (
113
+ request .prompt_token_count or self ._token_count (request .prompt )
114
+ ),
115
+ output_token_count = (
116
+ num_gen_tokens
117
+ if num_gen_tokens
118
+ else self ._token_count (chunk_content )
119
+ ),
120
+ )
121
+ else :
122
+ logger .debug ("Received token from OpenAI backend" )
123
+ yield GenerativeResponse (type_ = "token_iter" , add_token = chunk_content )
124
124
125
125
def available_models (self ) -> List [str ]:
126
126
"""
@@ -129,21 +129,28 @@ def available_models(self) -> List[str]:
129
129
:return: A list of available models.
130
130
:rtype: List[str]
131
131
"""
132
- models = [model ["id" ] for model in openai .Engine .list ()["data" ]]
132
+
133
+ models : list [str ] = [
134
+ model .id for model in self .openai_client .models .list ().data
135
+ ]
133
136
logger .info (f"Available models: { models } " )
137
+
134
138
return models
135
139
140
+ @property
141
+ @functools .lru_cache (maxsize = 1 )
136
142
def default_model (self ) -> str :
137
143
"""
138
144
Get the default model for the backend.
139
145
140
146
:return: The default model.
141
147
:rtype: str
142
148
"""
143
- models = self . available_models ()
144
- if models :
149
+
150
+ if models := self . available_models () :
145
151
logger .info (f"Default model: { models [0 ]} " )
146
152
return models [0 ]
153
+
147
154
logger .error ("No models available." )
148
155
raise ValueError ("No models available." )
149
156
0 commit comments