Skip to content

Commit 608e333

Browse files
authored
[ENH]: add Rust server config option for CORS allowed origins (#3861)
## Description of changes Adds a config option for CORS allowed origins to the Rust server which matches the behavior of the existing option on the Python side. ## Test plan *How are these changes tested?* Added a new test. ## Documentation Changes *Are all docstrings for user-facing APIs updated if required? Do we need to make documentation changes in the [docs repository](https://github.com/chroma-core/docs)?* Will need to update docs in #3853.
1 parent 69cdf74 commit 608e333

File tree

8 files changed

+83
-13
lines changed

8 files changed

+83
-13
lines changed

Cargo.lock

Lines changed: 2 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,11 +45,13 @@ sha2 = "0.10.8"
4545
md5 = "0.7.0"
4646
regex = "1.11.1"
4747
pyo3 = { version = "0.23.3", features = ["abi3-py39"] }
48-
tower-http = { version = "0.6.2", features = ["trace"] }
48+
tower-http = { version = "0.6.2", features = ["trace", "cors"] }
4949
bytemuck = "1.21.0"
5050
validator = { version = "0.19", features = ["derive"] }
5151
rust-embed = { version = "8.5.0", features = ["include-exclude", "debug-embed"] }
5252
hnswlib = { version = "0.8.0", git = "https://github.com/chroma-core/hnswlib.git" }
53+
reqwest = { version = "0.12.9" }
54+
random-port = "0.1.1"
5355

5456
chroma-benchmark = { path = "rust/benchmark" }
5557
chroma-blockstore = { path = "rust/blockstore" }

rust/benchmark/Cargo.toml

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,7 @@ path = "src/lib.rs"
99
[dependencies]
1010
anyhow = "1.0.93"
1111
async-tempfile = "0.6.0"
12-
async-compression = { version = "0.4.18", features = [
13-
"tokio",
14-
"gzip",
15-
"bzip2",
16-
] }
12+
async-compression = { version = "0.4.18", features = ["tokio", "gzip", "bzip2"] }
1713

1814
bincode = { workspace = true }
1915
clap = { workspace = true }
@@ -29,7 +25,7 @@ tokio = { workspace = true }
2925
uuid = { workspace = true }
3026

3127
dirs = "5.0.1"
32-
reqwest = { version = "0.12.9", features = ["stream"] }
28+
reqwest = { workspace = true, features = ["stream"] }
3329
tokio-stream = { version = "0.1.16", features = ["full"] }
3430
tokio-util = "0.7.12"
3531
bloom = "0.3.2"

rust/frontend/Cargo.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,3 +46,7 @@ chroma-sqlite = { workspace = true }
4646
utoipa = { workspace = true }
4747
utoipa-axum = { version = "0.2.0", features = ["debug"] }
4848
utoipa-swagger-ui = { version = "9", features = ["axum"] }
49+
50+
[dev-dependencies]
51+
reqwest = { workspace = true }
52+
random-port = { workspace = true }

rust/frontend/src/config.rs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,10 +100,12 @@ pub struct FrontendServerConfig {
100100
pub open_telemetry: Option<OpenTelemetryConfig>,
101101
#[serde(default)]
102102
pub persist_path: Option<String>,
103+
#[serde(default)]
104+
pub cors_allow_origins: Option<Vec<String>>,
103105
}
104106

105-
const DEFAULT_CONFIG_PATH: &str = "./sample_configs/distributed.yaml";
106-
const DEFAULT_SINGLE_NODE_CONFIG_FILENAME: &str = "./sample_configs/single_node.yaml";
107+
const DEFAULT_CONFIG_PATH: &str = "sample_configs/distributed.yaml";
108+
const DEFAULT_SINGLE_NODE_CONFIG_FILENAME: &str = "sample_configs/single_node.yaml";
107109

108110
#[derive(Embed)]
109111
#[folder = "./"]

rust/frontend/src/server.rs

Lines changed: 65 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ use std::sync::{
2929
atomic::{AtomicBool, Ordering},
3030
Arc,
3131
};
32+
use tower_http::cors::CorsLayer;
3233
use utoipa::openapi::security::{ApiKey, ApiKeyValue, SecurityScheme};
3334
use utoipa::ToSchema;
3435
use utoipa::{Modify, OpenApi};
@@ -159,6 +160,7 @@ impl FrontendServer {
159160
listen_address,
160161
max_payload_size_bytes,
161162
circuit_breaker,
163+
cors_allow_origins,
162164
..
163165
} = self.config.clone();
164166

@@ -239,7 +241,20 @@ impl FrontendServer {
239241
.merge(docs_router)
240242
.with_state(self)
241243
.layer(DefaultBodyLimit::max(max_payload_size_bytes));
242-
let app = add_tracing_middleware(app);
244+
let mut app = add_tracing_middleware(app);
245+
246+
if let Some(cors_allow_origins) = cors_allow_origins {
247+
let origins = cors_allow_origins
248+
.into_iter()
249+
.map(|origin| origin.parse().unwrap())
250+
.collect::<Vec<_>>();
251+
252+
let cors = CorsLayer::new()
253+
.allow_origin(origins)
254+
.allow_headers(tower_http::cors::Any)
255+
.allow_methods(tower_http::cors::Any);
256+
app = app.layer(cors);
257+
}
243258

244259
// TODO: tracing
245260
let addr = format!("{}:{}", listen_address, port);
@@ -1628,3 +1643,52 @@ impl Modify for ChromaTokenSecurityAddon {
16281643
modifiers(&ChromaTokenSecurityAddon)
16291644
)]
16301645
struct ApiDoc;
1646+
1647+
#[cfg(test)]
1648+
mod tests {
1649+
use crate::{config::FrontendServerConfig, frontend::Frontend, FrontendServer};
1650+
use chroma_config::{registry::Registry, Configurable};
1651+
use chroma_system::System;
1652+
use std::sync::Arc;
1653+
1654+
#[tokio::test]
1655+
async fn test_cors() {
1656+
let registry = Registry::new();
1657+
let system = System::new();
1658+
1659+
let port = random_port::PortPicker::new().pick().unwrap();
1660+
1661+
let mut config = FrontendServerConfig::single_node_default();
1662+
config.port = port;
1663+
config.cors_allow_origins = Some(vec!["http://localhost:3000".to_string()]);
1664+
1665+
let frontend = Frontend::try_from_config(&(config.clone().frontend, system), &registry)
1666+
.await
1667+
.unwrap();
1668+
let app = FrontendServer::new(config, frontend, vec![], Arc::new(()), Arc::new(()));
1669+
tokio::task::spawn(async move {
1670+
app.run().await;
1671+
});
1672+
1673+
let client = reqwest::Client::new();
1674+
let res = client
1675+
.request(
1676+
reqwest::Method::OPTIONS,
1677+
format!("http://localhost:{}/api/v2/heartbeat", port),
1678+
)
1679+
.header("Origin", "http://localhost:3000")
1680+
.send()
1681+
.await
1682+
.unwrap();
1683+
assert_eq!(res.status(), 200);
1684+
1685+
let allow_origin = res.headers().get("Access-Control-Allow-Origin");
1686+
assert_eq!(allow_origin.unwrap(), "http://localhost:3000");
1687+
1688+
let allow_methods = res.headers().get("Access-Control-Allow-Methods");
1689+
assert_eq!(allow_methods.unwrap(), "*");
1690+
1691+
let allow_headers = res.headers().get("Access-Control-Allow-Headers");
1692+
assert_eq!(allow_headers.unwrap(), "*");
1693+
}
1694+
}

rust/load/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,5 +27,5 @@ opentelemetry_sdk = { workspace = true }
2727
chromadb = { git = "https://github.com/rescrv/chromadb-rs", rev = "540c3e225e92ecea05039b73b69adf5875385c0e" }
2828
guacamole = { version = "0.9", default-features = false }
2929
tower-http = { workspace = true }
30-
reqwest = { version = "0.12", features = ["json"] }
30+
reqwest = { workspace = true, features = ["json"] }
3131
siphasher = "1.0.1"

rust/worker/Cargo.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ fastrace = "0.7"
6868
fastrace-opentelemetry = "0.8"
6969

7070
[dev-dependencies]
71-
random-port = "0.1.1"
71+
random-port = { workspace = true }
7272
serial_test = { workspace = true }
7373
criterion = { workspace = true }
7474
indicatif = { workspace = true }
@@ -99,4 +99,4 @@ harness = false
9999

100100
[[bench]]
101101
name = "spann"
102-
harness = false
102+
harness = false

0 commit comments

Comments
 (0)