Skip to content

Commit d996130

Browse files
authored
Allow using 3rd party AI services that are compatible with OpenAI API format in the openai-gpt agent (#331)
1 parent 1c4a2e8 commit d996130

File tree

3 files changed

+33
-7
lines changed

3 files changed

+33
-7
lines changed

shell/agents/AIShell.OpenAI.Agent/GPT.cs

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ internal enum EndpointType
99
{
1010
AzureOpenAI,
1111
OpenAI,
12+
CompatibleThirdParty,
1213
}
1314

1415
public class GPT
@@ -56,9 +57,16 @@ public GPT(
5657
bool noDeployment = string.IsNullOrEmpty(Deployment);
5758
Type = noEndpoint && noDeployment
5859
? EndpointType.OpenAI
59-
: !noEndpoint && !noDeployment
60-
? EndpointType.AzureOpenAI
61-
: throw new InvalidOperationException($"Invalid setting: {(noEndpoint ? "Endpoint" : "Deployment")} key is missing. To use Azure OpenAI service, please specify both the 'Endpoint' and 'Deployment' keys. To use OpenAI service, please ignore both keys.");
60+
: !noEndpoint && noDeployment
61+
? EndpointType.CompatibleThirdParty
62+
: !noEndpoint && !noDeployment
63+
? EndpointType.AzureOpenAI
64+
: throw new InvalidOperationException($"Invalid setting: 'Deployment' key present but 'Endpoint' key is missing. To use Azure OpenAI service, please specify both the 'Endpoint' and 'Deployment' keys. To use OpenAI service, please ignore both keys.");
65+
66+
if (ModelInfo is null && Type is EndpointType.CompatibleThirdParty)
67+
{
68+
ModelInfo = ModelInfo.ThirdPartyModel;
69+
}
6270
}
6371

6472
/// <summary>
@@ -142,11 +150,18 @@ private void ShowEndpointInfo(IHost host)
142150
new(label: " Model", m => m.ModelName),
143151
},
144152

145-
EndpointType.OpenAI => new CustomElement<GPT>[]
146-
{
153+
EndpointType.OpenAI =>
154+
[
147155
new(label: " Type", m => m.Type.ToString()),
148156
new(label: " Model", m => m.ModelName),
149-
},
157+
],
158+
159+
EndpointType.CompatibleThirdParty =>
160+
[
161+
new(label: " Type", m => m.Type.ToString()),
162+
new(label: " Endpoint", m => m.Endpoint),
163+
new(label: " Model", m => m.ModelName),
164+
],
150165

151166
_ => throw new UnreachableException(),
152167
};

shell/agents/AIShell.OpenAI.Agent/ModelInfo.cs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,11 @@ internal class ModelInfo
1313
private static readonly Dictionary<string, ModelInfo> s_modelMap;
1414
private static readonly Dictionary<string, Task<Tokenizer>> s_encodingMap;
1515

16+
// A rough estimate to cover all third-party models.
17+
// - most popular models today support 32K+ context length;
18+
// - use the gpt-4o encoding as an estimate for token count.
19+
internal static readonly ModelInfo ThirdPartyModel = new(32_000, encoding: Gpt4oEncoding);
20+
1621
static ModelInfo()
1722
{
1823
// For reference, see https://platform.openai.com/docs/models and the "Counting tokens" section in

shell/agents/AIShell.OpenAI.Agent/Service.cs

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,9 +122,10 @@ private void RefreshOpenAIClient()
122122
return;
123123
}
124124

125+
EndpointType type = _gptToUse.Type;
125126
string userKey = Utils.ConvertFromSecureString(_gptToUse.Key);
126127

127-
if (_gptToUse.Type is EndpointType.AzureOpenAI)
128+
if (type is EndpointType.AzureOpenAI)
128129
{
129130
// Create a client that targets Azure OpenAI service or Azure API Management service.
130131
var clientOptions = new AzureOpenAIClientOptions() { RetryPolicy = new ChatRetryPolicy() };
@@ -152,6 +153,11 @@ private void RefreshOpenAIClient()
152153
{
153154
// Create a client that targets the non-Azure OpenAI service.
154155
var clientOptions = new OpenAIClientOptions() { RetryPolicy = new ChatRetryPolicy() };
156+
if (type is EndpointType.CompatibleThirdParty)
157+
{
158+
clientOptions.Endpoint = new(_gptToUse.Endpoint);
159+
}
160+
155161
var aiClient = new OpenAIClient(new ApiKeyCredential(userKey), clientOptions);
156162
_client = aiClient.GetChatClient(_gptToUse.ModelName);
157163
}

0 commit comments

Comments
 (0)