diff --git a/Runtime/DataTypes.cs b/Runtime/DataTypes.cs index 392188f..77771b4 100644 --- a/Runtime/DataTypes.cs +++ b/Runtime/DataTypes.cs @@ -1,5 +1,6 @@ using System.Collections.Generic; using Newtonsoft.Json; +using Newtonsoft.Json.Linq; namespace OpenAI { @@ -83,6 +84,13 @@ public class OpenAIModelResponse : OpenAIModel, IResponse #endregion #region Chat API Data Types + + public enum ResponseFormat + { + Text, + JsonObject + } + public sealed class CreateChatCompletionRequest { public string Model { get; set; } @@ -97,6 +105,40 @@ public sealed class CreateChatCompletionRequest public Dictionary LogitBias { get; set; } public string User { get; set; } public string SystemFingerprint { get; set; } + + [JsonConverter(typeof(ResponseFormatJsonConverter))] + public ResponseFormat? ResponseFormat { get; set; } + } + + public class ResponseFormatJsonConverter : JsonConverter + { + public override void WriteJson(JsonWriter writer, ResponseFormat value, JsonSerializer serializer) + { + if (value == ResponseFormat.JsonObject) + { + writer.WriteStartObject(); + writer.WritePropertyName("type"); + writer.WriteValue("json_object"); + writer.WriteEndObject(); + } else + { + writer.WriteNull(); + } + } + + public override ResponseFormat ReadJson(JsonReader reader, System.Type objectType, ResponseFormat existingValue, bool hasExistingValue, JsonSerializer serializer) + { + if (reader.TokenType == JsonToken.StartObject) + { + JObject obj = JObject.Load(reader); + if (obj.TryGetValue("type", out JToken typeToken) && typeToken.ToString() == "json_object") + { + return ResponseFormat.JsonObject; + } + } + + return ResponseFormat.Text; + } } public struct CreateChatCompletionResponse : IResponse diff --git a/Runtime/OpenAIApi.cs b/Runtime/OpenAIApi.cs index 5dd54f6..f5c9456 100644 --- a/Runtime/OpenAIApi.cs +++ b/Runtime/OpenAIApi.cs @@ -35,11 +35,18 @@ private Configuration Configuration } } + private const string OFFICIAL_BASE_PATH = "https://api.openai.com/v1"; + /// OpenAI API base path for requests. - private const string BASE_PATH = "https://api.openai.com/v1"; - - public OpenAIApi(string apiKey = null, string organization = null) + private readonly string _basePath; + + public OpenAIApi(string apiKey = null, string organization = null, string baseUrl = null) { + if (string.IsNullOrEmpty(baseUrl)) + _basePath = OFFICIAL_BASE_PATH; + else + _basePath = baseUrl; + if (apiKey != null) { configuration = new Configuration(apiKey, organization); @@ -202,7 +209,7 @@ private byte[] CreatePayload(T request) /// public async Task ListModels() { - var path = $"{BASE_PATH}/models"; + var path = $"{_basePath}/models"; return await DispatchRequest(path, UnityWebRequest.kHttpVerbGET); } @@ -213,7 +220,7 @@ public async Task ListModels() /// See public async Task RetrieveModel(string id) { - var path = $"{BASE_PATH}/models/{id}"; + var path = $"{_basePath}/models/{id}"; return await DispatchRequest(path, UnityWebRequest.kHttpVerbGET); } @@ -224,9 +231,8 @@ public async Task RetrieveModel(string id) /// See public async Task CreateChatCompletion(CreateChatCompletionRequest request) { - var path = $"{BASE_PATH}/chat/completions"; + var path = $"{_basePath}/chat/completions"; var payload = CreatePayload(request); - return await DispatchRequest(path, UnityWebRequest.kHttpVerbPOST, payload); } @@ -240,7 +246,7 @@ public async Task CreateChatCompletion(CreateChatC public void CreateChatCompletionAsync(CreateChatCompletionRequest request, Action> onResponse, Action onComplete, CancellationTokenSource token) { request.Stream = true; - var path = $"{BASE_PATH}/chat/completions"; + var path = $"{_basePath}/chat/completions"; var payload = CreatePayload(request); DispatchRequest(path, UnityWebRequest.kHttpVerbPOST, onResponse, onComplete, token, payload); @@ -253,7 +259,7 @@ public void CreateChatCompletionAsync(CreateChatCompletionRequest request, Actio /// See public async Task CreateImage(CreateImageRequest request) { - var path = $"{BASE_PATH}/images/generations"; + var path = $"{_basePath}/images/generations"; var payload = CreatePayload(request); return await DispatchRequest(path, UnityWebRequest.kHttpVerbPOST, payload); } @@ -265,7 +271,7 @@ public async Task CreateImage(CreateImageRequest request) /// See public async Task CreateImageEdit(CreateImageEditRequest request) { - var path = $"{BASE_PATH}/images/edits"; + var path = $"{_basePath}/images/edits"; var form = new List(); form.AddFile(request.Image, "image", "image/png"); @@ -285,7 +291,7 @@ public async Task CreateImageEdit(CreateImageEditRequest re /// See public async Task CreateImageVariation(CreateImageVariationRequest request) { - var path = $"{BASE_PATH}/images/variations"; + var path = $"{_basePath}/images/variations"; var form = new List(); form.AddFile(request.Image, "image", "image/png"); @@ -304,7 +310,7 @@ public async Task CreateImageVariation(CreateImageVariation /// See public async Task CreateEmbeddings(CreateEmbeddingsRequest request) { - var path = $"{BASE_PATH}/embeddings"; + var path = $"{_basePath}/embeddings"; var payload = CreatePayload(request); return await DispatchRequest(path, UnityWebRequest.kHttpVerbPOST, payload); } @@ -316,7 +322,7 @@ public async Task CreateEmbeddings(CreateEmbeddingsReq /// See public async Task CreateAudioTranscription(CreateAudioTranscriptionsRequest request) { - var path = $"{BASE_PATH}/audio/transcriptions"; + var path = $"{_basePath}/audio/transcriptions"; var form = new List(); if (string.IsNullOrEmpty(request.File)) @@ -343,7 +349,7 @@ public async Task CreateAudioTranscription(CreateAudioTrans /// See public async Task CreateAudioTranslation(CreateAudioTranslationRequest request) { - var path = $"{BASE_PATH}/audio/translations"; + var path = $"{_basePath}/audio/translations"; var form = new List(); if (string.IsNullOrEmpty(request.File)) @@ -368,7 +374,7 @@ public async Task CreateAudioTranslation(CreateAudioTransla /// See public async Task ListFiles() { - var path = $"{BASE_PATH}/files"; + var path = $"{_basePath}/files"; return await DispatchRequest(path, UnityWebRequest.kHttpVerbGET); } @@ -381,7 +387,7 @@ public async Task ListFiles() /// See public async Task CreateFile(CreateFileRequest request) { - var path = $"{BASE_PATH}/files"; + var path = $"{_basePath}/files"; var form = new List(); form.AddFile(request.File, "file", "application/json"); @@ -397,7 +403,7 @@ public async Task CreateFile(CreateFileRequest request) /// See public async Task DeleteFile(string id) { - var path = $"{BASE_PATH}/files/{id}"; + var path = $"{_basePath}/files/{id}"; return await DispatchRequest(path, UnityWebRequest.kHttpVerbDELETE); } @@ -408,7 +414,7 @@ public async Task DeleteFile(string id) /// See public async Task RetrieveFile(string id) { - var path = $"{BASE_PATH}/files/{id}"; + var path = $"{_basePath}/files/{id}"; return await DispatchRequest(path, UnityWebRequest.kHttpVerbGET); } @@ -419,7 +425,7 @@ public async Task RetrieveFile(string id) /// See public async Task DownloadFile(string id) { - var path = $"{BASE_PATH}/files/{id}/content"; + var path = $"{_basePath}/files/{id}/content"; return await DispatchRequest(path, UnityWebRequest.kHttpVerbGET); } @@ -431,7 +437,7 @@ public async Task DownloadFile(string id) /// See public async Task CreateFineTune(CreateFineTuneRequest request) { - var path = $"{BASE_PATH}/fine-tunes"; + var path = $"{_basePath}/fine-tunes"; var payload = CreatePayload(request); return await DispatchRequest(path, UnityWebRequest.kHttpVerbPOST, payload); } @@ -442,7 +448,7 @@ public async Task CreateFineTune(CreateFineTuneRequest request) /// See public async Task ListFineTunes() { - var path = $"{BASE_PATH}/fine-tunes"; + var path = $"{_basePath}/fine-tunes"; return await DispatchRequest(path, UnityWebRequest.kHttpVerbGET); } @@ -453,7 +459,7 @@ public async Task ListFineTunes() /// See public async Task RetrieveFineTune(string id) { - var path = $"{BASE_PATH}/fine-tunes/{id}"; + var path = $"{_basePath}/fine-tunes/{id}"; return await DispatchRequest(path, UnityWebRequest.kHttpVerbGET); } @@ -464,7 +470,7 @@ public async Task RetrieveFineTune(string id) /// See public async Task CancelFineTune(string id) { - var path = $"{BASE_PATH}/fine-tunes/{id}/cancel"; + var path = $"{_basePath}/fine-tunes/{id}/cancel"; return await DispatchRequest(path, UnityWebRequest.kHttpVerbPOST); } @@ -479,7 +485,7 @@ public async Task CancelFineTune(string id) /// See public async Task ListFineTuneEvents(string id, bool stream = false) { - var path = $"{BASE_PATH}/fine-tunes/{id}/events?stream={stream}"; + var path = $"{_basePath}/fine-tunes/{id}/events?stream={stream}"; return await DispatchRequest(path, UnityWebRequest.kHttpVerbGET); } @@ -490,7 +496,7 @@ public async Task ListFineTuneEvents(string id, bool /// See public async Task DeleteFineTunedModel(string model) { - var path = $"{BASE_PATH}/models/{model}"; + var path = $"{_basePath}/models/{model}"; return await DispatchRequest(path, UnityWebRequest.kHttpVerbDELETE); } @@ -501,7 +507,7 @@ public async Task DeleteFineTunedModel(string model) /// See public async Task CreateModeration(CreateModerationRequest request) { - var path = $"{BASE_PATH}/moderations"; + var path = $"{_basePath}/moderations"; var payload = CreatePayload(request); return await DispatchRequest(path, UnityWebRequest.kHttpVerbPOST, payload); }