Skip to content

Commit 683d1bf

Browse files
Adding preliminary load
1 parent 6ec0589 commit 683d1bf

File tree

6 files changed

+145
-43
lines changed

6 files changed

+145
-43
lines changed

Cargo.lock

Lines changed: 16 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ urlencoding = "2.1"
3030
packed_simd = {version = "0.3.8", optional=true}
3131
aligned_box = "0.2"
3232
tiktoken-rs = "0.4"
33+
itertools = "0.10"
3334

3435
[features]
3536
simd = ["packed_simd"]

src/indexer.rs

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
use crate::{
22
openai::embeddings_for,
33
server::Operation,
4-
vecmath::Embedding,
4+
vecmath::{self, Embedding},
55
vectors::{LoadedVec, VectorStore},
66
};
77
use hnsw::{Hnsw, Searcher};
@@ -62,19 +62,14 @@ impl Metric<Point> for OpenAI {
6262
fn distance(&self, p1: &Point, p2: &Point) -> u32 {
6363
let a = p1.vec();
6464
let b = p2.vec();
65-
let f = a
66-
.iter()
67-
.zip(b.iter())
68-
.map(|(&a, &b)| (a - b).powi(2))
69-
.sum::<f32>()
70-
.sqrt();
65+
let f = vecmath::normalized_cosine_distance(a, b);
7166
f.to_bits()
7267
}
7368
}
7469

