Skip to content

Commit 0e7a629

Browse files
cfraz8964bit
andauthored
Implement vector store search, retrieve file content operations (#360)
* Implement vector search api * Make ids in ListVectorStoreFilesResponse optional, as they can come back null when there are no files * Implement vector file content api * Add Default derive to RankingOptions, make CompountFilter.type non-optional * Made comparison type non-optional * Make compound filter a Vec of VectorStoreSearchFilter * Implement from conversions for filters * Add vector store retrieval example * Update example readme * Add attributes to create vector store * Update examples/vector-store-retrieval/src/main.rs * Update examples/vector-store-retrieval/src/main.rs --------- Co-authored-by: Himanshu Neema <himanshun.iitkgp@gmail.com>
1 parent aeb6d1f commit 0e7a629

File tree

8 files changed

+407
-4
lines changed

8 files changed

+407
-4
lines changed

async-openai/src/types/vector_store.rs

Lines changed: 249 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -140,8 +140,8 @@ pub struct UpdateVectorStoreRequest {
140140
pub struct ListVectorStoreFilesResponse {
141141
pub object: String,
142142
pub data: Vec<VectorStoreFileObject>,
143-
pub first_id: String,
144-
pub last_id: String,
143+
pub first_id: Option<String>,
144+
pub last_id: Option<String>,
145145
pub has_more: bool,
146146
}
147147

@@ -209,7 +209,10 @@ pub enum VectorStoreFileObjectChunkingStrategy {
209209
pub struct CreateVectorStoreFileRequest {
210210
/// A [File](https://platform.openai.com/docs/api-reference/files) ID that the vector store should use. Useful for tools like `file_search` that can access files.
211211
pub file_id: String,
212+
#[serde(skip_serializing_if = "Option::is_none")]
212213
pub chunking_strategy: Option<VectorStoreChunkingStrategy>,
214+
#[serde(skip_serializing_if = "Option::is_none")]
215+
pub attributes: Option<HashMap<String, AttributeValue>>,
213216
}
214217

215218
#[derive(Debug, Deserialize, Clone, PartialEq, Serialize)]
@@ -269,3 +272,247 @@ pub struct VectorStoreFileBatchObject {
269272
pub status: VectorStoreFileBatchStatus,
270273
pub file_counts: VectorStoreFileBatchCounts,
271274
}
275+
276+
/// Represents the parsed content of a vector store file.
277+
#[derive(Debug, Deserialize, Clone, PartialEq, Serialize)]
278+
pub struct VectorStoreFileContentResponse {
279+
/// The object type, which is always `vector_store.file_content.page`
280+
pub object: String,
281+
282+
/// Parsed content of the file.
283+
pub data: Vec<VectorStoreFileContentObject>,
284+
285+
/// Indicates if there are more content pages to fetch.
286+
pub has_more: bool,
287+
288+
/// The token for the next page, if any.
289+
pub next_page: Option<String>,
290+
}
291+
292+
/// Represents the parsed content of a vector store file.
293+
#[derive(Debug, Deserialize, Clone, PartialEq, Serialize)]
294+
pub struct VectorStoreFileContentObject {
295+
/// The content type (currently only `"text"`)
296+
pub r#type: String,
297+
298+
/// The text content
299+
pub text: String,
300+
}
301+
302+
#[derive(Debug, Serialize, Default, Clone, Builder, PartialEq, Deserialize)]
303+
#[builder(name = "VectorStoreSearchRequestArgs")]
304+
#[builder(pattern = "mutable")]
305+
#[builder(setter(into, strip_option), default)]
306+
#[builder(derive(Debug))]
307+
#[builder(build_fn(error = "OpenAIError"))]
308+
pub struct VectorStoreSearchRequest {
309+
/// A query string for a search.
310+
pub query: VectorStoreSearchQuery,
311+
312+
/// Whether to rewrite the natural language query for vector search.
313+
#[serde(skip_serializing_if = "Option::is_none")]
314+
pub rewrite_query: Option<bool>,
315+
316+
/// The maximum number of results to return. This number should be between 1 and 50 inclusive.
317+
#[serde(skip_serializing_if = "Option::is_none")]
318+
pub max_num_results: Option<u8>,
319+
320+
/// A filter to apply based on file attributes.
321+
#[serde(skip_serializing_if = "Option::is_none")]
322+
pub filters: Option<VectorStoreSearchFilter>,
323+
324+
/// Ranking options for search.
325+
#[serde(skip_serializing_if = "Option::is_none")]
326+
pub ranking_options: Option<RankingOptions>,
327+
}
328+
329+
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
330+
#[serde(untagged)]
331+
pub enum VectorStoreSearchQuery {
332+
/// A single query to search for.
333+
Text(String),
334+
/// A list of queries to search for.
335+
Array(Vec<String>),
336+
}
337+
338+
impl Default for VectorStoreSearchQuery {
339+
fn default() -> Self {
340+
Self::Text(String::new())
341+
}
342+
}
343+
344+
impl From<String> for VectorStoreSearchQuery {
345+
fn from(query: String) -> Self {
346+
Self::Text(query)
347+
}
348+
}
349+
350+
impl From<&str> for VectorStoreSearchQuery {
351+
fn from(query: &str) -> Self {
352+
Self::Text(query.to_string())
353+
}
354+
}
355+
356+
impl From<Vec<String>> for VectorStoreSearchQuery {
357+
fn from(query: Vec<String>) -> Self {
358+
Self::Array(query)
359+
}
360+
}
361+
362+
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
363+
#[serde(untagged)]
364+
pub enum VectorStoreSearchFilter {
365+
Comparison(ComparisonFilter),
366+
Compound(CompoundFilter),
367+
}
368+
369+
impl From<ComparisonFilter> for VectorStoreSearchFilter {
370+
fn from(filter: ComparisonFilter) -> Self {
371+
Self::Comparison(filter)
372+
}
373+
}
374+
375+
impl From<CompoundFilter> for VectorStoreSearchFilter {
376+
fn from(filter: CompoundFilter) -> Self {
377+
Self::Compound(filter)
378+
}
379+
}
380+
381+
/// A filter used to compare a specified attribute key to a given value using a defined comparison operation.
382+
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
383+
pub struct ComparisonFilter {
384+
/// Specifies the comparison operator: `eq`, `ne`, `gt`, `gte`, `lt`, `lte`.
385+
pub r#type: ComparisonType,
386+
387+
/// The key to compare against the value.
388+
pub key: String,
389+
390+
/// The value to compare against the attribute key; supports string, number, or boolean types.
391+
pub value: AttributeValue,
392+
}
393+
394+
/// Specifies the comparison operator: `eq`, `ne`, `gt`, `gte`, `lt`, `lte`.
395+
#[derive(Debug, Serialize, Deserialize, Clone, Copy, PartialEq)]
396+
#[serde(rename_all = "lowercase")]
397+
pub enum ComparisonType {
398+
Eq,
399+
Ne,
400+
Gt,
401+
Gte,
402+
Lt,
403+
Lte,
404+
}
405+
406+
/// The value to compare against the attribute key; supports string, number, or boolean types.
407+
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
408+
#[serde(untagged)]
409+
pub enum AttributeValue {
410+
String(String),
411+
Number(i64),
412+
Boolean(bool),
413+
}
414+
415+
impl From<String> for AttributeValue {
416+
fn from(value: String) -> Self {
417+
Self::String(value)
418+
}
419+
}
420+
421+
impl From<i64> for AttributeValue {
422+
fn from(value: i64) -> Self {
423+
Self::Number(value)
424+
}
425+
}
426+
427+
impl From<bool> for AttributeValue {
428+
fn from(value: bool) -> Self {
429+
Self::Boolean(value)
430+
}
431+
}
432+
433+
impl From<&str> for AttributeValue {
434+
fn from(value: &str) -> Self {
435+
Self::String(value.to_string())
436+
}
437+
}
438+
439+
/// Ranking options for search.
440+
#[derive(Debug, Serialize, Default, Deserialize, Clone, PartialEq)]
441+
pub struct RankingOptions {
442+
#[serde(skip_serializing_if = "Option::is_none")]
443+
pub ranker: Option<Ranker>,
444+
445+
#[serde(skip_serializing_if = "Option::is_none")]
446+
pub score_threshold: Option<f32>,
447+
}
448+
449+
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
450+
pub enum Ranker {
451+
#[serde(rename = "auto")]
452+
Auto,
453+
#[serde(rename = "default-2024-11-15")]
454+
Default20241115,
455+
}
456+
457+
/// Combine multiple filters using `and` or `or`.
458+
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
459+
pub struct CompoundFilter {
460+
/// Type of operation: `and` or `or`.
461+
pub r#type: CompoundFilterType,
462+
463+
/// Array of filters to combine. Items can be `ComparisonFilter` or `CompoundFilter`
464+
pub filters: Vec<VectorStoreSearchFilter>,
465+
}
466+
467+
/// Type of operation: `and` or `or`.
468+
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
469+
#[serde(rename_all = "lowercase")]
470+
pub enum CompoundFilterType {
471+
And,
472+
Or,
473+
}
474+
475+
#[derive(Debug, Deserialize, Clone, PartialEq, Serialize)]
476+
pub struct VectorStoreSearchResultsPage {
477+
/// The object type, which is always `vector_store.search_results.page`.
478+
pub object: String,
479+
480+
/// The query used for this search.
481+
pub search_query: Vec<String>,
482+
483+
/// The list of search result items.
484+
pub data: Vec<VectorStoreSearchResultItem>,
485+
486+
/// Indicates if there are more results to fetch.
487+
pub has_more: bool,
488+
489+
/// The token for the next page, if any.
490+
pub next_page: Option<String>,
491+
}
492+
493+
#[derive(Debug, Deserialize, Clone, PartialEq, Serialize)]
494+
pub struct VectorStoreSearchResultItem {
495+
/// The ID of the vector store file.
496+
pub file_id: String,
497+
498+
/// The name of the vector store file.
499+
pub filename: String,
500+
501+
/// The similarity score for the result.
502+
pub score: f32, // minimum: 0, maximum: 1
503+
504+
/// Attributes of the vector store file.
505+
pub attributes: HashMap<String, AttributeValue>,
506+
507+
/// Content chunks from the file.
508+
pub content: Vec<VectorStoreSearchResultContentObject>,
509+
}
510+
511+
#[derive(Debug, Deserialize, Clone, PartialEq, Serialize)]
512+
pub struct VectorStoreSearchResultContentObject {
513+
/// The type of content
514+
pub r#type: String,
515+
516+
/// The text content returned from search.
517+
pub text: String,
518+
}

async-openai/src/vector_store_files.rs

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ use crate::{
55
error::OpenAIError,
66
types::{
77
CreateVectorStoreFileRequest, DeleteVectorStoreFileResponse, ListVectorStoreFilesResponse,
8-
VectorStoreFileObject,
8+
VectorStoreFileContentResponse, VectorStoreFileObject,
99
},
1010
Client,
1111
};
@@ -78,6 +78,20 @@ impl<'c, C: Config> VectorStoreFiles<'c, C> {
7878
)
7979
.await
8080
}
81+
82+
/// Retrieve the parsed contents of a vector store file.
83+
#[crate::byot(T0 = std::fmt::Display, R = serde::de::DeserializeOwned)]
84+
pub async fn retrieve_file_content(
85+
&self,
86+
file_id: &str,
87+
) -> Result<VectorStoreFileContentResponse, OpenAIError> {
88+
self.client
89+
.get(&format!(
90+
"/vector_stores/{}/files/{file_id}/content",
91+
&self.vector_store_id
92+
))
93+
.await
94+
}
8195
}
8296

8397
#[cfg(test)]

async-openai/src/vector_stores.rs

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@ use crate::{
55
error::OpenAIError,
66
types::{
77
CreateVectorStoreRequest, DeleteVectorStoreResponse, ListVectorStoresResponse,
8-
UpdateVectorStoreRequest, VectorStoreObject,
8+
UpdateVectorStoreRequest, VectorStoreObject, VectorStoreSearchRequest,
9+
VectorStoreSearchResultsPage,
910
},
1011
vector_store_file_batches::VectorStoreFileBatches,
1112
Client, VectorStoreFiles,
@@ -78,4 +79,16 @@ impl<'c, C: Config> VectorStores<'c, C> {
7879
.post(&format!("/vector_stores/{vector_store_id}"), request)
7980
.await
8081
}
82+
83+
/// Searches a vector store.
84+
#[crate::byot(T0 = std::fmt::Display, T1 = serde::Serialize, R = serde::de::DeserializeOwned)]
85+
pub async fn search(
86+
&self,
87+
vector_store_id: &str,
88+
request: VectorStoreSearchRequest,
89+
) -> Result<VectorStoreSearchResultsPage, OpenAIError> {
90+
self.client
91+
.post(&format!("/vector_stores/{vector_store_id}/search"), request)
92+
.await
93+
}
8194
}
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
[package]
2+
name = "vector-store-retrieval"
3+
version = "0.1.0"
4+
edition = "2021"
5+
publish = false
6+
7+
[dependencies]
8+
async-openai = { path = "../../async-openai" }
9+
tokio = { version = "1.43.0", features = ["full"] }
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
## Intro
2+
3+
This example is based on https://platform.openai.com/docs/guides/retrieval
4+
5+
6+
## Data
7+
8+
Uber Annual Report obtained from https://investor.uber.com/financials/
9+
10+
Lyft Annual Report obtained from https://investor.lyft.com/financials-and-reports/annual-reports/default.aspx
11+
12+
13+
## Output
14+
15+
```
16+
Waiting for vector store to be[] ready...
17+
Search results: VectorStoreSearchResultsPage {
18+
object: "vector_store.search_results.page",
19+
search_query: [
20+
"uber profit",
21+
],
22+
data: [
23+
VectorStoreSearchResultItem {
24+
file_id: "file-1XFoSYUzJudwJLkAazLdjd",
25+
filename: "uber-10k.pdf",
26+
score: 0.5618923,
27+
attributes: {},
28+
content: [
29+
VectorStoreSearchResultContentObject {
30+
type: "text",
31+
text: "(In millions) Q1 2022 Q2 2022 Q3 2022 Q4 2022 Q1 2023 Q2 2023 Q3 2023 Q4 2023\n\nMobility $ 10,723 $ 13,364 $ 13,684 $ 14,894 $ 14,981 $ 16,728 $ 17,903 $ 19,285 \nDelivery 13,903 13,876 13,684 14,315 15,026 15,595 16,094 17,011 \nFreight 1,823 1,838 1,751 1,540 1,401 1,278 1,284 1,279 \n\nAdjusted EBITDA.
32+
...
33+
```
Binary file not shown.
Binary file not shown.

0 commit comments

Comments
 (0)