@@ -53,42 +53,79 @@ pub struct CreateEmbeddingRequest {
53
53
pub dimensions : Option < u32 > ,
54
54
}
55
55
56
- /// Represents an embedding vector returned by embedding endpoint.
57
56
#[ derive( Debug , Deserialize , Serialize , Clone , PartialEq ) ]
58
- pub struct Embedding {
59
- /// The index of the embedding in the list of embeddings.
60
- pub index : u32 ,
61
- /// The object type, which is always "embedding".
62
- pub object : String ,
63
- /// The embedding vector, which is a list of floats. The length of vector
64
- /// depends on the model as listed in the [embedding guide](https://platform.openai.com/docs/guides/embeddings).
65
- pub embedding : Vec < f32 > ,
57
+ #[ serde( untagged) ]
58
+ pub enum EmbeddingVector {
59
+ Float ( Vec < f32 > ) ,
60
+ Base64 ( String ) ,
66
61
}
67
62
68
- #[ derive( Debug , Deserialize , Serialize , Clone , PartialEq ) ]
69
- pub struct Base64EmbeddingVector ( pub String ) ;
70
-
71
- impl From < Base64EmbeddingVector > for Vec < f32 > {
72
- fn from ( value : Base64EmbeddingVector ) -> Self {
73
- let bytes = general_purpose:: STANDARD
74
- . decode ( value. 0 )
75
- . expect ( "openai base64 encoding to be valid" ) ;
76
- let chunks = bytes. chunks_exact ( 4 ) ;
77
- chunks
78
- . map ( |chunk| f32:: from_le_bytes ( [ chunk[ 0 ] , chunk[ 1 ] , chunk[ 2 ] , chunk[ 3 ] ] ) )
79
- . collect ( )
63
+ impl From < EmbeddingVector > for Vec < f32 > {
64
+ fn from ( val : EmbeddingVector ) -> Self {
65
+ match val {
66
+ EmbeddingVector :: Float ( v) => v,
67
+ EmbeddingVector :: Base64 ( s) => {
68
+ let bytes = general_purpose:: STANDARD
69
+ . decode ( s)
70
+ . expect ( "openai base64 encoding to be valid" ) ;
71
+ let chunks = bytes. chunks_exact ( 4 ) ;
72
+ chunks
73
+ . map ( |chunk| f32:: from_le_bytes ( [ chunk[ 0 ] , chunk[ 1 ] , chunk[ 2 ] , chunk[ 3 ] ] ) )
74
+ . collect :: < Vec < f32 > > ( )
75
+ }
76
+ }
77
+ }
78
+ }
79
+
80
+ /// Converts an embedding vector to a base64-encoded string.
81
+ impl From < EmbeddingVector > for String {
82
+ fn from ( val : EmbeddingVector ) -> Self {
83
+ match val {
84
+ EmbeddingVector :: Float ( v) => {
85
+ let mut bytes = Vec :: with_capacity ( v. len ( ) * 4 ) ;
86
+ for f in v {
87
+ bytes. extend_from_slice ( & f. to_le_bytes ( ) ) ;
88
+ }
89
+ general_purpose:: STANDARD . encode ( & bytes)
90
+ }
91
+ EmbeddingVector :: Base64 ( s) => s,
92
+ }
93
+ }
94
+ }
95
+
96
+ impl EmbeddingVector {
97
+ pub fn is_empty ( & self ) -> bool {
98
+ match self {
99
+ EmbeddingVector :: Float ( v) => v. is_empty ( ) ,
100
+
101
+ // Don't use .len() to avoid decoding the base64 string
102
+ EmbeddingVector :: Base64 ( v) => v. is_empty ( ) ,
103
+ }
104
+ }
105
+
106
+ pub fn len ( & self ) -> usize {
107
+ match self {
108
+ EmbeddingVector :: Float ( v) => v. len ( ) ,
109
+ EmbeddingVector :: Base64 ( v) => {
110
+ let bytes = general_purpose:: STANDARD
111
+ . decode ( v)
112
+ . expect ( "openai base64 encoding to be valid" ) ;
113
+ bytes. len ( ) / 4
114
+ }
115
+ }
80
116
}
81
117
}
82
118
83
- /// Represents an base64-encoded embedding vector returned by embedding endpoint.
119
+ /// Represents an embedding vector returned by embedding endpoint.
84
120
#[ derive( Debug , Deserialize , Serialize , Clone , PartialEq ) ]
85
- pub struct Base64Embedding {
121
+ pub struct Embedding {
86
122
/// The index of the embedding in the list of embeddings.
87
123
pub index : u32 ,
88
124
/// The object type, which is always "embedding".
89
125
pub object : String ,
90
- /// The embedding vector, encoded in base64.
91
- pub embedding : Base64EmbeddingVector ,
126
+ /// The embedding vector, which is a list of floats. The length of vector
127
+ /// depends on the model as listed in the [embedding guide](https://platform.openai.com/docs/guides/embeddings).
128
+ pub embedding : EmbeddingVector ,
92
129
}
93
130
94
131
#[ derive( Debug , Deserialize , Serialize , Clone , PartialEq ) ]
@@ -109,14 +146,3 @@ pub struct CreateEmbeddingResponse {
109
146
/// The usage information for the request.
110
147
pub usage : EmbeddingUsage ,
111
148
}
112
-
113
- #[ derive( Debug , Deserialize , Clone , PartialEq , Serialize ) ]
114
- pub struct CreateBase64EmbeddingResponse {
115
- pub object : String ,
116
- /// The name of the model used to generate the embedding.
117
- pub model : String ,
118
- /// The list of embeddings generated by the model.
119
- pub data : Vec < Base64Embedding > ,
120
- /// The usage information for the request.
121
- pub usage : EmbeddingUsage ,
122
- }
0 commit comments