Skip to content

Commit 2ad8aca

Browse files
Adding better error handling
1 parent 182a445 commit 2ad8aca

File tree

2 files changed

+157
-93
lines changed

2 files changed

+157
-93
lines changed

src/indexer.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ use std::{
1515
iter::{self, zip},
1616
path::PathBuf,
1717
};
18+
use thiserror::Error;
1819
use urlencoding::{decode, encode};
1920

2021
pub type HnswIndex = Hnsw<OpenAI, Point, Lcg128Xsl64, 24, 48>;
@@ -172,8 +173,9 @@ pub fn start_indexing_from_operations(
172173
Ok(hnsw)
173174
}
174175

175-
#[derive(Debug)]
176+
#[derive(Debug, Error)]
176177
pub enum SearchError {
178+
#[error("Search failed for unknown reason")]
177179
SearchFailed,
178180
}
179181

src/server.rs

Lines changed: 154 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ use bytes::Bytes;
33
use futures::StreamExt;
44
use futures::TryStreamExt;
55
use hnsw::Hnsw;
6+
use hyper::HeaderMap;
67
use hyper::StatusCode;
78
use hyper::{
89
service::{make_service_fn, service_fn},
@@ -28,6 +29,7 @@ use std::{
2829
};
2930
use thiserror::Error;
3031
use tokio::sync::Mutex;
32+
use tokio::task;
3133
use tokio::{io::AsyncBufReadExt, sync::RwLock};
3234
use tokio_stream::{wrappers::LinesStream, Stream};
3335
use tokio_util::io::StreamReader;
@@ -39,6 +41,7 @@ use crate::indexer::search;
3941
use crate::indexer::serialize_index;
4042
use crate::indexer::Point;
4143
use crate::indexer::PointOperation;
44+
use crate::indexer::SearchError;
4245
use crate::indexer::{start_indexing_from_operations, HnswIndex, IndexIdentifier, OpenAI};
4346
use crate::openai::embeddings_for;
4447
use crate::vectors::VectorStore;
@@ -113,6 +116,28 @@ fn query_map(uri: &Uri) -> HashMap<String, String> {
113116
.unwrap_or_else(|| HashMap::with_capacity(0))
114117
}
115118

119+
#[derive(Debug, Error)]
120+
enum HeaderError {
121+
#[error("Key was not valid utf8")]
122+
KeyNotUtf8,
123+
#[error("Missing the key {0}")]
124+
MissingKey(String),
125+
}
126+
127+
fn get_header_value(header: &HeaderMap, key: &str) -> Result<String, HeaderError> {
128+
let value = header.get(key);
129+
match value {
130+
Some(value) => {
131+
let value = String::from_utf8(value.as_bytes().to_vec());
132+
match value {
133+
Ok(value) => Ok(value),
134+
Err(_) => Err(HeaderError::KeyNotUtf8),
135+
}
136+
}
137+
None => Err(HeaderError::MissingKey(key.to_string())),
138+
}
139+
}
140+
116141
fn uri_to_spec(uri: &Uri) -> Result<ResourceSpec, SpecParseError> {
117142
lazy_static! {
118143
static ref RE_INDEX: Regex = Regex::new(r"^/index(/?)$").unwrap();
@@ -278,6 +303,20 @@ async fn get_operations_from_content_endpoint(
278303
Ok(fp)
279304
}
280305

306+
#[derive(Debug, Error)]
307+
enum ResponseError {
308+
#[error("{0:?}")]
309+
HeaderError(#[from] HeaderError),
310+
#[error("{0:?}")]
311+
IoError(#[from] std::io::Error),
312+
#[error("{0:?}")]
313+
SerdeError(#[from] serde_json::Error),
314+
#[error("{0:?}")]
315+
StartIndexError(#[from] StartIndexError),
316+
#[error("{0:?}")]
317+
SearchError(#[from] SearchError),
318+
}
319+
281320
fn add_to_duplicates(duplicates: &mut HashMap<usize, usize>, id1: usize, id2: usize) {
282321
if id1 < id2 {
283322
duplicates.insert(id1, id2);
@@ -293,15 +332,12 @@ impl Service {
293332
self.tasks.write().await.insert(task_id, status);
294333
}
295334

296-
async fn get_index(&self, index_id: &str) -> Option<Arc<HnswIndex>> {
335+
async fn get_index(&self, index_id: &str) -> io::Result<Arc<HnswIndex>> {
297336
if let Some(hnsw) = self.indexes.read().await.get(index_id) {
298-
Some(hnsw).cloned()
337+
Ok(hnsw).cloned()
299338
} else {
300339
let mut path = self.path.clone();
301-
match deserialize_index(&mut path, index_id, &self.vector_store) {
302-
Ok(res) => Some(res.into()),
303-
Err(_) => None,
304-
}
340+
Ok(deserialize_index(&mut path, index_id, &self.vector_store)?.into())
305341
}
306342
}
307343

@@ -416,24 +452,16 @@ impl Service {
416452
) -> Result<(), AssignIndexError> {
417453
let source_name = create_index_name(&domain, &source_commit);
418454
let target_name = create_index_name(&domain, &target_commit);
419-
420-
if self.get_index(&target_name).await.is_some() {
421-
return Err(AssignIndexError::TargetCommitAlreadyHasIndex);
422-
}
423-
if let Some(index) = self.get_index(&source_name).await {
424-
let mut indexes = self.indexes.write().await;
425-
indexes.insert(target_name.clone(), index.clone());
426-
427-
std::mem::drop(indexes);
428-
tokio::task::block_in_place(move || {
429-
let path = self.path.clone();
430-
serialize_index(path, &target_name, (*index).clone()).unwrap();
431-
});
432-
433-
Ok(())
434-
} else {
435-
Err(AssignIndexError::SourceCommitNotFound)
436-
}
455+
self.get_index(&target_name).await?;
456+
let index = self.get_index(&source_name).await?;
457+
let mut indexes = self.indexes.write().await;
458+
indexes.insert(target_name.clone(), index.clone());
459+
std::mem::drop(indexes);
460+
tokio::task::block_in_place(move || {
461+
let path = self.path.clone();
462+
serialize_index(path, &target_name, (*index).clone()).unwrap();
463+
});
464+
Ok(())
437465
}
438466

439467
async fn process_operation_chunks(
@@ -475,6 +503,20 @@ impl Service {
475503
(id, hnsw)
476504
}
477505

506+
async fn get_start_index(
507+
self: Arc<Self>,
508+
req: Request<Body>,
509+
domain: String,
510+
commit: String,
511+
previous: Option<String>,
512+
) -> Result<String, ResponseError> {
513+
let task_id = Service::generate_task();
514+
let api_key = get_header_value(req.headers(), "VECTORLINK_EMBEDDING_API_KEY")?;
515+
self.set_task_status(task_id.clone(), TaskStatus::Pending(0.0));
516+
self.start_indexing(domain, commit, previous, task_id.clone(), api_key)?;
517+
Ok(task_id)
518+
}
519+
478520
async fn get(self: Arc<Self>, req: Request<Body>) -> Result<Response<Body>, Infallible> {
479521
let uri = req.uri();
480522
match dbg!(uri_to_spec(uri)) {
@@ -483,37 +525,8 @@ impl Service {
483525
commit,
484526
previous,
485527
}) => {
486-
let task_id = Service::generate_task();
487-
let headers = req.headers();
488-
let openai_key = headers.get("VECTORLINK_EMBEDDING_API_KEY");
489-
match openai_key {
490-
Some(openai_key) => {
491-
let openai_key = String::from_utf8(openai_key.as_bytes().to_vec()).unwrap();
492-
self.set_task_status(task_id.clone(), TaskStatus::Pending(0.0))
493-
.await;
494-
match self.start_indexing(
495-
domain,
496-
commit,
497-
previous,
498-
task_id.clone(),
499-
openai_key,
500-
) {
501-
Ok(()) => Ok(Response::builder().body(task_id.into()).unwrap()),
502-
Err(e) => Ok(Response::builder()
503-
.status(400)
504-
.body(e.to_string().into())
505-
.unwrap()),
506-
}
507-
}
508-
None => Ok(Response::builder()
509-
.status(400)
510-
.body(
511-
"No API key supplied in header (VECTORLINK_EMBEDDING_API_KEY)"
512-
.to_string()
513-
.into(),
514-
)
515-
.unwrap()),
516-
}
528+
let result = self.get_start_index(req, domain, commit, previous).await;
529+
fun_name(result)
517530
}
518531
Ok(ResourceSpec::AssignIndex {
519532
domain,
@@ -553,27 +566,13 @@ impl Service {
553566
commit,
554567
threshold,
555568
}) => {
556-
let index_id = create_index_name(&domain, &commit);
557-
// if None, then return 404
558-
let hnsw = self.get_index(&index_id).await.unwrap();
559-
let mut duplicates: HashMap<usize, usize> = HashMap::new();
560-
let elts = hnsw.layer_len(0);
561-
for i in 0..elts {
562-
let current_point = &hnsw.feature(i);
563-
let results = search(current_point, 2, &hnsw).unwrap();
564-
for result in results.iter() {
565-
if f32::from_bits(result.distance()) < threshold {
566-
add_to_duplicates(&mut duplicates, i, result.internal_id())
567-
}
568-
}
569+
let result = self
570+
.get_duplicate_candidates(domain, commit, threshold)
571+
.await;
572+
match result {
573+
Ok(result) => todo!(),
574+
Err(e) => todo!(),
569575
}
570-
let mut v: Vec<(&str, &str)> = duplicates
571-
.into_iter()
572-
.map(|(i, j)| (hnsw.feature(i).id(), hnsw.feature(j).id()))
573-
.collect();
574-
Ok(Response::builder()
575-
.body(serde_json::to_string(&v).unwrap().into())
576-
.unwrap())
577576
}
578577
Ok(ResourceSpec::Similar {
579578
domain,
@@ -618,6 +617,34 @@ impl Service {
618617
}
619618
}
620619

620+
async fn get_duplicate_candidates(
621+
self: Arc<Self>,
622+
domain: String,
623+
commit: String,
624+
threshold: f32,
625+
) -> Result<String, ResponseError> {
626+
let index_id = create_index_name(&domain, &commit);
627+
// if None, then return 404
628+
let hnsw = self.get_index(&index_id).await?;
629+
let mut duplicates: HashMap<usize, usize> = HashMap::new();
630+
let elts = hnsw.layer_len(0);
631+
for i in 0..elts {
632+
let current_point = &hnsw.feature(i);
633+
let results = search(current_point, 2, &hnsw)?;
634+
for result in results.iter() {
635+
if f32::from_bits(result.distance()) < threshold {
636+
add_to_duplicates(&mut duplicates, i, result.internal_id())
637+
}
638+
}
639+
}
640+
let mut v: Vec<(&str, &str)> = duplicates
641+
.into_iter()
642+
.map(|(i, j)| (hnsw.feature(i).id(), hnsw.feature(j).id()))
643+
.collect();
644+
let result = serde_json::to_string(&v)?;
645+
Ok(result)
646+
}
647+
621648
async fn post(&self, req: Request<Body>) -> Result<Response<Body>, Infallible> {
622649
let uri = req.uri();
623650
match uri_to_spec(uri) {
@@ -626,28 +653,63 @@ impl Service {
626653
commit,
627654
count,
628655
}) => {
629-
let body_bytes = hyper::body::to_bytes(req.into_body()).await.unwrap();
656+
let headers = req.headers().clone();
657+
let body = req.into_body();
658+
let body_bytes = hyper::body::to_bytes(body).await.unwrap();
630659
let q = String::from_utf8(body_bytes.to_vec()).unwrap();
631-
let vec = Box::new((embeddings_for(&self.api_key, &[q]).await.unwrap())[0]);
632-
let qp = Point::Mem { vec };
633-
let index_id = create_index_name(&domain, &commit);
634-
// if None, then return 404
635-
let hnsw = self.get_index(&index_id).await.unwrap();
636-
let res = search(&qp, count, &hnsw).unwrap();
637-
let ids: Vec<QueryResult> = res
638-
.iter()
639-
.map(|p| QueryResult {
640-
id: p.id().to_string(),
641-
distance: f32::from_bits(p.distance()),
642-
})
643-
.collect();
644-
let s = serde_json::to_string(&ids).unwrap();
645-
Ok(Response::builder().body(s.into()).unwrap())
660+
let api_key = get_header_value(&headers, "VECTORLINK_EMBEDDING_API_KEY");
661+
let result = self.index_response(api_key, q, domain, commit, count).await;
662+
match result {
663+
Ok(body) => Ok(body),
664+
Err(e) => Ok(Response::builder()
665+
.status(StatusCode::NOT_FOUND)
666+
.body(e.to_string().into())
667+
.unwrap()),
668+
}
646669
}
647670
Ok(_) => todo!(),
648-
Err(_) => todo!(),
671+
Err(e) => Ok(Response::builder()
672+
.status(StatusCode::NOT_FOUND)
673+
.body(e.to_string().into())
674+
.unwrap()),
649675
}
650676
}
677+
678+
async fn index_response(
679+
&self,
680+
api_key: Result<String, HeaderError>,
681+
q: String,
682+
domain: String,
683+
commit: String,
684+
count: usize,
685+
) -> Result<Response<Body>, ResponseError> {
686+
let api_key = api_key?;
687+
let vec = Box::new((embeddings_for(&self.api_key, &[q]).await.unwrap())[0]);
688+
let qp = Point::Mem { vec };
689+
let index_id = create_index_name(&domain, &commit);
690+
// if None, then return 404
691+
let hnsw = self.get_index(&index_id).await?;
692+
let res = search(&qp, count, &hnsw).unwrap();
693+
let ids: Vec<QueryResult> = res
694+
.iter()
695+
.map(|p| QueryResult {
696+
id: p.id().to_string(),
697+
distance: f32::from_bits(p.distance()),
698+
})
699+
.collect();
700+
let s = serde_json::to_string(&ids)?;
701+
Ok(Response::builder().body(s.into()).unwrap())
702+
}
703+
}
704+
705+
fn fun_name(result: Result<String, ResponseError>) -> Result<Response<Body>, Infallible> {
706+
match result {
707+
Ok(task_id) => Ok(Response::builder().body(task_id.into()).unwrap()),
708+
Err(e) => Ok(Response::builder()
709+
.status(400)
710+
.body(e.to_string().into())
711+
.unwrap()),
712+
}
651713
}
652714

653715
#[derive(Debug, Error)]

0 commit comments

Comments
 (0)