Skip to content

Commit 603a796

Browse files
Slightly more robust output
1 parent cb93db6 commit 603a796

File tree

2 files changed

+74
-3
lines changed

2 files changed

+74
-3
lines changed

src/indexer.rs

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ pub struct IndexPoint {
3333
}
3434

3535
impl Point {
36-
pub fn id(&self) -> &String {
36+
pub fn id(&self) -> &str {
3737
match self {
3838
Point::Stored { id, vec } => id,
3939
Point::Mem { vec } => panic!("You can not get the external id of a memory point"),
@@ -100,10 +100,18 @@ pub async fn operations_to_point_operations(
100100
Operation::Inserted { string, id } => Some((Op::Insert, string.into(), id.into())),
101101
Operation::Changed { string, id } => Some((Op::Changed, string.into(), id.into())),
102102
Operation::Deleted { id: _ } => None,
103+
Operation::Error { message } => {
104+
eprintln!("{}", message);
105+
None
106+
}
103107
})
104108
.collect();
105109
let strings: Vec<String> = tuples.iter().map(|(_, s, _)| s.to_string()).collect();
106-
let vecs: Vec<Embedding> = embeddings_for(key, &strings).await.unwrap();
110+
let vecs: Vec<Embedding> = if strings.is_empty() {
111+
Vec::new()
112+
} else {
113+
embeddings_for(key, &strings).await.unwrap()
114+
};
107115
let domain = vector_store.get_domain(domain).unwrap();
108116
let loaded_vecs = vector_store
109117
.add_and_load_vecs(&domain, vecs.iter())
@@ -170,12 +178,17 @@ pub enum SearchError {
170178

171179
#[derive(Clone, Debug, PartialEq)]
172180
pub struct PointQuery {
181+
id: usize,
173182
point: Point,
174183
distance: u32,
175184
}
176185

177186
impl PointQuery {
178-
pub fn id(&self) -> &String {
187+
pub fn internal_id(&self) -> usize {
188+
self.id
189+
}
190+
191+
pub fn id(&self) -> &str {
179192
self.point.id()
180193
}
181194

@@ -197,6 +210,7 @@ pub fn search(p: &Point, num: usize, hnsw: &HnswIndex) -> Result<Vec<PointQuery>
197210
let mut points = Vec::with_capacity(num);
198211
for elt in output {
199212
points.push(PointQuery {
213+
id: elt.index,
200214
point: hnsw.feature(elt.index).clone(),
201215
distance: elt.distance,
202216
})

src/server.rs

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ pub enum Operation {
4848
Inserted { string: String, id: String },
4949
Changed { string: String, id: String },
5050
Deleted { id: String },
51+
Error { message: String },
5152
}
5253

5354
#[derive(Deserialize, Debug)]
@@ -79,6 +80,11 @@ enum ResourceSpec {
7980
id: String,
8081
count: usize,
8182
},
83+
DuplicateCandidates {
84+
domain: String,
85+
commit: String,
86+
threshold: f32,
87+
},
8288
}
8389

8490
#[derive(Debug)]
@@ -104,6 +110,7 @@ fn uri_to_spec(uri: &Uri) -> Result<ResourceSpec, SpecParseError> {
104110
static ref RE_CHECK: Regex = Regex::new(r"^/check(/?)$").unwrap();
105111
static ref RE_SEARCH: Regex = Regex::new(r"^/search(/?)$").unwrap();
106112
static ref RE_SIMILAR: Regex = Regex::new(r"^/similar(/?)$").unwrap();
113+
static ref RE_DUPLICATES: Regex = Regex::new(r"^/duplicates(/?)$").unwrap();
107114
}
108115
let path = uri.path();
109116

@@ -163,6 +170,22 @@ fn uri_to_spec(uri: &Uri) -> Result<ResourceSpec, SpecParseError> {
163170
}
164171
_ => Err(SpecParseError::NoCommitIdOrDomain),
165172
}
173+
} else if RE_DUPLICATES.is_match(path) {
174+
let query = query_map(uri);
175+
let domain = query.get("domain").map(|v| v.to_string());
176+
let commit = query.get("commit").map(|v| v.to_string());
177+
let threshold = query.get("threshold").map(|v| v.parse::<f32>().unwrap());
178+
match (domain, commit) {
179+
(Some(domain), Some(commit)) => {
180+
let threshold = threshold.unwrap_or(0.0);
181+
Ok(ResourceSpec::DuplicateCandidates {
182+
domain,
183+
commit,
184+
threshold,
185+
})
186+
}
187+
_ => Err(SpecParseError::NoCommitIdOrDomain),
188+
}
166189
} else {
167190
Err(SpecParseError::UnknownPath)
168191
}
@@ -216,13 +239,20 @@ async fn get_operations_from_terminusdb(
216239
let lines = StreamReader::new(res).lines();
217240
let lines_stream = LinesStream::new(lines);
218241
let fp = lines_stream.and_then(|l| {
242+
dbg!(&l);
219243
future::ready(
220244
serde_json::from_str(&l).map_err(|e| std::io::Error::new(ErrorKind::Other, e)),
221245
)
222246
});
223247
Ok(fp)
224248
}
225249

250+
fn add_to_duplicates(duplicates: &mut HashMap<usize, usize>, id1: usize, id2: usize) {
251+
if id1 < id2 {
252+
duplicates.insert(id1, id2);
253+
}
254+
}
255+
226256
impl Service {
227257
async fn get_task_status(&self, task_id: &str) -> Option<TaskStatus> {
228258
self.tasks.read().await.get(task_id).cloned()
@@ -389,6 +419,33 @@ impl Service {
389419
.unwrap())
390420
}
391421
}
422+
Ok(ResourceSpec::DuplicateCandidates {
423+
domain,
424+
commit,
425+
threshold,
426+
}) => {
427+
let index_id = create_index_name(&domain, &commit);
428+
// if None, then return 404
429+
let hnsw = self.get_index(&index_id).await.unwrap();
430+
let mut duplicates: HashMap<usize, usize> = HashMap::new();
431+
let elts = hnsw.layer_len(0);
432+
for i in 0..elts {
433+
let current_point = &hnsw.feature(i);
434+
let results = search(current_point, 2, &hnsw).unwrap();
435+
for result in results.iter() {
436+
if f32::from_bits(result.distance()) < threshold {
437+
add_to_duplicates(&mut duplicates, i, result.internal_id())
438+
}
439+
}
440+
}
441+
let mut v: Vec<(&str, &str)> = duplicates
442+
.into_iter()
443+
.map(|(i, j)| (hnsw.feature(i).id(), hnsw.feature(j).id()))
444+
.collect();
445+
Ok(Response::builder()
446+
.body(serde_json::to_string(&v).unwrap().into())
447+
.unwrap())
448+
}
392449
Ok(ResourceSpec::Similar {
393450
domain,
394451
commit,

0 commit comments

Comments
 (0)