diff --git a/OpenAI_API/Embedding/EmbeddingEndpoint.cs b/OpenAI_API/Embedding/EmbeddingEndpoint.cs index a324de9..fe77c89 100644 --- a/OpenAI_API/Embedding/EmbeddingEndpoint.cs +++ b/OpenAI_API/Embedding/EmbeddingEndpoint.cs @@ -28,10 +28,11 @@ internal EmbeddingEndpoint(OpenAIAPI api) : base(api) { } /// Ask the API to embedd text using the default embedding model /// /// Text to be embedded + /// Embeddings model to be used /// Asynchronously returns the embedding result. Look in its property of to find the vector of floating point numbers - public async Task CreateEmbeddingAsync(string input) + public async Task CreateEmbeddingAsync(string input, Model model = null) { - EmbeddingRequest req = new EmbeddingRequest(DefaultEmbeddingRequestArgs.Model, input); + EmbeddingRequest req = new EmbeddingRequest(model ?? DefaultEmbeddingRequestArgs.Model, input); return await CreateEmbeddingAsync(req); } @@ -46,14 +47,25 @@ public async Task CreateEmbeddingAsync(EmbeddingRequest request } /// - /// Ask the API to embedd text using the default embedding model + /// Ask the API to embedd text using the default embedding model in case no other model is specified /// /// Text to be embedded + /// Embeddings model to be used + /// Asynchronously returns the first embedding result as an array of floats. + public async Task GetEmbeddingsAsync(string input, Model model = null) + { + EmbeddingRequest req = new EmbeddingRequest(model ?? DefaultEmbeddingRequestArgs.Model, input); + return await GetEmbeddingsAsync(req); + } + + /// + /// Ask the API to embedd text + /// + /// Request to be send /// Asynchronously returns the first embedding result as an array of floats. - public async Task GetEmbeddingsAsync(string input) + public async Task GetEmbeddingsAsync(EmbeddingRequest request) { - EmbeddingRequest req = new EmbeddingRequest(DefaultEmbeddingRequestArgs.Model, input); - var embeddingResult = await CreateEmbeddingAsync(req); + var embeddingResult = await CreateEmbeddingAsync(request); return embeddingResult?.Data?[0]?.Embedding; } } diff --git a/OpenAI_API/Embedding/EmbeddingRequest.cs b/OpenAI_API/Embedding/EmbeddingRequest.cs index 99780eb..cddbe0e 100644 --- a/OpenAI_API/Embedding/EmbeddingRequest.cs +++ b/OpenAI_API/Embedding/EmbeddingRequest.cs @@ -20,6 +20,18 @@ public class EmbeddingRequest [JsonProperty("input")] public string Input { get; set; } + /// + /// The number of dimensions the resulting output embeddings should have. Only supported in text-embedding-3 and later models. + /// + [JsonProperty("dimensions")] + public int? Dimensions { get; set; } + + /// + /// A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse. + /// + [JsonProperty("user")] + public string User { get; set; } + /// /// Cretes a new, empty /// diff --git a/OpenAI_API/Embedding/IEmbeddingEndpoint.cs b/OpenAI_API/Embedding/IEmbeddingEndpoint.cs index acd9069..a45cbbf 100644 --- a/OpenAI_API/Embedding/IEmbeddingEndpoint.cs +++ b/OpenAI_API/Embedding/IEmbeddingEndpoint.cs @@ -17,8 +17,9 @@ public interface IEmbeddingEndpoint /// Ask the API to embedd text using the default embedding model /// /// Text to be embedded + /// Embeddings model to be used /// Asynchronously returns the embedding result. Look in its property of to find the vector of floating point numbers - Task CreateEmbeddingAsync(string input); + Task CreateEmbeddingAsync(string input, Model model = null); /// /// Ask the API to embedd text using a custom request @@ -28,10 +29,18 @@ public interface IEmbeddingEndpoint Task CreateEmbeddingAsync(EmbeddingRequest request); /// - /// Ask the API to embedd text using the default embedding model + /// Ask the API to embedd text using the default embedding model in case no other model is specified /// /// Text to be embedded + /// Embeddings model to be used + /// Asynchronously returns the first embedding result as an array of floats. + Task GetEmbeddingsAsync(string input, Model model = null); + + /// + /// Ask the API to embedd text using the default embedding model in case no other model is specified + /// + /// Request to be send /// Asynchronously returns the first embedding result as an array of floats. - Task GetEmbeddingsAsync(string input); + Task GetEmbeddingsAsync(EmbeddingRequest request); } } \ No newline at end of file diff --git a/OpenAI_API/Model/Model.cs b/OpenAI_API/Model/Model.cs index 49c0420..0b1fb0e 100644 --- a/OpenAI_API/Model/Model.cs +++ b/OpenAI_API/Model/Model.cs @@ -218,9 +218,19 @@ public async Task RetrieveModelDetailsAsync(OpenAI_API.OpenAIAPI api) #region Embeddings /// - /// OpenAI offers one second-generation embedding model for use with the embeddings API endpoint. + /// This model is not deprecated yet, but OpenAI recommends to use the newer models text-embedding-3-small and text-embedding-3-large. /// public static Model AdaTextEmbedding => new Model("text-embedding-ada-002") { OwnedBy = "openai" }; + + /// + /// Highly efficient embedding model which provides a significant upgrade over its predecessor, the text-embedding-ada-002 model + /// + public static Model SmallTextEmbedding => new Model("text-embedding-3-small") { OwnedBy = "openai" }; + + /// + /// Next generation larger embedding model which creates embeddings with up to 3072 dimensions + /// + public static Model LargeTextEmbedding => new Model("text-embedding-3-large") { OwnedBy = "openai" }; #endregion #region Moderation diff --git a/OpenAI_Tests/EmbeddingEndpointTests.cs b/OpenAI_Tests/EmbeddingEndpointTests.cs index 8297907..40e0fe9 100644 --- a/OpenAI_Tests/EmbeddingEndpointTests.cs +++ b/OpenAI_Tests/EmbeddingEndpointTests.cs @@ -64,5 +64,19 @@ public void GetSimpleEmbedding() Assert.IsNotNull(results); Assert.That(results.Length == 1536); } + + [Test] + public void GetSimpleEmbeddingWithDimensions() + { + var api = new OpenAI_API.OpenAIAPI(); + + Assert.IsNotNull(api.Embeddings); + + var request = new EmbeddingRequest(Model.SmallTextEmbedding, "A test text for embedding"); + request.Dimensions = 256; + var results = api.Embeddings.GetEmbeddingsAsync(request).Result; + Assert.IsNotNull(results); + Assert.That(results.Length == 256); + } } }