9
9
import openai
10
10
import tiktoken
11
11
from langfuse .model import InitialGeneration , Usage
12
+ from openai import OpenAI
12
13
from tenacity import *
13
14
14
15
from pentestgpt .utils .llm_api import LLMAPI
@@ -46,6 +47,8 @@ def __eq__(self, other):
46
47
class ChatGPTAPI (LLMAPI ):
47
48
def __init__ (self , config_class , use_langfuse_logging = False ):
48
49
self .name = str (config_class .model )
50
+ api_key = os .getenv ("OPENAI_API_KEY" , None )
51
+ self .client = OpenAI (api_key = api_key , base_url = config_class .api_base )
49
52
50
53
if use_langfuse_logging :
51
54
# use langfuse.openai to shadow the default openai library
@@ -58,9 +61,7 @@ def __init__(self, config_class, use_langfuse_logging=False):
58
61
from langfuse import Langfuse
59
62
60
63
self .langfuse = Langfuse ()
61
-
62
- openai .api_key = os .getenv ("OPENAI_API_KEY" , None )
63
- openai .api_base = config_class .api_base
64
+
64
65
self .model = config_class .model
65
66
self .log_dir = config_class .log_dir
66
67
self .history_length = 5 # maintain 5 messages in the history. (5 chat memory)
@@ -69,7 +70,9 @@ def __init__(self, config_class, use_langfuse_logging=False):
69
70
70
71
logger .add (sink = os .path .join (self .log_dir , "chatgpt.log" ), level = "WARNING" )
71
72
72
- def _chat_completion (self , history : List , model = None , temperature = 0.5 ) -> str :
73
+ def _chat_completion (
74
+ self , history : List , model = None , temperature = 0.5 , image_url : str = None
75
+ ) -> str :
73
76
generationStartTime = datetime .now ()
74
77
# use model if provided, otherwise use self.model; if self.model is None, use gpt-4-1106-preview
75
78
if model is None :
@@ -78,12 +81,12 @@ def _chat_completion(self, history: List, model=None, temperature=0.5) -> str:
78
81
else :
79
82
model = self .model
80
83
try :
81
- response = openai . ChatCompletion .create (
84
+ response = self . client . chat . completions .create (
82
85
model = model ,
83
86
messages = history ,
84
87
temperature = temperature ,
85
88
)
86
- except openai .error .APIConnectionError as e : # give one more try
89
+ except openai ._exceptions .APIConnectionError as e : # give one more try
87
90
logger .warning (
88
91
"API Connection Error. Waiting for {} seconds" .format (
89
92
self .error_wait_time
@@ -96,7 +99,7 @@ def _chat_completion(self, history: List, model=None, temperature=0.5) -> str:
96
99
messages = history ,
97
100
temperature = temperature ,
98
101
)
99
- except openai .error .RateLimitError as e : # give one more try
102
+ except openai ._exceptions .RateLimitError as e : # give one more try
100
103
logger .warning ("Rate limit reached. Waiting for 5 seconds" )
101
104
logger .error ("Rate Limit Error: " , e )
102
105
time .sleep (5 )
@@ -105,7 +108,7 @@ def _chat_completion(self, history: List, model=None, temperature=0.5) -> str:
105
108
messages = history ,
106
109
temperature = temperature ,
107
110
)
108
- except openai .error . InvalidRequestError as e : # token limit reached
111
+ except openai ._exceptions . RateLimitError as e : # token limit reached
109
112
logger .warning ("Token size limit reached. The recent message is compressed" )
110
113
logger .error ("Token size error; will retry with compressed message " , e )
111
114
# compress the message in two ways.
@@ -151,14 +154,14 @@ def _chat_completion(self, history: List, model=None, temperature=0.5) -> str:
151
154
model = self .model ,
152
155
modelParameters = {"temperature" : str (temperature )},
153
156
prompt = history ,
154
- completion = response [ " choices" ] [0 ][ " message" ][ " content" ] ,
157
+ completion = response . choices [0 ]. message . content ,
155
158
usage = Usage (
156
- promptTokens = response [ " usage" ][ " prompt_tokens" ] ,
157
- completionTokens = response [ " usage" ][ " completion_tokens" ] ,
159
+ promptTokens = response . usage . prompt_tokens ,
160
+ completionTokens = response . usage . completion_tokens ,
158
161
),
159
162
)
160
163
)
161
- return response [ " choices" ] [0 ][ " message" ][ " content" ]
164
+ return response . choices [0 ]. message . content
162
165
163
166
164
167
if __name__ == "__main__" :
0 commit comments