Skip to content

Commit 0afaf46

Browse files
Specify endpoint for content generator from the command line
1 parent af8f6c2 commit 0afaf46

File tree

2 files changed

+72
-30
lines changed

2 files changed

+72
-30
lines changed

src/main.rs

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@ enum Commands {
3737
#[arg(short, long)]
3838
key: Option<String>,
3939
#[arg(short, long)]
40+
content_endpoint: Option<String>,
41+
#[arg(short, long)]
4042
directory: String,
4143
#[arg(short, long, default_value_t = 8080)]
4244
port: u16,
@@ -106,16 +108,30 @@ fn key_or_env(k: Option<String>) -> String {
106108
result.unwrap()
107109
}
108110

111+
fn content_endpoint_or_env(c: Option<String>) -> Option<String> {
112+
c.or_else(|| std::env::var("TERMINUSDB_CONTENT_ENDPOINT").ok())
113+
}
114+
109115
#[tokio::main]
110116
async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
111117
let args = Args::parse();
112118
match args.command {
113119
Commands::Serve {
114120
key,
121+
content_endpoint,
115122
directory,
116123
port,
117124
size,
118-
} => server::serve(directory, port, size, key_or_env(key)).await?,
125+
} => {
126+
server::serve(
127+
directory,
128+
port,
129+
size,
130+
key_or_env(key),
131+
content_endpoint_or_env(content_endpoint),
132+
)
133+
.await?
134+
}
119135
Commands::Embed { key, string } => {
120136
let v = openai::embeddings_for(&key_or_env(key), &[string]).await?;
121137
eprintln!("{:?}", v);
@@ -162,8 +178,8 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
162178
)
163179
.await?;
164180
let mut calculated = empty_embedding();
165-
for i in 0..calculated.len() {
166-
calculated[i] = v[0][i] - v[1][i] + v[2][i];
181+
for (i, calculated) in calculated.iter_mut().enumerate() {
182+
*calculated = v[0][i] - v[1][i] + v[2][i];
167183
}
168184
let distance = vecmath::normalized_cosine_distance(&v[3], &calculated);
169185
eprintln!("{}", distance);

src/server.rs

Lines changed: 53 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,7 @@ pub struct QueryResult {
227227
}
228228

229229
pub struct Service {
230+
content_endpoint: Option<String>,
230231
api_key: String,
231232
path: PathBuf,
232233
vector_store: VectorStore,
@@ -235,14 +236,20 @@ pub struct Service {
235236
indexes: RwLock<HashMap<String, Arc<HnswIndex>>>,
236237
}
237238

239+
#[derive(Debug, Error)]
240+
enum StartIndexError {
241+
#[error("No content endpoint found: specify at server startup or supply indexing data from the command line")]
242+
NoContentEndpoint,
243+
}
244+
238245
async fn extract_body(req: Request<Body>) -> Bytes {
239246
hyper::body::to_bytes(req.into_body()).await.unwrap()
240247
}
241248

242249
enum TerminusIndexOperationError {}
243250

244-
const TERMINUSDB_INDEX_ENDPOINT: &str = "http://localhost:6363/api/index";
245-
async fn get_operations_from_terminusdb(
251+
async fn get_operations_from_content_endpoint(
252+
content_endpoint: String,
246253
domain: String,
247254
commit: String,
248255
previous: Option<String>,
@@ -251,7 +258,7 @@ async fn get_operations_from_terminusdb(
251258
if let Some(previous) = previous {
252259
params.push(("previous", previous))
253260
}
254-
let endpoint = format!("{}/{}", TERMINUSDB_INDEX_ENDPOINT, &domain);
261+
let endpoint = format!("{}/{}", content_endpoint, &domain);
255262
let url = reqwest::Url::parse_with_params(&endpoint, &params).unwrap();
256263
let res = reqwest::get(url)
257264
.await
@@ -322,9 +329,15 @@ impl Service {
322329
s
323330
}
324331

325-
fn new<P: Into<PathBuf>>(path: P, num_bufs: usize, key: String) -> Self {
332+
fn new<P: Into<PathBuf>>(
333+
path: P,
334+
num_bufs: usize,
335+
key: String,
336+
content_endpoint: Option<String>,
337+
) -> Self {
326338
let path = path.into();
327339
Service {
340+
content_endpoint,
328341
api_key: key,
329342
path: path.clone(),
330343
vector_store: VectorStore::new(path, num_bufs),
@@ -360,26 +373,33 @@ impl Service {
360373
commit: String,
361374
previous: Option<String>,
362375
task_id: String,
363-
) {
364-
tokio::spawn(async move {
365-
let index_id = create_index_name(&domain, &commit);
366-
if self.test_and_set_pending(index_id.clone()).await {
367-
let opstream = get_operations_from_terminusdb(
368-
domain.clone(),
369-
commit.clone(),
370-
previous.clone(),
371-
)
372-
.await
373-
.unwrap()
374-
.chunks(100);
375-
let (id, hnsw) = self
376-
.process_operation_chunks(opstream, domain, commit, previous, &index_id)
377-
.await;
378-
self.set_index(id, hnsw.into()).await;
379-
self.clear_pending(&index_id).await;
380-
}
381-
self.set_task_status(task_id, TaskStatus::Completed).await;
382-
});
376+
) -> Result<(), StartIndexError> {
377+
let content_endpoint = self.content_endpoint.clone();
378+
if let Some(content_endpoint) = content_endpoint {
379+
tokio::spawn(async move {
380+
let index_id = create_index_name(&domain, &commit);
381+
if self.test_and_set_pending(index_id.clone()).await {
382+
let opstream = get_operations_from_content_endpoint(
383+
content_endpoint.to_string(),
384+
domain.clone(),
385+
commit.clone(),
386+
previous.clone(),
387+
)
388+
.await
389+
.unwrap()
390+
.chunks(100);
391+
let (id, hnsw) = self
392+
.process_operation_chunks(opstream, domain, commit, previous, &index_id)
393+
.await;
394+
self.set_index(id, hnsw.into()).await;
395+
self.clear_pending(&index_id).await;
396+
}
397+
self.set_task_status(task_id, TaskStatus::Completed).await;
398+
});
399+
Ok(())
400+
} else {
401+
Err(StartIndexError::NoContentEndpoint)
402+
}
383403
}
384404

385405
async fn assign_index(
@@ -454,8 +474,13 @@ impl Service {
454474
let task_id = Service::generate_task();
455475
self.set_task_status(task_id.clone(), TaskStatus::Pending)
456476
.await;
457-
self.start_indexing(domain, commit, previous, task_id.clone());
458-
Ok(Response::builder().body(task_id.into()).unwrap())
477+
match self.start_indexing(domain, commit, previous, task_id.clone()) {
478+
Ok(()) => Ok(Response::builder().body(task_id.into()).unwrap()),
479+
Err(e) => Ok(Response::builder()
480+
.status(400)
481+
.body(e.to_string().into())
482+
.unwrap()),
483+
}
459484
}
460485
Ok(ResourceSpec::AssignIndex {
461486
domain,
@@ -598,9 +623,10 @@ pub async fn serve<P: Into<PathBuf>>(
598623
port: u16,
599624
num_bufs: usize,
600625
key: String,
626+
content_endpoint: Option<String>,
601627
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
602628
let addr = SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), port);
603-
let service = Arc::new(Service::new(directory, num_bufs, key));
629+
let service = Arc::new(Service::new(directory, num_bufs, key, content_endpoint));
604630
let make_svc = make_service_fn(move |_conn| {
605631
let s = service.clone();
606632
async {

0 commit comments

Comments
 (0)