Skip to content

Commit c017954

Browse files
committed
simplify embedding format in response
1 parent df4fc93 commit c017954

File tree

2 files changed

+78
-45
lines changed

2 files changed

+78
-45
lines changed

async-openai/src/embedding.rs

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,7 @@
11
use crate::{
22
config::Config,
33
error::OpenAIError,
4-
types::{
5-
CreateBase64EmbeddingResponse, CreateEmbeddingRequest, CreateEmbeddingResponse,
6-
EncodingFormat,
7-
},
4+
types::{CreateEmbeddingRequest, CreateEmbeddingResponse, EncodingFormat},
85
Client,
96
};
107

@@ -25,6 +22,16 @@ impl<'c, C: Config> Embeddings<'c, C> {
2522
pub async fn create(
2623
&self,
2724
request: CreateEmbeddingRequest,
25+
) -> Result<CreateEmbeddingResponse, OpenAIError> {
26+
self.client.post("/embeddings", request).await
27+
}
28+
29+
/// Creates an embedding vector representing the input text.
30+
///
31+
/// The response will contain the embedding in float-vector format.
32+
pub async fn create_float(
33+
&self,
34+
request: CreateEmbeddingRequest,
2835
) -> Result<CreateEmbeddingResponse, OpenAIError> {
2936
if matches!(request.encoding_format, Some(EncodingFormat::Base64)) {
3037
return Err(OpenAIError::InvalidArgument(
@@ -40,10 +47,10 @@ impl<'c, C: Config> Embeddings<'c, C> {
4047
pub async fn create_base64(
4148
&self,
4249
request: CreateEmbeddingRequest,
43-
) -> Result<CreateBase64EmbeddingResponse, OpenAIError> {
50+
) -> Result<CreateEmbeddingResponse, OpenAIError> {
4451
if !matches!(request.encoding_format, Some(EncodingFormat::Base64)) {
4552
return Err(OpenAIError::InvalidArgument(
46-
"When encoding_format is not base64, use Embeddings::create".into(),
53+
"When encoding_format is not base64, use Embeddings::create_float".into(),
4754
));
4855
}
4956

@@ -166,7 +173,7 @@ mod tests {
166173
.encoding_format(EncodingFormat::Base64)
167174
.build()
168175
.unwrap();
169-
let b64_response = client.embeddings().create(b64_request).await;
176+
let b64_response = client.embeddings().create_float(b64_request).await;
170177
assert!(matches!(b64_response, Err(OpenAIError::InvalidArgument(_))));
171178
}
172179

@@ -196,8 +203,8 @@ mod tests {
196203
.input(INPUT)
197204
.build()
198205
.unwrap();
199-
let response = client.embeddings().create(request).await.unwrap();
200-
let embedding = response.data.into_iter().next().unwrap().embedding;
206+
let response = client.embeddings().create_float(request).await.unwrap();
207+
let embedding: Vec<f32> = response.data.into_iter().next().unwrap().embedding.into();
201208

202209
assert_eq!(b64_embedding.len(), embedding.len());
203210
for (b64, normal) in b64_embedding.iter().zip(embedding.iter()) {

async-openai/src/types/embedding.rs

Lines changed: 62 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -53,42 +53,79 @@ pub struct CreateEmbeddingRequest {
5353
pub dimensions: Option<u32>,
5454
}
5555

56-
/// Represents an embedding vector returned by embedding endpoint.
5756
#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
58-
pub struct Embedding {
59-
/// The index of the embedding in the list of embeddings.
60-
pub index: u32,
61-
/// The object type, which is always "embedding".
62-
pub object: String,
63-
/// The embedding vector, which is a list of floats. The length of vector
64-
/// depends on the model as listed in the [embedding guide](https://platform.openai.com/docs/guides/embeddings).
65-
pub embedding: Vec<f32>,
57+
#[serde(untagged)]
58+
pub enum EmbeddingVector {
59+
Float(Vec<f32>),
60+
Base64(String),
6661
}
6762

68-
#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
69-
pub struct Base64EmbeddingVector(pub String);
70-
71-
impl From<Base64EmbeddingVector> for Vec<f32> {
72-
fn from(value: Base64EmbeddingVector) -> Self {
73-
let bytes = general_purpose::STANDARD
74-
.decode(value.0)
75-
.expect("openai base64 encoding to be valid");
76-
let chunks = bytes.chunks_exact(4);
77-
chunks
78-
.map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
79-
.collect()
63+
impl From<EmbeddingVector> for Vec<f32> {
64+
fn from(val: EmbeddingVector) -> Self {
65+
match val {
66+
EmbeddingVector::Float(v) => v,
67+
EmbeddingVector::Base64(s) => {
68+
let bytes = general_purpose::STANDARD
69+
.decode(s)
70+
.expect("openai base64 encoding to be valid");
71+
let chunks = bytes.chunks_exact(4);
72+
chunks
73+
.map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
74+
.collect::<Vec<f32>>()
75+
}
76+
}
77+
}
78+
}
79+
80+
/// Converts an embedding vector to a base64-encoded string.
81+
impl From<EmbeddingVector> for String {
82+
fn from(val: EmbeddingVector) -> Self {
83+
match val {
84+
EmbeddingVector::Float(v) => {
85+
let mut bytes = Vec::with_capacity(v.len() * 4);
86+
for f in v {
87+
bytes.extend_from_slice(&f.to_le_bytes());
88+
}
89+
general_purpose::STANDARD.encode(&bytes)
90+
}
91+
EmbeddingVector::Base64(s) => s,
92+
}
93+
}
94+
}
95+
96+
impl EmbeddingVector {
97+
pub fn is_empty(&self) -> bool {
98+
match self {
99+
EmbeddingVector::Float(v) => v.is_empty(),
100+
101+
// Don't use .len() to avoid decoding the base64 string
102+
EmbeddingVector::Base64(v) => v.is_empty(),
103+
}
104+
}
105+
106+
pub fn len(&self) -> usize {
107+
match self {
108+
EmbeddingVector::Float(v) => v.len(),
109+
EmbeddingVector::Base64(v) => {
110+
let bytes = general_purpose::STANDARD
111+
.decode(v)
112+
.expect("openai base64 encoding to be valid");
113+
bytes.len() / 4
114+
}
115+
}
80116
}
81117
}
82118

83-
/// Represents an base64-encoded embedding vector returned by embedding endpoint.
119+
/// Represents an embedding vector returned by embedding endpoint.
84120
#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
85-
pub struct Base64Embedding {
121+
pub struct Embedding {
86122
/// The index of the embedding in the list of embeddings.
87123
pub index: u32,
88124
/// The object type, which is always "embedding".
89125
pub object: String,
90-
/// The embedding vector, encoded in base64.
91-
pub embedding: Base64EmbeddingVector,
126+
/// The embedding vector, which is a list of floats. The length of vector
127+
/// depends on the model as listed in the [embedding guide](https://platform.openai.com/docs/guides/embeddings).
128+
pub embedding: EmbeddingVector,
92129
}
93130

94131
#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
@@ -109,14 +146,3 @@ pub struct CreateEmbeddingResponse {
109146
/// The usage information for the request.
110147
pub usage: EmbeddingUsage,
111148
}
112-
113-
#[derive(Debug, Deserialize, Clone, PartialEq, Serialize)]
114-
pub struct CreateBase64EmbeddingResponse {
115-
pub object: String,
116-
/// The name of the model used to generate the embedding.
117-
pub model: String,
118-
/// The list of embeddings generated by the model.
119-
pub data: Vec<Base64Embedding>,
120-
/// The usage information for the request.
121-
pub usage: EmbeddingUsage,
122-
}

0 commit comments

Comments
 (0)