7570
impl Metric<IndexPoint> for OpenAI {
7671
type Unit = u32;
77-
fn distance(&self, p1: &IndexPoint, p2: &IndexPoint) -> u32 {
72+
fn distance(&self, _p1: &IndexPoint, _p2: &IndexPoint) -> u32 {
7873
unimplemented!()
7974
}
8075
}
@@ -86,8 +81,6 @@ pub enum PointOperation {
8681
Delete { id: String },
8782
}
8883

89-
pub const OPENAI_API_KEY: &str = "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";
90-
9184
enum Op {
9285
Insert,
9386
Changed,
@@ -97,6 +90,7 @@ pub async fn operations_to_point_operations(
9790
domain: &str,
9891
vector_store: &VectorStore,
9992
structs: Vec<Result<Operation, std::io::Error>>,
93+
key: &str,
10094
) -> Vec<PointOperation> {
10195
let ops: Vec<Operation> = structs.into_iter().map(|ro| ro.unwrap()).collect();
10296
let tuples: Vec<(Op, String, String)> = ops
@@ -108,7 +102,7 @@ pub async fn operations_to_point_operations(
108102
})
109103
.collect();
110104
let strings: Vec<String> = tuples.iter().map(|(_, s, _)| s.to_string()).collect();
111-
let vecs: Vec<Embedding> = embeddings_for(OPENAI_API_KEY, &strings).await.unwrap();
105+
let vecs: Vec<Embedding> = embeddings_for(key, &strings).await.unwrap();
112106
let domain = vector_store.get_domain(domain).unwrap();
113107
let loaded_vecs = vector_store
114108
.add_and_load_vecs(&domain, vecs.iter())
@@ -213,7 +207,7 @@ pub fn search(
213207
}
214208

215209
pub fn serialize_index(mut path: PathBuf, name: &str, hnsw: HnswIndex) -> io::Result<()> {
216-
let name = encode(name);
210+
//let name = encode(name);
217211
path.push(format!("{name}.hnsw"));
218212
let write_file = File::options().write(true).create(true).open(&path)?;
219213

src/main.rs

Lines changed: 77 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,27 @@
1+
use std::future;
2+
use std::io::ErrorKind;
3+
use std::path::Path;
4+
15
use clap::{Parser, Subcommand, ValueEnum};
6+
use hnsw::Hnsw;
7+
use indexer::start_indexing_from_operations;
28
use indexer::Point;
9+
use indexer::{operations_to_point_operations, OpenAI};
10+
use server::Operation;
311
use space::Metric;
4-
use terminusdb_semantic_indexer::vecmath::empty_embedding;
5-
6-
use crate::indexer::OpenAI;
7-
12+
use std::fs::File;
13+
use std::io::{self, BufRead};
14+
use {
15+
indexer::{create_index_name, HnswIndex},
16+
vecmath::empty_embedding,
17+
vectors::VectorStore,
18+
};
819
mod indexer;
920
mod openai;
1021
mod server;
1122
mod vecmath;
1223
mod vectors;
24+
use itertools::Itertools;
1325

1426
#[derive(Parser, Debug)]
1527
#[command(author, version, about, long_about = None)]
@@ -21,13 +33,29 @@ struct Args {
2133
#[derive(Subcommand, Debug)]
2234
enum Commands {
2335
Serve {
36+
#[arg(short, long)]
37+
key: String,
2438
#[arg(short, long)]
2539
directory: String,
2640
#[arg(short, long, default_value_t = 8080)]
2741
port: u16,
2842
#[arg(short, long, default_value_t = 10000)]
2943
size: usize,
3044
},
45+
Load {
46+
#[arg(short, long)]
47+
key: String,
48+
#[arg(short, long)]
49+
commit: String,
50+
#[arg(short, long)]
51+
domain: String,
52+
#[arg(short, long)]
53+
directory: String,
54+
#[arg(short, long)]
55+
input: String,
56+
#[arg(short, long, default_value_t = 10000)]
57+
size: usize,
58+
},
3159
Embed {
3260
#[arg(short, long)]
3361
key: String,
@@ -70,10 +98,11 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
7098
let args = Args::parse();
7199
match args.command {
72100
Commands::Serve {
101+
key,
73102
directory,
74103
port,
75104
size,
76-
} => server::serve(directory, port, size).await?,
105+
} => server::serve(directory, port, size, key).await?,
77106
Commands::Embed { key, string } => {
78107
let v = openai::embeddings_for(&key, &[string]).await?;
79108
eprintln!("{:?}", v);
@@ -108,15 +137,56 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
108137
};
109138
println!("distance: {}", distance);
110139
}
111-
Commands::Test {key } => {
112-
let v = openai::embeddings_for(&key, &["king".to_string(), "man".to_string(), "woman".to_string(), "queen".to_string()]).await?;
140+
Commands::Test { key } => {
141+
let v = openai::embeddings_for(
142+
&key,
143+
&[
144+
"king".to_string(),
145+
"man".to_string(),
146+
"woman".to_string(),
147+
"queen".to_string(),
148+
],
149+
)
150+
.await?;
113151
let mut calculated = empty_embedding();
114152
for i in 0..calculated.len() {
115153
calculated[i] = v[0][i] - v[1][i] + v[2][i];
116154
}
117155
let distance = vecmath::normalized_cosine_distance(&v[3], &calculated);
118156
eprintln!("{}", distance);
119157
}
158+
Commands::Load {
159+
key,
160+
domain,
161+
commit,
162+
directory,
163+
input,
164+
size,
165+
} => {
166+
let path = Path::new(&input);
167+
let dirpath = Path::new(&directory);
168+
let mut hnsw: HnswIndex = Hnsw::new(OpenAI);
169+
let store = VectorStore::new(dirpath, size);
170+
171+
let f = File::options().read(true).create(true).open(path)?;
172+
173+
let lines = io::BufReader::new(f).lines();
174+
175+
let opstream = &lines
176+
.map(|l| {
177+
let ro: io::Result<Operation> = serde_json::from_str(&l.unwrap())
178+
.map_err(|e| std::io::Error::new(ErrorKind::Other, e));
179+
ro
180+
})
181+
.chunks(100);
182+
183+
for structs in opstream {
184+
let structs: Vec<_> = structs.collect();
185+
let new_ops =
186+
operations_to_point_operations(&domain.clone(), &store, structs, &key).await;
187+
hnsw = start_indexing_from_operations(hnsw, new_ops).unwrap();
188+
}
189+
}
120190
}
121191

122192
Ok(())

src/openai.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -87,8 +87,9 @@ pub enum EmbeddingError {
8787
}
8888

8989
lazy_static! {
90-
static ref ENCODER: CoreBPE = cl100k_base().unwrap();
90+
static ref ENCODER: CoreBPE = cl100k_base().unwrap();
9191
}
92+
9293
fn tokens_for(s: &str) -> Vec<usize> {
9394
ENCODER.encode_with_special_tokens(s)
9495
}
@@ -100,7 +101,6 @@ fn truncated_tokens_for(s: &str) -> Vec<usize> {
100101
tokens.truncate(MAX_TOKEN_COUNT);
101102
let decoded = ENCODER.decode(tokens.clone()).unwrap();
102103
eprintln!("truncating to {decoded}");
103-
104104
}
105105

106106
tokens
@@ -115,7 +115,7 @@ pub async fn embeddings_for(
115115
static ref CLIENT: Client = Client::new();
116116
}
117117

118-
let token_lists: Vec<_> = strings.iter().map(|s|truncated_tokens_for(s)).collect();
118+
let token_lists: Vec<_> = strings.iter().map(|s| truncated_tokens_for(s)).collect();
119119

120120
let mut req = Request::new(Method::POST, ENDPOINT.clone());
121121
let headers = req.headers_mut();

src/server.rs

Lines changed: 42 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@ use crate::indexer::search;
3737
use crate::indexer::serialize_index;
3838
use crate::indexer::Point;
3939
use crate::indexer::PointOperation;
40-
use crate::indexer::OPENAI_API_KEY;
4140
use crate::indexer::{start_indexing_from_operations, HnswIndex, IndexIdentifier, OpenAI};
4241
use crate::openai::embeddings_for;
4342
use crate::vectors::VectorStore;
@@ -182,6 +181,7 @@ pub struct QueryResult {
182181
}
183182

184183
pub struct Service {
184+
api_key: String,
185185
path: PathBuf,
186186
vector_store: VectorStore,
187187
pending: Mutex<HashSet<String>>,
@@ -271,9 +271,10 @@ impl Service {
271271
s
272272
}
273273

274-
fn new<P: Into<PathBuf>>(path: P, num_bufs: usize) -> Self {
274+
fn new<P: Into<PathBuf>>(path: P, num_bufs: usize, key: String) -> Self {
275275
let path = path.into();
276276
Service {
277+
api_key: key,
277278
path: path.clone(),
278279
vector_store: VectorStore::new(path, num_bufs),
279280
pending: Mutex::new(HashSet::new()),
@@ -306,37 +307,56 @@ impl Service {
306307
tokio::spawn(async move {
307308
let index_id = create_index_name(&domain, &commit);
308309
if self.test_and_set_pending(index_id.clone()).await {
309-
let mut opstream = get_operations_from_terminusdb(
310+
let opstream = get_operations_from_terminusdb(
310311
domain.clone(),
311312
commit.clone(),
312313
previous.clone(),
313314
)
314315
.await
315316
.unwrap()
316317
.chunks(100);
317-
let mut point_ops: Vec<PointOperation> = Vec::new();
318-
while let Some(structs) = opstream.next().await {
319-
let mut new_ops =
320-
operations_to_point_operations(&domain, &self.vector_store, structs).await;
321-
point_ops.append(&mut new_ops)
322-
}
323-
let id = create_index_name(&domain, &commit);
324-
let hnsw = self
325-
.load_hnsw_for_indexing(IndexIdentifier {
326-
domain,
327-
commit,
328-
previous,
329-
})
318+
let (id, hnsw) = self
319+
.process_operation_chunks(opstream, domain, commit, previous, &index_id)
330320
.await;
331-
let hnsw = start_indexing_from_operations(hnsw, point_ops).unwrap();
332-
let path = self.path.clone();
333-
serialize_index(path, &index_id, hnsw.clone()).unwrap();
334321
self.set_index(id, hnsw.into()).await;
335322
self.clear_pending(&index_id).await;
336323
}
337324
});
338325
}
339326

327+
async fn process_operation_chunks(
328+
self: &Arc<Self>,
329+
mut opstream: futures::stream::Chunks<
330+
impl Stream<Item = Result<Operation, io::Error>> + Unpin,
331+
>,
332+
domain: String,
333+
commit: String,
334+
previous: Option<String>,
335+
index_id: &str,
336+
) -> (String, Hnsw<OpenAI, Point, rand_pcg::Lcg128Xsl64, 12, 24>) {
337+
let id = create_index_name(&domain, &commit);
338+
let mut hnsw = self
339+
.load_hnsw_for_indexing(IndexIdentifier {
340+
domain: domain.clone(),
341+
commit,
342+
previous,
343+
})
344+
.await;
345+
while let Some(structs) = opstream.next().await {
346+
let new_ops = operations_to_point_operations(
347+
&domain.clone(),
348+
&self.vector_store,
349+
structs,
350+
&self.api_key,
351+
)
352+
.await;
353+
hnsw = start_indexing_from_operations(hnsw, new_ops).unwrap();
354+
}
355+
let path = self.path.clone();
356+
serialize_index(path, index_id, hnsw.clone()).unwrap();
357+
(id, hnsw)
358+
}
359+
340360
async fn get(self: Arc<Self>, req: Request<Body>) -> Result<Response<Body>, Infallible> {
341361
let uri = req.uri();
342362
match dbg!(uri_to_spec(uri)) {
@@ -395,7 +415,7 @@ impl Service {
395415
}) => {
396416
let body_bytes = hyper::body::to_bytes(req.into_body()).await.unwrap();
397417
let q = String::from_utf8(body_bytes.to_vec()).unwrap();
398-
let vec = Box::new((embeddings_for(OPENAI_API_KEY, &[q]).await.unwrap())[0]);
418+
let vec = Box::new((embeddings_for(&self.api_key, &[q]).await.unwrap())[0]);
399419
let qp = Point::Mem { vec };
400420
let index_id = create_index_name(&domain, &commit);
401421
// if None, then return 404
@@ -421,9 +441,10 @@ pub async fn serve<P: Into<PathBuf>>(
421441
directory: P,
422442
port: u16,
423443
num_bufs: usize,
444+
key: String,
424445
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
425446
let addr = SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), port);
426-
let service = Arc::new(Service::new(directory, num_bufs));
447+
let service = Arc::new(Service::new(directory, num_bufs, key));
427448
let make_svc = make_service_fn(move |_conn| {
428449
let s = service.clone();
429450
async {

0 commit comments

Comments
 (0)