@@ -1669,9 +1669,7 @@ pub async fn run(
1669
1669
ApiDoc :: openapi ( )
1670
1670
} ;
1671
1671
1672
- // Create router
1673
- let mut app = Router :: new ( )
1674
- . merge ( SwaggerUi :: new ( "/docs" ) . url ( "/api-doc/openapi.json" , doc) )
1672
+ let mut routes = Router :: new ( )
1675
1673
// Base routes
1676
1674
. route ( "/info" , get ( get_model_info) )
1677
1675
. route ( "/embed" , post ( embed) )
@@ -1686,74 +1684,72 @@ pub async fn run(
1686
1684
. route ( "/embeddings" , post ( openai_embed) )
1687
1685
. route ( "/v1/embeddings" , post ( openai_embed) )
1688
1686
// Vertex compat route
1689
- . route ( "/vertex" , post ( vertex_compatibility) )
1687
+ . route ( "/vertex" , post ( vertex_compatibility) ) ;
1688
+
1689
+ #[ allow( unused_mut) ]
1690
+ let mut public_routes = Router :: new ( )
1690
1691
// Base Health route
1691
1692
. route ( "/health" , get ( health) )
1692
1693
// Inference API health route
1693
1694
. route ( "/" , get ( health) )
1694
1695
// AWS Sagemaker health route
1695
1696
. route ( "/ping" , get ( health) )
1696
1697
// Prometheus metrics route
1697
- . route ( "/metrics" , get ( metrics) )
1698
- // Update payload limit
1699
- . layer ( DefaultBodyLimit :: max ( payload_limit) ) ;
1698
+ . route ( "/metrics" , get ( metrics) ) ;
1700
1699
1701
1700
#[ cfg( feature = "google" ) ]
1702
1701
{
1703
1702
tracing:: info!( "Built with `google` feature" ) ;
1704
1703
1705
1704
if let Ok ( env_predict_route) = std:: env:: var ( "AIP_PREDICT_ROUTE" ) {
1706
1705
tracing:: info!( "Serving Vertex compatible route on {env_predict_route}" ) ;
1707
- app = app . route ( & env_predict_route, post ( vertex_compatibility) ) ;
1706
+ routes = routes . route ( & env_predict_route, post ( vertex_compatibility) ) ;
1708
1707
}
1709
1708
1710
1709
if let Ok ( env_health_route) = std:: env:: var ( "AIP_HEALTH_ROUTE" ) {
1711
1710
tracing:: info!( "Serving Vertex compatible health route on {env_health_route}" ) ;
1712
- app = app . route ( & env_health_route, get ( health) ) ;
1711
+ public_routes = public_routes . route ( & env_health_route, get ( health) ) ;
1713
1712
}
1714
1713
}
1715
1714
#[ cfg( not( feature = "google" ) ) ]
1716
1715
{
1717
1716
// Set default routes
1718
- app = match & info. model_type {
1717
+ routes = match & info. model_type {
1719
1718
ModelType :: Classifier ( _) => {
1720
- app. route ( "/" , post ( predict) )
1719
+ routes
1720
+ . route ( "/" , post ( predict) )
1721
1721
// AWS Sagemaker route
1722
1722
. route ( "/invocations" , post ( predict) )
1723
1723
}
1724
1724
ModelType :: Reranker ( _) => {
1725
- app. route ( "/" , post ( rerank) )
1725
+ routes
1726
+ . route ( "/" , post ( rerank) )
1726
1727
// AWS Sagemaker route
1727
1728
. route ( "/invocations" , post ( rerank) )
1728
1729
}
1729
1730
ModelType :: Embedding ( model) => {
1730
1731
if std:: env:: var ( "TASK" ) . ok ( ) == Some ( "sentence-similarity" . to_string ( ) ) {
1731
- app. route ( "/" , post ( similarity) )
1732
+ routes
1733
+ . route ( "/" , post ( similarity) )
1732
1734
// AWS Sagemaker route
1733
1735
. route ( "/invocations" , post ( similarity) )
1734
1736
} else if model. pooling == "splade" {
1735
- app. route ( "/" , post ( embed_sparse) )
1737
+ routes
1738
+ . route ( "/" , post ( embed_sparse) )
1736
1739
// AWS Sagemaker route
1737
1740
. route ( "/invocations" , post ( embed_sparse) )
1738
1741
} else {
1739
- app. route ( "/" , post ( embed) )
1742
+ routes
1743
+ . route ( "/" , post ( embed) )
1740
1744
// AWS Sagemaker route
1741
1745
. route ( "/invocations" , post ( embed) )
1742
1746
}
1743
1747
}
1744
1748
} ;
1745
1749
}
1746
1750
1747
- app = app
1748
- . layer ( Extension ( infer) )
1749
- . layer ( Extension ( info) )
1750
- . layer ( Extension ( prom_handle. clone ( ) ) )
1751
- . layer ( OtelAxumLayer :: default ( ) )
1752
- . layer ( cors_layer) ;
1753
-
1754
1751
if let Some ( api_key) = api_key {
1755
- let mut prefix = "Bearer " . to_string ( ) ;
1756
- prefix. push_str ( & api_key) ;
1752
+ let prefix = format ! ( "Bearer {}" , api_key) ;
1757
1753
1758
1754
// Leak to allow FnMut
1759
1755
let api_key: & ' static str = prefix. leak ( ) ;
@@ -1770,9 +1766,20 @@ pub async fn run(
1770
1766
}
1771
1767
} ;
1772
1768
1773
- app = app . layer ( axum:: middleware:: from_fn ( auth) ) ;
1769
+ routes = routes . layer ( axum:: middleware:: from_fn ( auth) ) ;
1774
1770
}
1775
1771
1772
+ let app = Router :: new ( )
1773
+ . merge ( SwaggerUi :: new ( "/docs" ) . url ( "/api-doc/openapi.json" , doc) )
1774
+ . merge ( routes)
1775
+ . merge ( public_routes)
1776
+ . layer ( Extension ( infer) )
1777
+ . layer ( Extension ( info) )
1778
+ . layer ( Extension ( prom_handle. clone ( ) ) )
1779
+ . layer ( OtelAxumLayer :: default ( ) )
1780
+ . layer ( DefaultBodyLimit :: max ( payload_limit) )
1781
+ . layer ( cors_layer) ;
1782
+
1776
1783
// Run server
1777
1784
let listener = tokio:: net:: TcpListener :: bind ( & addr)
1778
1785
. await
0 commit comments