Skip to content

Commit ebe078a

Browse files
authored
fix: allow health check w/o auth (#360)
1 parent dbeb3ab commit ebe078a

File tree

1 file changed

+32
-25
lines changed

1 file changed

+32
-25
lines changed

router/src/http/server.rs

Lines changed: 32 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1669,9 +1669,7 @@ pub async fn run(
16691669
ApiDoc::openapi()
16701670
};
16711671

1672-
// Create router
1673-
let mut app = Router::new()
1674-
.merge(SwaggerUi::new("/docs").url("/api-doc/openapi.json", doc))
1672+
let mut routes = Router::new()
16751673
// Base routes
16761674
.route("/info", get(get_model_info))
16771675
.route("/embed", post(embed))
@@ -1686,74 +1684,72 @@ pub async fn run(
16861684
.route("/embeddings", post(openai_embed))
16871685
.route("/v1/embeddings", post(openai_embed))
16881686
// Vertex compat route
1689-
.route("/vertex", post(vertex_compatibility))
1687+
.route("/vertex", post(vertex_compatibility));
1688+
1689+
#[allow(unused_mut)]
1690+
let mut public_routes = Router::new()
16901691
// Base Health route
16911692
.route("/health", get(health))
16921693
// Inference API health route
16931694
.route("/", get(health))
16941695
// AWS Sagemaker health route
16951696
.route("/ping", get(health))
16961697
// Prometheus metrics route
1697-
.route("/metrics", get(metrics))
1698-
// Update payload limit
1699-
.layer(DefaultBodyLimit::max(payload_limit));
1698+
.route("/metrics", get(metrics));
17001699

17011700
#[cfg(feature = "google")]
17021701
{
17031702
tracing::info!("Built with `google` feature");
17041703

17051704
if let Ok(env_predict_route) = std::env::var("AIP_PREDICT_ROUTE") {
17061705
tracing::info!("Serving Vertex compatible route on {env_predict_route}");
1707-
app = app.route(&env_predict_route, post(vertex_compatibility));
1706+
routes = routes.route(&env_predict_route, post(vertex_compatibility));
17081707
}
17091708

17101709
if let Ok(env_health_route) = std::env::var("AIP_HEALTH_ROUTE") {
17111710
tracing::info!("Serving Vertex compatible health route on {env_health_route}");
1712-
app = app.route(&env_health_route, get(health));
1711+
public_routes = public_routes.route(&env_health_route, get(health));
17131712
}
17141713
}
17151714
#[cfg(not(feature = "google"))]
17161715
{
17171716
// Set default routes
1718-
app = match &info.model_type {
1717+
routes = match &info.model_type {
17191718
ModelType::Classifier(_) => {
1720-
app.route("/", post(predict))
1719+
routes
1720+
.route("/", post(predict))
17211721
// AWS Sagemaker route
17221722
.route("/invocations", post(predict))
17231723
}
17241724
ModelType::Reranker(_) => {
1725-
app.route("/", post(rerank))
1725+
routes
1726+
.route("/", post(rerank))
17261727
// AWS Sagemaker route
17271728
.route("/invocations", post(rerank))
17281729
}
17291730
ModelType::Embedding(model) => {
17301731
if std::env::var("TASK").ok() == Some("sentence-similarity".to_string()) {
1731-
app.route("/", post(similarity))
1732+
routes
1733+
.route("/", post(similarity))
17321734
// AWS Sagemaker route
17331735
.route("/invocations", post(similarity))
17341736
} else if model.pooling == "splade" {
1735-
app.route("/", post(embed_sparse))
1737+
routes
1738+
.route("/", post(embed_sparse))
17361739
// AWS Sagemaker route
17371740
.route("/invocations", post(embed_sparse))
17381741
} else {
1739-
app.route("/", post(embed))
1742+
routes
1743+
.route("/", post(embed))
17401744
// AWS Sagemaker route
17411745
.route("/invocations", post(embed))
17421746
}
17431747
}
17441748
};
17451749
}
17461750

1747-
app = app
1748-
.layer(Extension(infer))
1749-
.layer(Extension(info))
1750-
.layer(Extension(prom_handle.clone()))
1751-
.layer(OtelAxumLayer::default())
1752-
.layer(cors_layer);
1753-
17541751
if let Some(api_key) = api_key {
1755-
let mut prefix = "Bearer ".to_string();
1756-
prefix.push_str(&api_key);
1752+
let prefix = format!("Bearer {}", api_key);
17571753

17581754
// Leak to allow FnMut
17591755
let api_key: &'static str = prefix.leak();
@@ -1770,9 +1766,20 @@ pub async fn run(
17701766
}
17711767
};
17721768

1773-
app = app.layer(axum::middleware::from_fn(auth));
1769+
routes = routes.layer(axum::middleware::from_fn(auth));
17741770
}
17751771

1772+
let app = Router::new()
1773+
.merge(SwaggerUi::new("/docs").url("/api-doc/openapi.json", doc))
1774+
.merge(routes)
1775+
.merge(public_routes)
1776+
.layer(Extension(infer))
1777+
.layer(Extension(info))
1778+
.layer(Extension(prom_handle.clone()))
1779+
.layer(OtelAxumLayer::default())
1780+
.layer(DefaultBodyLimit::max(payload_limit))
1781+
.layer(cors_layer);
1782+
17761783
// Run server
17771784
let listener = tokio::net::TcpListener::bind(&addr)
17781785
.await

0 commit comments

Comments
 (0)