diff --git a/OpenAI_API/Embedding/EmbeddingEndpoint.cs b/OpenAI_API/Embedding/EmbeddingEndpoint.cs
index a324de9..fe77c89 100644
--- a/OpenAI_API/Embedding/EmbeddingEndpoint.cs
+++ b/OpenAI_API/Embedding/EmbeddingEndpoint.cs
@@ -28,10 +28,11 @@ internal EmbeddingEndpoint(OpenAIAPI api) : base(api) { }
/// Ask the API to embedd text using the default embedding model
///
/// Text to be embedded
+ /// Embeddings model to be used
/// Asynchronously returns the embedding result. Look in its property of to find the vector of floating point numbers
- public async Task CreateEmbeddingAsync(string input)
+ public async Task CreateEmbeddingAsync(string input, Model model = null)
{
- EmbeddingRequest req = new EmbeddingRequest(DefaultEmbeddingRequestArgs.Model, input);
+ EmbeddingRequest req = new EmbeddingRequest(model ?? DefaultEmbeddingRequestArgs.Model, input);
return await CreateEmbeddingAsync(req);
}
@@ -46,14 +47,25 @@ public async Task CreateEmbeddingAsync(EmbeddingRequest request
}
///
- /// Ask the API to embedd text using the default embedding model
+ /// Ask the API to embedd text using the default embedding model in case no other model is specified
///
/// Text to be embedded
+ /// Embeddings model to be used
+ /// Asynchronously returns the first embedding result as an array of floats.
+ public async Task GetEmbeddingsAsync(string input, Model model = null)
+ {
+ EmbeddingRequest req = new EmbeddingRequest(model ?? DefaultEmbeddingRequestArgs.Model, input);
+ return await GetEmbeddingsAsync(req);
+ }
+
+ ///
+ /// Ask the API to embedd text
+ ///
+ /// Request to be send
/// Asynchronously returns the first embedding result as an array of floats.
- public async Task GetEmbeddingsAsync(string input)
+ public async Task GetEmbeddingsAsync(EmbeddingRequest request)
{
- EmbeddingRequest req = new EmbeddingRequest(DefaultEmbeddingRequestArgs.Model, input);
- var embeddingResult = await CreateEmbeddingAsync(req);
+ var embeddingResult = await CreateEmbeddingAsync(request);
return embeddingResult?.Data?[0]?.Embedding;
}
}
diff --git a/OpenAI_API/Embedding/EmbeddingRequest.cs b/OpenAI_API/Embedding/EmbeddingRequest.cs
index 99780eb..cddbe0e 100644
--- a/OpenAI_API/Embedding/EmbeddingRequest.cs
+++ b/OpenAI_API/Embedding/EmbeddingRequest.cs
@@ -20,6 +20,18 @@ public class EmbeddingRequest
[JsonProperty("input")]
public string Input { get; set; }
+ ///
+ /// The number of dimensions the resulting output embeddings should have. Only supported in text-embedding-3 and later models.
+ ///
+ [JsonProperty("dimensions")]
+ public int? Dimensions { get; set; }
+
+ ///
+ /// A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse.
+ ///
+ [JsonProperty("user")]
+ public string User { get; set; }
+
///
/// Cretes a new, empty
///
diff --git a/OpenAI_API/Embedding/IEmbeddingEndpoint.cs b/OpenAI_API/Embedding/IEmbeddingEndpoint.cs
index acd9069..a45cbbf 100644
--- a/OpenAI_API/Embedding/IEmbeddingEndpoint.cs
+++ b/OpenAI_API/Embedding/IEmbeddingEndpoint.cs
@@ -17,8 +17,9 @@ public interface IEmbeddingEndpoint
/// Ask the API to embedd text using the default embedding model
///
/// Text to be embedded
+ /// Embeddings model to be used
/// Asynchronously returns the embedding result. Look in its property of to find the vector of floating point numbers
- Task CreateEmbeddingAsync(string input);
+ Task CreateEmbeddingAsync(string input, Model model = null);
///
/// Ask the API to embedd text using a custom request
@@ -28,10 +29,18 @@ public interface IEmbeddingEndpoint
Task CreateEmbeddingAsync(EmbeddingRequest request);
///
- /// Ask the API to embedd text using the default embedding model
+ /// Ask the API to embedd text using the default embedding model in case no other model is specified
///
/// Text to be embedded
+ /// Embeddings model to be used
+ /// Asynchronously returns the first embedding result as an array of floats.
+ Task GetEmbeddingsAsync(string input, Model model = null);
+
+ ///
+ /// Ask the API to embedd text using the default embedding model in case no other model is specified
+ ///
+ /// Request to be send
/// Asynchronously returns the first embedding result as an array of floats.
- Task GetEmbeddingsAsync(string input);
+ Task GetEmbeddingsAsync(EmbeddingRequest request);
}
}
\ No newline at end of file
diff --git a/OpenAI_API/Model/Model.cs b/OpenAI_API/Model/Model.cs
index 49c0420..0b1fb0e 100644
--- a/OpenAI_API/Model/Model.cs
+++ b/OpenAI_API/Model/Model.cs
@@ -218,9 +218,19 @@ public async Task RetrieveModelDetailsAsync(OpenAI_API.OpenAIAPI api)
#region Embeddings
///
- /// OpenAI offers one second-generation embedding model for use with the embeddings API endpoint.
+ /// This model is not deprecated yet, but OpenAI recommends to use the newer models text-embedding-3-small and text-embedding-3-large.
///
public static Model AdaTextEmbedding => new Model("text-embedding-ada-002") { OwnedBy = "openai" };
+
+ ///
+ /// Highly efficient embedding model which provides a significant upgrade over its predecessor, the text-embedding-ada-002 model
+ ///
+ public static Model SmallTextEmbedding => new Model("text-embedding-3-small") { OwnedBy = "openai" };
+
+ ///
+ /// Next generation larger embedding model which creates embeddings with up to 3072 dimensions
+ ///
+ public static Model LargeTextEmbedding => new Model("text-embedding-3-large") { OwnedBy = "openai" };
#endregion
#region Moderation
diff --git a/OpenAI_Tests/EmbeddingEndpointTests.cs b/OpenAI_Tests/EmbeddingEndpointTests.cs
index 8297907..40e0fe9 100644
--- a/OpenAI_Tests/EmbeddingEndpointTests.cs
+++ b/OpenAI_Tests/EmbeddingEndpointTests.cs
@@ -64,5 +64,19 @@ public void GetSimpleEmbedding()
Assert.IsNotNull(results);
Assert.That(results.Length == 1536);
}
+
+ [Test]
+ public void GetSimpleEmbeddingWithDimensions()
+ {
+ var api = new OpenAI_API.OpenAIAPI();
+
+ Assert.IsNotNull(api.Embeddings);
+
+ var request = new EmbeddingRequest(Model.SmallTextEmbedding, "A test text for embedding");
+ request.Dimensions = 256;
+ var results = api.Embeddings.GetEmbeddingsAsync(request).Result;
+ Assert.IsNotNull(results);
+ Assert.That(results.Length == 256);
+ }
}
}